题目链接:Codeforces 452E
你有三个字符串 $(s_1, s_2, s_3)$。对于每个整数 $l(1 \le l \le \min(\vert s_1 \vert, \vert s_2 \vert, \vert s_3 \vert)$,你需要求出有多少三元组 $(i_1, i_2, i_3)$ 满足 $s_k[i_k \dots i_k + l - 1](k = 1, 2, 3)$ 两两相等。答案对 $10 ^ 9 + 7$ 取模。
数据范围:$3 \le \sum_{i = 1} ^ 3 \vert s_i \vert \le 3 \times 10 ^ 5$。
Solution
首先考虑当 $l$ 确定时的做法。我们只需要对 $height$ 分组,保证每组内 $height(i) \ge l$。根据乘法原理和加法原理,如果第 $i$ 组内在第 $1, 2, 3$ 个字符串内的后缀分别有 $a_i, b_i, c_i$ 个,那么答案为 $\sum a_i \times b_i \times c_i$。
又注意到,随着 $l$ 的增大,组数是单调不降的。于是我们考虑从大到小枚举长度 $l$,并不断合并区间,同时维护每组(每个区间)的三个值。可以通过并查集实现。
时间复杂度:$\mathcal O(n \log n)$。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
const int N = 1e6 + 5;
const int P = 1e9 + 7;
int n, n1, n2, n3, bl[N], f[N], idx[N], sum[N][4], ans[N];
char s[N];
template <int S>
struct SuffixArray {
static const int N = S << 1;
int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
void clear() {
memset(a, 0, sizeof(a));
memset(sa, 0, sizeof(sa));
memset(rk, 0, sizeof(rk));
memset(height, 0, sizeof(height));
}
void radixSort() {
for (int i = 1; i <= m; i++) bin[i] = 0;
for (int i = 1; i <= n; i++) bin[rk[i]]++;
for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
}
template <class Tp>
void build(Tp *_a, int _n, int _m) {
n = _n, m = _m;
std::copy(_a + 1, _a + n + 1, a + 1);
for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
radixSort();
for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
p = 0;
for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
radixSort();
std::swap(rk, tmp);
p = rk[sa[1]] = 1;
for (int i = 2; i <= n; i++) {
rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
}
}
for (int i = 1, k = 0; i <= n; i++) {
k -= (k > 0);
int j = sa[rk[i] - 1];
for (; a[i + k] == a[j + k]; k++);
height[rk[i]] = k;
}
}
};
SuffixArray<N> A;
void read(int &_n, int p) {
scanf("%s", s + n + 1);
_n = strlen(s + n + 1);
std::fill(bl + n + 1, bl + n + _n + 1, p);
s[n += _n + 1] = '#' + p;
}
int find(int x) {
return f[x] == x ? x : f[x] = find(f[x]);
}
void add(int &x, int y) {
(x += y) >= P && (x -= P);
}
int calc(int x) {
return 1LL * sum[x][1] * sum[x][2] % P * sum[x][3] % P;
}
int main() {
read(n1, 1), read(n2, 2), read(n3, 3);
A.build(s, n, 255);
int l = std::min(n1, std::min(n2, n3));
for (int i = 1; i <= n; i++) {
f[i] = idx[i] = i, sum[i][bl[i]] = 1;
}
std::sort(idx + 1, idx + n + 1, [](int a, int b) {
return A.height[a] > A.height[b];
});
for (int i = l, j = 1, now = 0; i >= 1; i--) {
for (; j <= n && A.height[idx[j]] >= i; j++) {
int u = find(A.sa[idx[j] - 1]), v = find(A.sa[idx[j]]);
add(now, P - calc(u));
add(now, P - calc(v));
add(sum[u][1], sum[v][1]);
add(sum[u][2], sum[v][2]);
add(sum[u][3], sum[v][3]);
add(now, calc(u));
f[v] = u;
}
ans[i] = now;
}
for (int i = 1; i <= l; i++) {
printf("%d%c", ans[i], " \n"[i == l]);
}
return 0;
}
1 条评论
siyuan tql!