题目链接:POJ 3415
字符串 $T$ 的子串定义为:
$$ T(i,k) = T_iT_{i + 1}\cdots T_{t + k - 1}, 1 \le i \le i + k - 1 \le \lvert T \rvert $$
给定两个字符串 $A, B$ 和一个整数 $K$,我们定义 $S$ 为三元组 $(i, j, k)$ 集合:
$$ S = \{(i, j, k) \mid k \ge K, A(i, k) = B(j, k)\} $$
你需要求出集合 $S$ 的大小 $\lvert S \rvert$。
数据范围:$1 \le \lvert A \rvert, \lvert B \rvert \le 10 ^ 5$,$1 \le K \le \min(\lvert A \rvert, \lvert B \rvert)$。
Solution
按照套路,我们将两个字符串拼接起来并建立后缀数组,按照 $height$ 分组。
接下来我们要统计每组中的后缀的最长公共前缀之和。按照 $rk$ 从小到大扫描一遍,每遇到一个 $A$ 串的后缀就和排名在其前面的 $B$ 串后缀进行统计。对 $B$ 串同样统计一遍。
直接计算的复杂度是 $\mathcal O(n ^ 2)$ 的,注意到两个后缀的 $\text{lcp}$ 是一段区间的 $height$ 的最小值,我们可以用单调栈来维护一下。栈内每个元素维护 $2$ 个值,记录当前的 $height$ 值和不小于该 $height$ 值的 $B$ 串后缀的数量。
每次加入新的后缀就要把栈内元素进行合并,并重新计算前缀答案。具体实现详见代码。
时间复杂度:$\mathcal O(n \log n)$。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
const int N = 2e5 + 5;
int n, n1, n2, k, stk[N], num[N];
char s[N];
template <int S>
struct SuffixArray {
static const int N = S << 1;
int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
void clear() {
memset(a, 0, sizeof(a));
memset(sa, 0, sizeof(sa));
memset(rk, 0, sizeof(rk));
memset(height, 0, sizeof(height));
}
void radixSort() {
for (int i = 1; i <= m; i++) bin[i] = 0;
for (int i = 1; i <= n; i++) bin[rk[i]]++;
for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
}
template <class Tp>
void build(Tp *_a, int _n, int _m) {
n = _n, m = _m;
std::copy(_a + 1, _a + n + 1, a + 1);
for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
radixSort();
for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
p = 0;
for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
radixSort();
std::swap(rk, tmp);
p = rk[sa[1]] = 1;
for (int i = 2; i <= n; i++) {
rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
}
}
for (int i = 1, k = 0; i <= n; i++) {
k -= (k > 0);
int j = sa[rk[i] - 1];
for (; a[i + k] == a[j + k]; k++);
height[rk[i]] = k;
}
}
};
SuffixArray<N> A;
void read(int &_n, int p) {
scanf("%s", s + n + 1);
_n = strlen(s + n + 1);
s[n += _n + 1] = '#' + p;
}
int main() {
while (scanf("%d", &k) && k) {
n = 0;
read(n1, 1), read(n2, 2);
A.clear();
A.build(s, n, 255);
int top = 0;
long long ans = 0, sum = 0;
for (int i = 1; i <= n; i++) {
if (A.height[i] < k) {
top = sum = 0;
} else {
int cnt = 0;
if (A.sa[i - 1] <= n1) {
cnt++;
sum += A.height[i] - k + 1;
}
for (; top && A.height[stk[top]] >= A.height[i]; top--) {
cnt += num[top];
sum -= 1LL * num[top] * (A.height[stk[top]] - A.height[i]);
}
stk[++top] = i, num[top] = cnt;
if (A.sa[i] > n1) ans += sum;
}
}
for (int i = 1; i <= n; i++) {
if (A.height[i] < k) {
top = sum = 0;
} else {
int cnt = 0;
if (A.sa[i - 1] > n1) {
cnt++;
sum += A.height[i] - k + 1;
}
for (; top && A.height[stk[top]] >= A.height[i]; top--) {
cnt += num[top];
sum -= 1LL * num[top] * (A.height[stk[top]] - A.height[i]);
}
stk[++top] = i, num[top] = cnt;
if (A.sa[i] <= n1) ans += sum;
}
}
printf("%lld\n", ans);
}
return 0;
}