题目链接:BZOJ 1500

请写一个程序,要求维护一个数列。一共有 $m$ 个操作,支持以下 $6​$ 种操作:

操作输入格式说明
插入INSERT post tot c[1] c[2] ... c[tot]在当前数列的第 $pos$ 个数字后插入 $tot$ 个数字:$c_1,c_2,\cdots,c_{tot}$;若在数列首插入,则 $pos$ 为 $0$。
删除DELETE pos tot从当前数列的第 $pos$ 个数字开始连续删除 $tot$ 个数字。
修改MAKE-SAME pos tot c将当前数列的第 $pos$ 个数字开始的连续 $tot$ 个数字统一修改为 $c$。
翻转REVERSE pos tot取出从当前数列的第 $pos$ 个数字开始的 $tot$ 个数字,翻转后放入原来的位置。
求和GET-SUM pos tot计算从当前数列的第 $pos$ 个数字开始的 $tot$ 个数字的和并输出。
求和最大的子列MAX-SUM求出当前数列中和最大的一段非空子列,并输出最大和。

数据范围:$1\le m\le 2\times 10 ^ 4$,任何时刻数列中最多含有 $5\times 10 ^ 5$,数列中任何一个数字均在 $[-10^3,10^3]$,插入的数字总数不超过 $4\times 10 ^ 6$ 个。


Solution

平衡树操作大合集吼啊!

考虑用 $\text{Splay}$ 维护这个序列(具体实现过程可以参考「算法笔记」Splay 维护序列),当我们要提取 $[l,r]$ 时,将 $l-1$ 旋转到根节点,将 $r+1$ 旋转到根节点的右儿子。维护最大子段和时,和线段树维护相同套路。

代码实现中需要注意许多细节:空指针越界问题、如何定义空儿子的信息以化简代码……代码也就两百多行而已 QAQ

时间复杂度:$\mathcal O(n\log n)$。


Code

#include <cstdio>
#include <algorithm>

const int N = 5e5 + 5;
const int INF = 0x3f3f3f3f;

int n, m, a[N];

struct Node;
Node *null;

struct Node {
    Node *ch[2], *fa;
    int val, sz, tag, lmx, rmx, sum, ans;
    bool rev, cov;
    Node(Node *l = null, Node *r= null, Node *_fa = null, int _val = 0) {
        ch[0] = l, ch[1] = r, fa = _fa, val = sum = ans = _val, sz = 1, tag = rev = cov = 0;
        lmx = rmx = std::max(val, 0);
    }
    void setc(int o, Node *son) {
        ch[o] = son, son->fa = this;
    }
    bool get() {
        return fa->ch[1] == this;
    }
    void pushup() {
        sz = ch[0]->sz + ch[1]->sz + 1;
        sum = ch[0]->sum + ch[1]->sum + val;
        lmx = std::max(ch[0]->lmx, ch[0]->sum + val + ch[1]->lmx);
        rmx = std::max(ch[1]->rmx, ch[0]->rmx + val + ch[1]->sum);
        ans = std::max(ch[0]->rmx + val + ch[1]->lmx, std::max(ch[0]->ans, ch[1]->ans));
    }
    void reverse() {
        rev ^= 1;
        std::swap(ch[0], ch[1]), std::swap(lmx, rmx);
    }
    void cover(int v) {
        cov = 1, sum = sz * (tag = val = v);
        if (v > 0) {
            lmx = rmx = ans = sum;
        } else {
            lmx = rmx = 0, ans = v;
        }
    }
    void pushdown() {
        if (rev) {
            if (ch[0] != null) ch[0]->reverse();
            if (ch[1] != null) ch[1]->reverse();
            rev = 0;
        }
        if (cov) {
            if (ch[0] != null) ch[0]->cover(tag);
            if (ch[1] != null) ch[1]->cover(tag);
            tag = cov = 0;
        }
    }
};

Node *newNode(Node *l = null, Node *r = null, Node *fa = null, int val = 0) {
    Node *t = new Node();
    *t = Node(l, r, fa, val);
    return t;
}

