题目链接:SPOJ 2798
给定一棵 $n$ 个节点的树,初始状态每个节点都是白色的。接下来有 $q$ 次操作,操作分为如下 $2$ 种:
0 i
:反转节点 $i$ 的颜色(白色变成黑色,黑色变成白色)。1 v
:询问从节点 $1$ 到 $v$ 的有向路径上第一个黑点。如果没有黑点则输出 $-1$。
数据范围:$1\le n, q\le 10 ^ 5$。
Solution
我们对每个点记一个标记,表示以这个点所在的 $\text{Splay}$ 中,以该点为根的子树中是否有黑点。询问时,我们提取 $1$ 到 $v$ 的路径,并把 $v$ 作为 $\text{Splay}$ 的根。由于询问的点是满足条件的深度最小的点,因此我们尽可能往 $\text{Splay}$ 的左子树走;如果左子树不满足条件则判断当前点;否则向右子树走。
时间复杂度:$\mathcal O((n + q)\log n)$。
--
Code
#include <cstdio>
#include <algorithm>
const int N = 1e5 + 5;
int n, m;
struct Node;
Node *null;
struct Node {
Node *ch[2], *fa;
int idx;
bool col, flg, rev;
Node(int _idx = 0) {
ch[0] = ch[1] = fa = null;
idx = _idx, flg = col = rev = 0;
}
bool get() {
return fa->ch[1] == this;
}
bool isroot() {
return fa->ch[0] != this && fa->ch[1] != this;
}
void reverse() {
rev ^= 1, std::swap(ch[0], ch[1]);
}
void pushup() {
flg = (ch[0]->flg || ch[1]->flg || col);
}
void pushdown() {
if (rev) {
ch[0]->reverse();
ch[1]->reverse();
rev = 0;
}
}
} *a[N];
struct LCT {
LCT() {
null = new Node();
null->ch[0] = null->ch[1] = null->fa = null;
}
void build(int n) {
for (int i = 1; i <= n; i++) {
a[i] = new Node(i);
}
}
void pushtag(Node *x) {
if (!x->isroot()) pushtag(x->fa);
x->pushdown();
}
void rotate(Node *x) {
Node *y = x->fa, *z = y->fa;
int k = x->get();
!y->isroot() && (z->ch[y->get()] = x), x->fa = z;
y->ch[k] = x->ch[!k], x->ch[!k]->fa = y;
x->ch[!k] = y, y->fa = x;
y->pushup();
}
void splay(Node *x) {
pushtag(x);
while (!x->isroot()) {
Node *y = x->fa;
if (!y->isroot()) {
rotate(x->get() == y->get() ? y : x);
}
rotate(x);
}
x->pushup();
}
void access(Node *x) {
for (Node *y = null; x != null; x = (y = x)->fa) {
splay(x), x->ch[1] = y, x->pushup();
}
}
void makeroot(Node *x) {
access(x), splay(x), x->reverse();
}
void split(Node *x, Node *y) {
makeroot(x), access(y), splay(y);
}
void link(Node *x, Node *y) {
makeroot(x);
x->fa = y;
}
int solve(Node *rt) {
Node *u = rt;
while (1) {
u->pushdown();
if (u->ch[0]->flg) {
u = u->ch[0];
} else if(u->col) {
return u->idx;
} else if(u->ch[1]->flg) {
u = u->ch[1];
} else {
return -1;
}
}
}
};
int main() {
scanf("%d%d", &n, &m);
LCT T;
T.build(n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
T.link(a[u], a[v]);
}
for (int i = 1; i <= m; i++) {
int opt, x;
scanf("%d%d", &opt, &x);
if (!opt) {
T.splay(a[x]);
a[x]->col ^= 1;
a[x]->pushup();
} else {
T.split(a[1], a[x]);
printf("%d\n", T.solve(a[x]));
}
}
return 0;
}