$\text{Splay}$ 是一种二叉查找树,它通过不断将某个节点旋转到根节点,使得整棵树仍然满足二叉查找树的性质,并且保持平衡而不至于退化为链。


结构

二叉查找树的性质

首先肯定是一棵二叉树!

能够在这棵树上查找某个值的性质:左儿子的值 $<$ 根节点的值 $<$ 右儿子的值。

节点维护信息

$rt$$fa$$ch[0/1]$$val$$cnt$$sz$
根节点父亲左右儿子节点权值权值出现次数子树大小

操作

基本操作

  • $\text{get}()$:判断节点 $x$ 是父亲节点的左儿子还是右儿子。
  • $\text{pushup}()​$:在改变节点 $x​$ 的位置前,将节点 $x​$ 的 $\text{size}​$ 更新。
bool get() {
    return fa->ch[1] == this;
}
void pushup() {
    sz = ch[0]->sz + ch[1]->sz + cnt;
}

旋转操作

为了使 $\text{Splay}$ 保持平衡而进行旋转操作,旋转的本质是将某个节点上移一个位置。

旋转需要保证

  • 整棵 $\text{Splay}$ 的中序遍历不变(不能破坏二叉查找树的性质)。
  • 受影响的节点维护的信息依然正确有效。
  • $root$ 必须指向旋转后的根节点。

在 $\text{Splay}$ 中旋转分为两种:左旋和右旋。

具体分析旋转步骤(假设需要旋转的节点为 $x$,$x$ 的父亲为 $y$,$y$ 的父亲为 $z$,以右旋为例)

  1. 将 $z$ 的某个儿子(原来 $y$ 所在的儿子位置即 y->get(y))指向 $x$,且 $x$ 的父亲指向 $z$。
    z->ch[get(y)] = x, x->fa = z;
  2. 将 $y$ 的左儿子指向 $x$ 的右儿子,且 $x$ 的右儿子的父亲指向 $y$。
    y->ch[0] = x->ch[1], x->ch[1]->fa = y;
  3. 将 $x$ 的右儿子指向 $y$,且 $y$ 的父亲指向 $x$。
    x->ch[1] = y, y->fa = x;
  4. 分别更新 $y$ 节点的信息(节点 $x$ 信息在后文会更新)。
    y->pushup();
void rotate(Node *x) {
    Node *y = x->fa, *z = y->fa;
    int k = x->get();
    z->ch[y->get()] = x, x->fa = z;
    y->ch[k] = x->ch[!k], x->ch[!k]->fa = y;
    x->ch[!k] = y, y->fa = x;
    y->pushup();
}

Splay 操作

$\text{Splay}$ 规定:每访问一个节点后都要强制将其旋转到根节点。此时旋转操作具体分为 $6$ 种情况讨论(其中 $x$ 为需要旋转到根的节点)。

  • 如果 $x$ 的父亲是根节点,直接将 $x$ 左旋或右旋(图 $1,2$)。
  • 如果 $x$ 的父亲不是根节点,且 $x$ 和父亲的儿子类型相同,首先将其父亲左旋或右旋,然后将 $x$ 右旋或左旋(图 $3,4$)。
  • 如果 $x$ 的父亲不是根节点,且 $x$ 和父亲的儿子类型不同,将 $x$ 左旋再右旋、或者右旋再左旋(图 $5,6$)。

分析起来一大串,其实代码一小段。大家可以自己模拟一下 $6$ 种旋转情况,就能理解 $\text{splay}$ 的基本思想了。代码 splay(x, g) 表示把 $x$ 旋转到 $g$ 的儿子(当 $g = 0$ 时表示旋转到根)。$\text{splay}$ 结束后,我们需要更新 $x$ 的信息。

void splay(Node *x, Node *g) {
    while (x->fa != g) {
        Node *y = x->fa;
        if (y->fa != g) rotate(x->get() == y->get() ? y : x);
        rotate(x);
    }
    x->pushup();
    if (g == null) rt = x;
}

