题目链接:Codeforces 452E

你有三个字符串 $(s_1, s_2, s_3)​$。对于每个整数 $l(1 \le l \le \min(\vert s_1 \vert, \vert s_2 \vert, \vert s_3 \vert)​$,你需要求出有多少三元组 $(i_1, i_2, i_3)​$ 满足 $s_k[i_k \dots i_k + l - 1](k = 1, 2, 3)​$ 两两相等。答案对 $10 ^ 9 + 7​$ 取模。

数据范围:$3 \le \sum_{i = 1} ^ 3 \vert s_i \vert \le 3 \times 10 ^ 5$。


Solution

首先考虑当 $l$ 确定时的做法。我们只需要对 $height$ 分组,保证每组内 $height(i) \ge l$。根据乘法原理和加法原理,如果第 $i$ 组内在第 $1, 2, 3$ 个字符串内的后缀分别有 $a_i, b_i, c_i$ 个,那么答案为 $\sum a_i \times b_i \times c_i$。

又注意到,随着 $l$ 的增大,组数是单调不降的。于是我们考虑从大到小枚举长度 $l​$,并不断合并区间,同时维护每组(每个区间)的三个值。可以通过并查集实现。

时间复杂度:$\mathcal O(n \log n)​$。


Code

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 1e6 + 5;
const int P = 1e9 + 7;

int n, n1, n2, n3, bl[N], f[N], idx[N], sum[N][4], ans[N];
char s[N];

template <int S>
struct SuffixArray {
    static const int N = S << 1;
    int n, m, a[N], sa[N], rk[N], bin[N], tmp[N], height[N];
    void clear() {
        memset(a, 0, sizeof(a));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        memset(height, 0, sizeof(height));
    }
    void radixSort() {
        for (int i = 1; i <= m; i++) bin[i] = 0;
        for (int i = 1; i <= n; i++) bin[rk[i]]++;
        for (int i = 1; i <= m; i++) bin[i] += bin[i - 1];
        for (int i = n; i >= 1; i--) sa[bin[rk[tmp[i]]]--] = tmp[i];
    }
    template <class Tp>
    void build(Tp *_a, int _n, int _m) {
        n = _n, m = _m;
        std::copy(_a + 1, _a + n + 1, a + 1);
        for (int i = 1; i <= n; i++) rk[i] = a[i], tmp[i] = i;
        radixSort();
        for (int l = 1, p = 0; p < n; l <<= 1, m = p) {
            p = 0;
            for (int i = n - l + 1; i <= n; i++) tmp[++p] = i;
            for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++p] = sa[i] - l;
            radixSort();
            std::swap(rk, tmp);
            p = rk[sa[1]] = 1;
            for (int i = 2; i <= n; i++) {
                rk[sa[i]] = (tmp[sa[i - 1]] == tmp[sa[i]] && tmp[sa[i - 1] + l] == tmp[sa[i] + l]) ? p : ++p;
            }
        }
        for (int i = 1, k = 0; i <= n; i++) {
            k -= (k > 0);
            int j = sa[rk[i] - 1];
            for (; a[i + k] == a[j + k]; k++);
            height[rk[i]] = k;
        }
    }
};

SuffixArray<N> A;

void read(int &_n, int p) {
    scanf("%s", s + n + 1);
    _n = strlen(s + n + 1);
    std::fill(bl + n + 1, bl + n + _n + 1, p);
    s[n += _n + 1] = '#' + p;
}
int find(int x) {
    return f[x] == x ? x : f[x] = find(f[x]);
}
void add(int &x, int y) {
    (x += y) >= P && (x -= P);
}
int calc(int x) {
    return 1LL * sum[x][1] * sum[x][2] % P * sum[x][3] % P;
}
int main() {
    read(n1, 1), read(n2, 2), read(n3, 3);
    A.build(s, n, 255);
    int l = std::min(n1, std::min(n2, n3));
    for (int i = 1; i <= n; i++) {
        f[i] = idx[i] = i, sum[i][bl[i]] = 1;
    }
    std::sort(idx + 1, idx + n + 1, [](int a, int b) {
        return A.height[a] > A.height[b];
    });
    for (int i = l, j = 1, now = 0; i >= 1; i--) {
        for (; j <= n && A.height[idx[j]] >= i; j++) {
            int u = find(A.sa[idx[j] - 1]), v = find(A.sa[idx[j]]);
            add(now, P - calc(u));
            add(now, P - calc(v));
            add(sum[u][1], sum[v][1]);
            add(sum[u][2], sum[v][2]);
            add(sum[u][3], sum[v][3]);
            add(now, calc(u));
            f[v] = u;
        }
        ans[i] = now;
    }
    for (int i = 1; i <= l; i++) {
        printf("%d%c", ans[i], " \n"[i == l]);
    }
    return 0;
}
最后修改:2019 年 06 月 28 日
如果觉得我的文章对你有用,请随意赞赏