题目链接:Codeforces 204E
小象非常喜欢字符串。他拥有 $n$ 个包含小写字母的字符串,第 $i$ 个字符串记为 $a_i$。对于每个字符串 $a_i(1 \le i \le n)$,小象想要求出二元组 $(l, r)$ 的对数,其中 $(l, r)$ 需要满足:$1 \le l \le r \le \lvert a_i \rvert$ 且子串 $a_i[l\dots r]$ 是至少 $k$ 个字符串的子串。
数据范围:$1 \le n, k \le 10 ^ 5$,$\sum_{i = 1} ^ n \lvert a_i \rvert \le 10 ^ 5$。
Solution
首先我们把所有字符串平在一起求后缀数组,对于每个字符串 $a_i$ 从前往后考虑,找到当前位置 $j$ 能够往后延伸的最长长度 $len_j$,使得 $s_i[j\dots j + len_j - 1]$ 至少为 $k$ 个字符串的子串。那么这个位置 $j$ 对答案的贡献就是 $len_j$。在计算同一个字符串的下一个位置 $j + 1$ 时,我们不需要重新二分长度 $len_{j + 1}$;注意到去掉位置 $j$ 后的子串 $s_i[j + 1\dots j + len_j - 1]$ 一定是合法的,因此我们有 $len_{j + 1} \ge len_j - 1$。这个过程和求 $height$ 数组很类似。
接下来问题转化为:如何快速求出从位置 $j$ 开始的长度为 $len$ 的子串是否合法?运用 $height$ 数组分组的思想,我们将 $rk(j)$ 往前扩展到 $rk(l)$、往后扩展到 $rk(r)$,使得 $\text{LCP}(rk(l), rk(r)) \ge len$。此时我们只需要判断后缀 $\text{suffix}(sa(rk(l))), \text{suffix}(sa(rk(l) + 1)), \cdots, \text{suffix}(sa(rk(r)))$ 中是否属于不少于 $k$ 个字符串。
这个问题我们是可以在 $\mathcal O(n)$ 的时间内预处理得到的。对于每个 $i$ 预处理 $pos(i)$ 表示最小的位置 $j$ 使得 $\text{suffix}(sa(i\dots j))$ 属于不少于 $k$ 个字符串。可以通过 $\text{Two Pointers}$ 预处理。
于是上述问题只需要满足 $pos(rk(l)) \le rk(r)$ 则可行。
时间复杂度:$\mathcal O(n \log n)$,貌似比标算复杂度优秀耶 >\_<!
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
const int N = 2e5 + 5;
int n, m, k, a[N], len[N], st[N], bl[N], cnt[N], p[N];
template <int S>
struct SuffixArray {
static const int N = S << 1, logN = 20;
int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N], lg[N], f[N][logN];
void clear() {
memset(a, 0, sizeof(a));
memset(sa, 0, sizeof(sa));
memset(rk, 0, sizeof(rk));
memset(height, 0, sizeof(height));
}
void radixSort() {
for (int i = 1; i <= m; i++) bin[i] = 0;
for (int i = 1; i <= n; i++) bin[rk[i]]++;
for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
}
template <class Tp>
void build(Tp *_a, int _n, int _m) {
memset(a, 0, sizeof(a));
memset(sa, 0, sizeof(sa));
memset(rk, 0, sizeof(rk));
memset(height, 0, sizeof(height));
n = _n, m = _m;
std::copy(_a + 1, _a + n + 1, a + 1);
for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
radixSort();
for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
p = 0;
for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
radixSort();
std::swap(rk, tmp);
p = rk[sa[1]] = 1;
for (int i = 2; i <= n; i++) {
rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
}
}
for (int i = 1, k = 0; i <= n; i++) {
k -= (k > 0);
int j = sa[rk[i] - 1];
for (; a[i + k] == a[j + k]; k++);
height[rk[i]] = k;
}
}
void buildST() {
for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for (int i = 1; i <= n; i++) f[i][0] = height[i];
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = std::min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}
int LCP(int l, int r) {
if (l == r) return n - sa[l] + 1;
int k = lg[r - (++l) + 1];
return std::min(f[l][k], f[r - (1 << k) + 1][k]);
}
};
SuffixArray<N> A;
int calcL(int x, int len) {
int l = 1, r = x, ans = 0;
while (l <= r) {
int mid = (l + r) >> 1;
A.LCP(mid, x) >= len ? r = (ans = mid) - 1 : l = mid + 1;
}
return ans;
}
int calcR(int x, int len) {
int l = x, r = n, ans = 0;
while (l <= r) {
int mid = (l + r) >> 1;
A.LCP(x, mid) >= len ? l = (ans = mid) + 1 : r = mid - 1;
}
return ans;
}
bool check(int pos, int len) {
int x = A.rk[pos];
int l = calcL(x, len), r = calcR(x, len);
return p[l] <= r;
}
int main() {
scanf("%d%d", &m, &k);
for (int i = 1; i <= m; i++) {
static char t[N];
scanf("%s", t + 1);
len[i] = strlen(t + 1);
st[i] = n + 1;
for (int j = 1; j <= len[i]; j++) a[++n] = t[j], bl[n] = i;
a[++n] = i + 256, bl[n] = 0;
}
A.build(a, n, m + 256);
A.buildST();
for (int i = 1, j = 1, now = 0; i <= n; i++) {
for (; now < k && j <= n; j++) {
now += (bl[A.sa[j]] && ++cnt[bl[A.sa[j]]] == 1);
}
p[i] = (now >= k ? j - 1 : n + 1);
now -= (bl[A.sa[i]] && --cnt[bl[A.sa[i]]] == 0);
}
for (int i = 1; i <= m; i++) {
long long ans = 0;
for (int j = 1, l = 0; j <= len[i]; j++) {
l -= (l > 0);
for (; j + l - 1 < len[i] && check(st[i] + j - 1, l + 1); l++);
ans += l;
}
printf("%lld%c", ans, " \n"[i == m]);
}
return 0;
}