题目链接:Codeforces 271D
你有一个包含小写字母的字符串 $S$,有一些字母是好的,其余的是坏的。如果在 $S_l, S_{l + 1}, \cdots, S_r$ 中有至多 $k$ 个坏的字母,那么子串 $S_{l,r}$ 是好的。
你需要找出 $S$ 中本质不同的好的子串数量。两个子串 $S_{x, y}$ 和 $S_{p, q}$ 是不同的当且仅当 $S_{x, y} \ne S_{p, q}$。
数据范围:$1 \le \lvert S \rvert \le 1500$,$0 \le k \le \lvert S \rvert$。
Solution
首先对字符串建立后缀数组。对于第 $sa(i)$ 个位置,二分找到最后一个位置 $j$ 满足 $S_{sa(i), j}$ 中至多有 $k$ 个坏的字母,那么 $S_{sa(i), sa(i) \sim j}$ 这些字串都是满足条件的。减去重复的子串 $height(i)$ 就是本质不同的子串个数。
时间复杂度:$\mathcal O(n \log n)$。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
const int N = 2e3 + 5;
int n, k, a[N], cnt[N];
char s[N], t[N];
template <int S>
struct SuffixArray {
static const int N = S << 1;
int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
void clear() {
memset(a, 0, sizeof(a));
memset(sa, 0, sizeof(sa));
memset(rk, 0, sizeof(rk));
memset(height, 0, sizeof(height));
}
void radixSort() {
for (int i = 1; i <= m; i++) bin[i] = 0;
for (int i = 1; i <= n; i++) bin[rk[i]]++;
for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
}
template <class Tp>
void build(Tp *_a, int _n, int _m) {
n = _n, m = _m;
std::copy(_a + 1, _a + n + 1, a + 1);
for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
radixSort();
for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
p = 0;
for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
radixSort();
std::swap(rk, tmp);
p = rk[sa[1]] = 1;
for (int i = 2; i <= n; i++) {
rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
}
}
for (int i = 1, k = 0; i <= n; i++) {
k -= (k > 0);
int j = sa[rk[i] - 1];
for (; a[i + k] == a[j + k]; k++);
height[rk[i]] = k;
}
}
};
SuffixArray<N> A;
int calc(int st) {
int l = st - 1, r = n, ans = 0;
while (l <= r) {
int mid = (l + r) >> 1;
cnt[mid] - cnt[st - 1] <= k ? l = (ans = mid) + 1 : r = mid - 1;
}
return ans;
}
int main() {
scanf("%s%s%d", s + 1, t + 1, &k);
n = strlen(s + 1);
A.build(s, n, 255);
for (int i = 1; i <= n; i++) {
cnt[i] = cnt[i - 1] + (t[s[i] - 'a' + 1] == '0');
}
int ans = 0;
for (int i = 1; i <= n; i++) {
ans += std::max(calc(A.sa[i]) - A.sa[i] + 1 - A.height[i], 0);
}
printf("%d\n", ans);
return 0;
}