$\text{Splay}$ 是一种二叉查找树,它通过不断将某个节点旋转到根节点,使得整棵树仍然满足二叉查找树的性质,并且保持平衡而不至于退化为链。
例题
我们以「LOJ 105」文艺平衡树 作为例题:你需要维护一个有序数列,支持翻转区间操作。
分析
总体思路
我们还是用 $\text{Splay}$ 来维护这个序列,但是不用权值维护,而是用节点在序列中的位置为关键字维护,显然一个子树对应一个区间。
每次提取区间 $[l,r]$ 然后将左右子树全部交换。这正是利用了 $\text{Splay}$ 在旋转过程中不会改变中序遍历。那么原来的左根右在交换后变为右根左,实现了区间翻转。
提取区间
根据 $\text{Splay}$ 的特性(具体介绍详见「算法笔记」Splay - 维护二叉查找树),对于区间 $[l,r]$,我们可以把 $l-1$ 旋转到根,$r+1$ 旋转到根的儿子(显然是右儿子)。那么根的右儿子的左子树就是区间 $[l,r]$。
对于这里的 $l-1$ 和 $r+1$,指的是序列中第 $l-1$ 和第 $r+1$ 个元素对应的节点,而不是下标为 $l-1$ 和 $r+1$ 的节点。因此我们要调用 $\text{Splay}$ 基本操作中的 kth(l - 1)
和 kth(r + 1)
找到对应的节点编号。
交换子树
我们这里要使用一个和线段树很相似的懒标记,我们对于每个节点记录一个 $\text{rev}$ 标记,标记这个区间是否被翻转。每次 kth
操作时将标记下传、交换左右子树、清空自身标记。
时间复杂度:$\mathcal O(n\log n)$
代码
#include <cstdio>
#include <algorithm>
const int INF = 0x3f3f3f3f;
int n, m;
struct Node;
Node *null;
struct Node {
Node *ch[2], *fa;
int val, sz;
bool rev;
Node(int _val = 0) {
ch[0] = ch[1] = fa = null, val = _val, sz = 1;
}
bool get() {
return fa->ch[1] == this;
}
void reverse() {
rev ^= 1, std::swap(ch[0], ch[1]);
}
void pushup() {
sz = ch[0]->sz + ch[1]->sz + 1;
}
void pushdown() {
if (rev) {
ch[0]->reverse();
ch[1]->reverse();
rev = 0;
}
}
};
struct Splay {
Node *rt;
Splay() {
null = new Node();
null->ch[0] = null->ch[1] = null->fa = null;
null->sz = 0;
rt = null;
}
void rotate(Node *x) {
Node *y = x->fa, *z = y->fa;
int k = x->get();
z->ch[y->get()] = x, x->fa = z;
y->ch[k] = x->ch[!k], x->ch[!k]->fa = y;
x->ch[!k] = y, y->fa = x;
y->pushup();
}
void splay(Node *x, Node *g) {
while (x->fa != g) {
Node *y = x->fa;
if (y->fa != g) rotate(x->get() == y->get() ? y : x);
rotate(x);
}
x->pushup();
if (g == null) rt = x;
}
Node *kth(int k) {
k++;
Node *u = rt;
while (1) {
u->pushdown();
if (k <= u->ch[0]->sz) {
u = u->ch[0];
} else if (k > u->ch[0]->sz + 1) {
k -= u->ch[0]->sz + 1;
u = u->ch[1];
} else {
return u;
}
}
}
void insert(int v) {
Node *u = rt, *f = null;
while (u != null && v != u->val) {
f = u, u = u->ch[v > u->val];
}
u = new Node(v);
if (f != null) {
f->ch[v > f->val] = u;
u->fa = f;
}
splay(u, null);
}
void reverse(int l, int r) {
splay(kth(l - 1), null);
splay(kth(r + 1), rt);
rt->ch[1]->ch[0]->reverse();
}
} splay;
int main() {
splay.insert(-INF), splay.insert(INF);
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
splay.insert(i);
}
for (int i = 1; i <= m; i++) {
int l, r;
scanf("%d%d", &l, &r);
splay.reverse(l, r);
}
for (int i = 1; i <= n; i++) {
printf("%d%c", splay.kth(i)->val, " \n"[i == n]);
}
return 0;
}