题目链接:SPOJ 2666

给定一棵 $n$ 个节点的数,第 $i$ 条边的边权为 $c_i$,初始状态所有的点都是白色的。接下来要进行 $q$ 次操作,操作问题如下 $2$ 种:

  • C a:反转点 $a$ 的颜色(白色变成黑色,黑色变成白色)。
  • A:询问 $\max\{\text{dist(a, b)}\}$,其中 $a, b$ 都是白点(两个点可以相同)。这意味着,只要树上存在白点,则答案一定是非负整数。如果不存在白点则输出 They have disappeared.

数据范围:$1\le n, q\le 10 ^ 5$,$-10 ^ 3\le c_i \le 10 ^ 3$。


Solution

这道题主要麻烦在需要维护整棵树的信息,而不是一条树链。这意味着我们不但要记录左右儿子的信息,还需要维护虚子树的信息。

我们以点 $x$ 为例,模仿线段树维护最大子段和的形式,定义如下变量:

$\text{set}\ down$$\text{set}\ ians$$lmx$$rmx$$sum$$ans$$len$
以虚子树深度最小的节点向下的链的集合。虚子树的答案的集合。在该 $\text{Splay}$ 维护的链中,从深度最小的节点出发的链的最长长度。在该 $\text{Splay}$ 维护的链中,从深度最大的节点出发的链的最长长度。该 $\text{Splay}$ 维护的链的总长度。该 $\text{Splay}$ 中的答案的最大值。该点代表的边权。

对于每个变量,我们都有一系列复杂的转移,接下来会慢慢分析。

首先考虑虚子树信息的转移,这一部分比较简单:

  • $down,ians$:在 access 操作中,我们需要加入 x->ch[1] 的答案,删除 y 的答案。

接下来考虑 pushup 操作:

为了方便转移,定义一些临时变量:

  • 虚子树内向下的链的最长长度 $t=\max(down)$,如果当前点是白点则需要和 $0$ 取最大值。
  • 从当前点向上、向虚子树的最长链长度 $L = \max(t, rmx_{\text{lson}} + len)$。
  • 从当前点向下、向虚子树的最长链长度 $R = \max(t, lmx_{\text{rson}})$。

有了这些信息,我们就可以进行精彩的转移啦!

这一部分最重要的转移,由于都是分类讨论,具体详见代码。在此提几个关键点:

  1. 某些转移中,需要考虑当前点维护的边权 $len$。
  2. 考虑虚子树组合的情况。
  3. 注意当前点的颜色对转移的影响。

时间复杂度:$\mathcal O((n + q)\log ^ 2 n)$。


Code

#include <cstdio>
#include <algorithm>
#include <set>

const int N = 1e5 + 5, M = 2e5 + 5;
const int INF = 0x3f3f3f3f;

int n, m, tot, lnk[N], ter[M], nxt[M], val[M];

struct Node;
Node *null;

int fir(std::multiset<int> &s) {
    return s.empty() ? -INF : *s.rbegin();
}
int sec(std::multiset<int> &s) {
    return s.size() <= 1 ? -INF : *(++s.rbegin());
}
struct Node {
    Node *ch[2], *fa;
    int len, sum, lmx, rmx, ans;
    bool col;
    std::multiset<int> ians, down;
    Node() {
        ch[0] = ch[1] = fa = null;
        len = sum = lmx = rmx = ans = col = 0;
    }
    bool get() {
        return fa->ch[1] == this;
    }
    bool isroot() {
        return fa->ch[0] != this && fa->ch[1] != this;
    }
    void pushup() {
        sum = ch[0]->sum + ch[1]->sum + len;
        int t = std::max(col ? -INF : 0, fir(down));
        int L = std::max(t, ch[0]->rmx + len);
        int R = std::max(t, ch[1]->lmx);
        lmx = std::max(ch[0]->lmx, ch[0]->sum + len + R);
        rmx = std::max(ch[1]->rmx, ch[1]->sum + L);
        ans = std::max(ch[0]->rmx + len + R, ch[1]->lmx + L);
        ans = std::max(ans, std::max(ch[0]->ans, ch[1]->ans));
        ans = std::max(ans, fir(ians));
        ans = std::max(ans, fir(down) + sec(down));
        if (!col) ans = std::max(ans, std::max(fir(down), 0));
    }
} *a[N];

struct LCT {
    LCT() {
        null = new Node();
        null->ch[0] = null->ch[1] = null->fa = null;
        null->lmx = null->rmx = null->ans = -INF;
    }
    void build(int n) {
        for (int i = 1; i <= n; i++) {
            a[i] = new Node();
        }
    }
    void rotate(Node *x) {
        Node *y = x->fa, *z = y->fa;
        int k = x->get();
        !y->isroot() && (z->ch[y->get()] = x), x->fa = z;
        y->ch[k] = x->ch[!k], x->ch[!k]->fa = y;
        x->ch[!k] = y, y->fa = x;
        y->pushup();
    }
    void splay(Node *x) {
        while (!x->isroot()) {
            Node *y = x->fa;
            if (!y->isroot()) {
                rotate(x->get() == y->get() ? y : x);
            }
            rotate(x);
        }
        x->pushup();
    }
    void access(Node *x) {
        for (Node *y = null; x != null; x = (y = x)->fa) {
            splay(x);
            if (x->ch[1] != null) {
                x->ians.insert(x->ch[1]->ans);
                x->down.insert(x->ch[1]->lmx);
            }
            if ((x->ch[1] = y) != null) {
                x->ians.erase(x->ians.find(y->ans));
                x->down.erase(x->down.find(y->lmx));
            }
            x->pushup();
        }
    }
};

void add(int u, int v, int w) {
    ter[++tot] = v, nxt[tot] = lnk[u], lnk[u] = tot, val[tot] = w;
}
void dfs(int u, int p) {
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p) continue;
        a[v]->fa = a[u];
        a[v]->len = val[i];
        dfs(v, u);
        a[u]->ians.insert(a[v]->ans);
        a[u]->down.insert(a[v]->lmx);
    }
    a[u]->pushup();
}
int main() {
    scanf("%d", &n);
    LCT T;
    T.build(n);
    for (int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w), add(v, u, w);
    }
    dfs(1, 0);
    int ans = a[1]->ans;
    scanf("%d", &m);
    for (int i = 1; i <= m; i++) {
        char s[5];
        scanf("%s", s + 1);
        if (s[1] == 'C') {
            int x;
            scanf("%d", &x);
            T.access(a[x]), T.splay(a[x]);
            a[x]->col ^= 1;
            a[x]->pushup();
            ans = a[x]->ans;
        } else {
            if (ans < 0) {
                puts("They have disappeared.");
            } else {
                printf("%d\n", ans);
            }
        }
    }
    return 0;
}
最后修改:2019 年 06 月 28 日
如果觉得我的文章对你有用,请随意赞赏