查找操作

我们有时在 $\text{Splay}$ 中查找一个值就需要查找操作。它的思想就是二叉查找树的查找过程,每次根据待查找的值 $v$ 与当前节点的值的关系,来判断进入左、右儿子。

void find(int v) {
    if (rt == null) return;
    Node *u = rt;
    while (v != u->val && u->ch[v > u->val] != null) {
        u = u->ch[v > u->val];
    }
    splay(u, null);
}

查询排名

排名定义为第 $1$ 个等于 $v$ 的值的排名。那么我们只需要把 $v$ 旋转到根节点,返回根的左子树的 $sz$ 再减 $1$ 即可!(代码中没有减 $1$ 的原因是笔者在 $\text{Splay}$ 中事先插入了 $-\text{INF}$ 和 $\text{INF}$)

int rank(int v) {
    find(v);
    return rt->ch[0]->sz;
}

第 k 大数

设 $k$ 为剩余排名,具体步骤如下:

  • 如果 $k$ 大于左子树大小与当前节点大小的和,那么向右子树查找。
  • 如果 $k$ 不大于左子树的大小,那么向左子树查找。
  • 否则直接返回当前节点的值。

代码中将 $k$ 增加 $1$ 的原因同上。

Node *kth(int k) {
    k++;
    Node *u = rt;
    while (1) {
        if (k <= u->ch[0]->sz) {
            u = u->ch[0];
        } else if (k > u->ch[0]->sz + u->cnt) {
            k -= u->ch[0]->sz + u->cnt;
            u = u->ch[1];
        } else {
            return u;
        }
    }
}

查询前驱

前驱定义为严格小于 $v$ 的最大的数,那么查询前驱可以转化为:将 $v$ 旋转到根节点, 前驱即为 $v$ 的左子树中最右边的节点。注意当 $v$ 不存在时,根节点的值比 $v$ 小的情况要特判!

Node *pre(int v) {
    find(v);
    if (rt->val < v) return rt;
    Node *u = rt->ch[0];
    while (u->ch[1] != null) u = u->ch[1];
    return u;
}

查询后继

后继定义为严格大于 $v$ 的最小的数,查询方法和前驱类似:$v$ 的右子树中最左边的节点。

Node *suc(int v) {
    find(v);
    if (rt->val > v) return rt;
    Node *u = rt->ch[1];
    while (u->ch[0] != null) u = u->ch[0];
    return u;
}

插入操作

插入操作是一个非常重要的操作:按照二叉查找树的性质向下查找,找到待插入的值 $v$ 应该插入的节点并插入。如果 $v$ 原来就存在,那么直接更新 $cnt$,否则新建一个空节点。最后别忘了 $\text{splay}$ 操作。

void insert(int v) {
    Node *u = rt, *f = null;
    while (u != null && v != u->val) {
        f = u, u = u->ch[v > u->val];
    }
    if (u != null) {
        u->cnt++;
    } else {
        u = new Node(v);
        if (f != null) {
            f->ch[v > f->val] = u;
            u->fa = f;
        }
    }
    splay(u, null);
}

删除操作

删除操作看似是一个比较复杂的操作,但是如果深入理解了 $\text{Splay}$ 的性质,其实非常简单!

  • 首先得到 $v$ 的前驱 $lst$ 和后继 $nxt$。将 $lst$ 旋转到根,将 $nxt$ 旋转到 $lst$ 的儿子(显然是右儿子)。
  • 观察这个过程可以发现:如果 $v$ 存在,那么此时 $nxt$ 的左儿子一定就是 $v$,将这个节点的大小减 $1$ (需要 $\text{splay}$ 操作)或者直接删除即可。
void erase(int v) {
    Node *lst = pre(v), *nxt = suc(v);
    splay(lst, null), splay(nxt, lst);
    Node *u = nxt->ch[0];
    if (u->cnt > 1) {
        u->cnt--;
        splay(u, null);
    } else {
        clear(nxt->ch[0]);
        nxt->ch[0] = null;
    }
}

