题目链接:LOJ 2377
给定一个长度为 $n$ 的字符串 $S$,令 $T_i$ 表示它从第 $i$ 个字符开始的后缀,求:
$$ \sum_{1 \le i < j \le n} \text{len}(T_i) + \text{len}(T_j) - 2 \times \text{lcp}(T_i, T_j) $$
其中,$\text{len}(a)$ 表示字符串 $a$ 的长度,$\text{lcp}(a, b)$ 表示字符串 $a$ 和字符串 $b$ 的最长公共前缀。
数据范围:$2 \le n \le 5 \times 10 ^ 5$。
Solution
首先我们可以在 $\mathcal O(1)\sim \mathcal O(n)$ 的时间内求出 $\text{len}(T_i) + \text{len}(T_j)$ 的值。把这部分贡献提出来之后,我们需要求的只是:
$$ \sum_{i = 1} ^ n \sum_{j = i + 1} ^ n \text{lcp}(T_i, T_j) $$
看到 $\text{lcp}$ 一定会想到后缀数组,于是我们先求出字符串的后缀数组。然后设 $\text{LCP}(i, j) = \text{lcp}(sa(i), sa(j))$,将枚举下标改为枚举排名。式子化为:
$$ \sum_{i = 1} ^ {n - 1} \sum_{j = i + 1} ^ n \text{LCP}(i, j) $$
将 $\text{LCP}(i, j)$ 拆成区间 $\text{min}$ 的形式得到:
$$ \sum_{i = 1} ^ {n - 1} \sum_{j = i + 1} ^ n \min_{i < k \le j} height(k) $$
我们把 $i + 1$ 替换为 $i'$ 得到:
$$ \sum_{i' = 2} ^ n \sum_{j = i'} ^ n \min_{i' \le k \le j} height(k) $$
现在问题转化为:对于所有区间,求出其最小值之和。
这是一个单调栈经典问题,考虑每个数的贡献进行统计就行了。
时间复杂度:$\mathcal O(n \log n)$。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
const int N = 5e5 + 5;
int n, stk[N], l[N], r[N];
char s[N];
template <int S>
struct SuffixArray {
static const int N = S << 1;
int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
void clear() {
memset(a, 0, sizeof(a));
memset(sa, 0, sizeof(sa));
memset(rk, 0, sizeof(rk));
memset(height, 0, sizeof(height));
}
void radixSort() {
for (int i = 1; i <= m; i++) bin[i] = 0;
for (int i = 1; i <= n; i++) bin[rk[i]]++;
for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
}
template <class Tp>
void build(Tp *_a, int _n, int _m) {
n = _n, m = _m;
std::copy(_a + 1, _a + n + 1, a + 1);
for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
radixSort();
for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
p = 0;
for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
radixSort();
std::swap(rk, tmp);
p = rk[sa[1]] = 1;
for (int i = 2; i <= n; i++) {
rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
}
}
for (int i = 1, k = 0; i <= n; i++) {
k -= (k > 0);
int j = sa[rk[i] - 1];
for (; a[i + k] == a[j + k]; k++);
height[rk[i]] = k;
}
}
};
SuffixArray<N> A;
int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
A.build(s, n, 255);
int top;
stk[top = 0] = 1;
for (int i = 2; i <= n; i++) {
for (; top && A.height[stk[top]] > A.height[i]; top--);
l[i] = stk[top];
stk[++top] = i;
}
stk[top = 0] = n + 1;
for (int i = n; i >= 2; i--) {
for (; top && A.height[stk[top]] >= A.height[i]; top--);
r[i] = stk[top];
stk[++top] = i;
}
long long ans = 1LL * (n - 1) * n * (n + 1) >> 1;
for (int i = 1; i <= n; i++) {
ans -= 2LL * A.height[i] * (r[i] - i) * (i - l[i]);
}
printf("%lld\n", ans);
return 0;
}