struct Splay {
    Node *rt;
    Splay() {
        null = newNode(NULL, NULL, NULL, 0);
        null->ch[0] = null->ch[1] = null->fa = null;
        null->sz = 0, null->ans = -INF;
        rt = newNode(null, null, null, -INF);
        rt->ch[1] = newNode(null, null, rt, -INF);
        rt->pushup();
    }
    Splay(int *a, int n) {
        rt = build(a, 1, n);
    }
    Node *build(int *a, int l, int r) {
        if (l > r) {
            return null;
        }
        int mid = (l + r) >> 1;
        Node *x = newNode();
        x->sum = x->ans = x->val = a[mid];
        x->lmx = x->rmx = std::max(a[mid], 0);
        Node *L = build(a, l, mid - 1), *R = build(a, mid + 1, r);
        if (L != null) x->setc(0, L);
        if (R != null) x->setc(1, R);
        x->pushup();
        return x;
    }
    void rotate(Node *x) {
        Node *y = x->fa, *z = y->fa;
        int k = x->get();
        z->setc(y->get(), x);
        y->setc(k, x->ch[!k]);
        x->setc(!k, y);
        y->pushup();
    }
    void splay(Node *x, Node *g) {
        while (x->fa != g) {
            Node *y = x->fa;
            if (y->fa != g) rotate(x->get() == y->get() ? y : x);
            rotate(x);
        }
        x->pushup();
        if (g == null) {
            rt = x;
        }
    }
    void find(int x) {
        if (rt == null) {
            return;
        }
        Node *u = rt;
        u->pushdown();
        while (x != u->val && u->ch[x > u->val] != null) {
            u = u->ch[x > u->val];
            u->pushdown();
        }
        splay(u, null);
    }
    int rnk(int x) {
        find(x);
        return rt->ch[0]->sz;
    }
    Node *kth(int x) {
        x++;
        Node *u = rt;
        while (1) {
            u->pushdown();
            if (x > u->ch[0]->sz + 1) {
                x -= u->ch[0]->sz + 1;
                u = u->ch[1];
            } else if (x <= u->ch[0]->sz) {
                u = u->ch[0];
            } else {
                return u;
            }
        }
    }
    void clear(Node *x) {
        if (x == null) {
            return;
        }
        clear(x->ch[0]);
        clear(x->ch[1]);
        delete x;
    }
    void select(int l, int r) {
        splay(kth(l - 1), null);
        splay(kth(r + 1), rt);
    }
    void ins(int pos, Splay &x) {
        select(pos + 1, pos);
        rt->ch[1]->setc(0, x.rt);
        rt->ch[1]->pushup(), rt->pushup();
    }
    void del(int l, int r) {
        select(l, r);
        clear(rt->ch[1]->ch[0]);
        rt->ch[1]->ch[0] = null;
        rt->ch[1]->pushup(), rt->pushup();
    }
    void cov(int l, int r, int v) {
        select(l, r);
        rt->ch[1]->ch[0]->cover(v);
        rt->ch[1]->pushup(), rt->pushup();
    }
    void rev(int l, int r) {
        select(l, r);
        rt->ch[1]->ch[0]->reverse();
        rt->ch[1]->pushup(), rt->pushup();
    }
    int sum(int l, int r) {
        select(l, r);
        return rt->ch[1]->ch[0]->sum;
    }
    int ans() {
        return rt->ans;
    }
} spl;

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    Splay tmp = Splay(a, n);
    spl.ins(0, tmp);
    for (int i = 1; i <= m; i++) {
        char s[20];
        scanf("%s", s + 1);
        if (s[1] == 'M' && s[3] == 'X') {
            printf("%d\n", spl.ans());
            continue;
        }
        int pos, tot;
        scanf("%d%d", &pos, &tot);
        if (s[1] == 'I') {
            for (int j = 1; j <= tot; j++) {
                scanf("%d", &a[j]);
            }
            tmp = Splay(a, tot);
            spl.ins(pos, tmp);
        }
        if (s[1] == 'D') {
            spl.del(pos, pos + tot - 1);
        }
        if (s[1] == 'M' && s[3] =='K') {
            int v;
            scanf("%d", &v);
            spl.cov(pos, pos + tot - 1, v);
        }
        if (s[1] == 'R') {
            spl.rev(pos, pos + tot - 1);
        }
        if (s[1] == 'G') {
            printf("%d\n", spl.sum(pos, pos + tot - 1));
        }
    }
    return 0;
}
最后修改:2019 年 06 月 28 日
如果觉得我的文章对你有用,请随意赞赏