代码

#include <cstdio>

const int N = 1e5 + 5;
const int INF = 0x3f3f3f3f;

struct Node;
Node *null;

struct Node {
    Node *ch[2], *fa;
    int val, cnt, sz;
    Node(int _val = 0) {
        ch[0] = ch[1] = fa = null, val = _val, cnt = sz = 1;
    }
    bool get() {
        return fa->ch[1] == this;
    }
    void pushup() {
        sz = ch[0]->sz + ch[1]->sz + cnt;
    }
};

struct Splay {
    Node *rt;
    Splay() {
        null = new Node();
        null->ch[0] = null->ch[1] = null->fa = null;
        null->cnt = null->sz = 0;
        rt = null;
    }
    void clear(Node *x) {
        if (x == null) return;
        clear(x->ch[0]);
        clear(x->ch[1]);
        delete x;
    }
    void rotate(Node *x) {
        Node *y = x->fa, *z = y->fa;
        int k = x->get();
        z->ch[y->get()] = x, x->fa = z;
        y->ch[k] = x->ch[!k], x->ch[!k]->fa = y;
        x->ch[!k] = y, y->fa = x;
        y->pushup();
    }
    void splay(Node *x, Node *g) {
        while (x->fa != g) {
            Node *y = x->fa;
            if (y->fa != g) rotate(x->get() == y->get() ? y : x);
            rotate(x);
        }
        x->pushup();
        if (g == null) rt = x;
    }
    void find(int v) {
        if (rt == null) return;
        Node *u = rt;
        while (v != u->val && u->ch[v > u->val] != null) {
            u = u->ch[v > u->val];
        }
        splay(u, null);
    }
    int rank(int v) {
        find(v);
        return rt->ch[0]->sz;
    }
    Node *kth(int k) {
        k++;
        Node *u = rt;
        while (1) {
            if (k <= u->ch[0]->sz) {
                u = u->ch[0];
            } else if (k > u->ch[0]->sz + u->cnt) {
                k -= u->ch[0]->sz + u->cnt;
                u = u->ch[1];
            } else {
                return u;
            }
        }
    }
    Node *pre(int v) {
        find(v);
        if (rt->val < v) return rt;
        Node *u = rt->ch[0];
        while (u->ch[1] != null) u = u->ch[1];
        return u;
    }
    Node *suc(int v) {
        find(v);
        if (rt->val > v) return rt;
        Node *u = rt->ch[1];
        while (u->ch[0] != null) u = u->ch[0];
        return u;
    }
    void insert(int v) {
        Node *u = rt, *f = null;
        while (u != null && v != u->val) {
            f = u, u = u->ch[v > u->val];
        }
        if (u != null) {
            u->cnt++;
        } else {
            u = new Node(v);
            if (f != null) {
                f->ch[v > f->val] = u;
                u->fa = f;
            }
        }
        splay(u, null);
    }
    void erase(int v) {
        Node *lst = pre(v), *nxt = suc(v);
        splay(lst, null), splay(nxt, lst);
        Node *u = nxt->ch[0];
        if (u->cnt > 1) {
            u->cnt--;
            splay(u, null);
        } else {
            clear(nxt->ch[0]);
            nxt->ch[0] = null;
        }
    }
} splay;

int main() {
    splay.insert(-INF), splay.insert(INF);
    int m;
    for (scanf("%d", &m); m--; ) {
        int opt, x;
        scanf("%d%d", &opt, &x);
        if (opt == 1) splay.insert(x);
        if (opt == 2) splay.erase(x);
        if (opt == 3) printf("%d\n", splay.rank(x));
        if (opt == 4) printf("%d\n", splay.kth(x)->val);
        if (opt == 5) printf("%d\n", splay.pre(x)->val);
        if (opt == 6) printf("%d\n", splay.suc(x)->val);
    }
    return 0;
}

习题

最后修改:2023 年 01 月 10 日
如果觉得我的文章对你有用,请随意赞赏