题目链接:LOJ 2377

给定一个长度为 $n$ 的字符串 $S$,令 $T_i$ 表示它从第 $i$ 个字符开始的后缀,求:

$$ \sum_{1 \le i < j \le n} \text{len}(T_i) + \text{len}(T_j) - 2 \times \text{lcp}(T_i, T_j) $$

其中,$\text{len}(a)$ 表示字符串 $a$ 的长度,$\text{lcp}(a, b)$ 表示字符串 $a$ 和字符串 $b​$ 的最长公共前缀。

数据范围:$2 \le n \le 5 \times 10 ^ 5$。


Solution

首先我们可以在 $\mathcal O(1)\sim \mathcal O(n)$ 的时间内求出 $\text{len}(T_i) + \text{len}(T_j)$ 的值。把这部分贡献提出来之后,我们需要求的只是:

$$ \sum_{i = 1} ^ n \sum_{j = i + 1} ^ n \text{lcp}(T_i, T_j) $$

看到 $\text{lcp}$ 一定会想到后缀数组,于是我们先求出字符串的后缀数组。然后设 $\text{LCP}(i, j) = \text{lcp}(sa(i), sa(j))$,将枚举下标改为枚举排名。式子化为:

$$ \sum_{i = 1} ^ {n - 1} \sum_{j = i + 1} ^ n \text{LCP}(i, j) $$

将 $\text{LCP}(i, j)$ 拆成区间 $\text{min}$ 的形式得到:

$$ \sum_{i = 1} ^ {n - 1} \sum_{j = i + 1} ^ n \min_{i < k \le j} height(k) $$

我们把 $i + 1$ 替换为 $i'$ 得到:

$$ \sum_{i' = 2} ^ n \sum_{j = i'} ^ n \min_{i' \le k \le j} height(k) $$

现在问题转化为:对于所有区间,求出其最小值之和。

这是一个单调栈经典问题,考虑每个数的贡献进行统计就行了。

时间复杂度:$\mathcal O(n \log n)$。


Code

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 5e5 + 5;

int n, stk[N], l[N], r[N];
char s[N];

template <int S>
struct SuffixArray {
    static const int N = S << 1;
    int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
    void clear() {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
    }
    void radixSort() {
        for (int i = 1; i <= m; i++) bin[i] = 0;
        for (int i = 1; i <= n; i++) bin[rk[i]]++;
        for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
        for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
    }
    template <class Tp>
    void build(Tp *_a, int _n, int _m) {
        n = _n, m = _m;
        std::copy(_a + 1, _a + n + 1, a + 1);
        for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
        radixSort();
        for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
            p = 0;
            for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
            for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
            radixSort();
            std::swap(rk, tmp);
            p = rk[sa[1]] = 1;
            for (int i = 2; i <= n; i++) {
                rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
            }
        }
        for (int i = 1, k = 0; i <= n; i++) {
            k -= (k > 0);
            int j = sa[rk[i] - 1];
            for (; a[i + k] == a[j + k]; k++);
            height[rk[i]] = k;
        }
    }
};

SuffixArray<N> A;

int main() {
    scanf("%s", s + 1);
    n = strlen(s + 1);
    A.build(s, n, 255);
    int top;
    stk[top = 0] = 1;
    for (int i = 2; i <= n; i++) {
        for (; top && A.height[stk[top]] > A.height[i]; top--);
        l[i] = stk[top];
        stk[++top] = i;
    }
    stk[top = 0] = n + 1;
    for (int i = n; i >= 2; i--) {
        for (; top && A.height[stk[top]] >= A.height[i]; top--);
        r[i] = stk[top];
        stk[++top] = i;
    }
    long long ans = 1LL * (n - 1) * n * (n + 1) >> 1;
    for (int i = 1; i <= n; i++) {
        ans -= 2LL * A.height[i] * (r[i] - i) * (i - l[i]);
    }
    printf("%lld\n", ans);
    return 0;
}
最后修改:2019 年 06 月 28 日
如果觉得我的文章对你有用,请随意赞赏