题目链接:UOJ 395

小 A 被选为了 ION2018 的出题人,他精心准备了一道质量十分高的题目,且已经把除了题目命名以外的工作都做好了。

由于 ION 已经举办了很多届,所以在题目命名上也是有规定的,ION 命题手册规定:每年由命题委员会规定一个小写字母字符串,我们称之为那一年的命名串,要求每道题的名字必须是那一年的命名串的一个非空连续子串且不能和前一年的任何一道题目的名字相同

由于一些特殊的原因,小 A 不知道 ION2017 每道题的名字,但是他通过一些特殊手段得到了 ION2017 的命名串,现在小 A 有 $Q$ 次询问:每次给定 ION2017 的命名串 $S$ 和 ION2018 的命名串 $T$,求有几种题目的命名,使得这个名字一定满足命题委员会的规定,即是 ION2018 的命名串的一个非空连续子串且一定不会和 ION2017 的任何一道题目的名字相同。

由于一些特殊原因,所有询问给出的 ION2017 的命名串都是某个串的连续子串 $S[l \dots r]$。

数据范围:$1 \le \lvert S \rvert \le 5 \times 10 ^ 5$,$1 \le Q \le 10 ^ 5$,$\sum \lvert T \rvert \le 10 ^ 6$,$1 \le l \le r \le \lvert S \rvert$。


Solution

68 pts

考虑 $l = 1, r = \lvert S \rvert$ 的做法,这意味着每次的 ION2017 命名串一定是原串 $S$。

我们对 $S, T$ 分别建立 $\text{SAM}$,对 $\text{SAM}_T$ 的每个状态分别计算贡献。那么答案为:

$$ \sum_{i = 1} ^ {tot} \min(len(i) - len(link(i)), len(i) - lim(pos(i))) $$

注意:上述式子由于美观问题省略了一些细节,实际答案应该为 $\max(\min(\dots), 0)$。

其中 $pos(i)$ 为状态 $i$ 的 $endpos$ 集合中的任何一个元素(因为他们都是等价的),$lim(i)$ 表示 $T$ 以 $i$ 结尾的子串中,和 $S$ 匹配的最长长度,这个过程可以在 $\text{SAM}$ 进行转移或跳后缀链接,在 $\mathcal O(\lvert T \rvert)$ 的时间内求出。

时间复杂度:$\mathcal O(\lvert S \rvert + \sum \lvert T \rvert)$。

100 pts

注意到满分做法也可以套用部分分做法的式子,只不过 $lim(i)$ 没法直接在 $\text{SAM}$ 上求出了。

发现求 $lim(i)$ 的本质就是判断区间 $[l, r]$ 内是否存在某个子串,那么我们直接用线段树维护 $S$ 的 $\text{SAM}$ 中每个状态的 $endpos$ 集合,可以用线段树合并解决。具体实现详见代码。

时间复杂度:$\mathcal O((\lvert S \rvert + \sum \lvert T \rvert) \log \lvert S \rvert)$。


Code

不用在意两份代码中后缀自动机的初始状态下标不同……

68 pts

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 1e6 + 5;

int n, m, q, lim[N];
char s[N], t[N];

template <int MAX, int S>
struct SAM {
    static const int N = MAX << 1;
    int tot, lst, len[N], lnk[N], pos[N], nxt[N][S];
    void clear(int x) {
        len[x] = lnk[x] = pos[x] = 0;
        std::fill(nxt[x], nxt[x] + S, 0);
    }
    void clear() {
        clear(tot = lst = 0);
        lnk[0] = -1;
    }
    void insert(int x, int i) {
        int cur = ++tot, p = lst;
        clear(cur);
        len[cur] = len[lst] + 1;
        pos[cur] = i;
        for (; ~p && !nxt[p][x]; p = lnk[p]) nxt[p][x] = cur;
        if (p == -1) {
            lnk[cur] = 0;
        } else {
            int q = nxt[p][x];
            if (len[q] == len[p] + 1) {
                lnk[cur] = q;
            } else {
                int c = ++tot;
                clear(c);
                len[c] = len[p] + 1;
                lnk[c] = lnk[q];
                pos[c] = pos[q];
                std::copy(nxt[q], nxt[q] + S, nxt[c]);
                for (; ~p && nxt[p][x] == q; p = lnk[p]) nxt[p][x] = c;
                lnk[cur] = lnk[q] = c;
            }
        }
        lst = cur;
    }
};

SAM<N, 26> S1, S2;

int main() {
    scanf("%s%d", s + 1, &q);
    n = strlen(s + 1);
    S1.clear();
    for (int i = 1; i <= n; i++) S1.insert(s[i] - 'a', i);
    for (int _ = 1; _ <= q; _++) {
        int l, r;
        scanf("%s%d%d", t + 1, &l, &r);
        m = strlen(t + 1);
        S2.clear();
        for (int i = 1; i <= m; i++) S2.insert(t[i] - 'a', i);
        for (int i = 1, u = 0, s = 0; i <= m; i++) {
            int c = t[i] - 'a';
            for (; ~u && !S1.nxt[u][c]; u = S1.lnk[u], s = S1.len[u]);
            if (u == -1) {
                u = s = 0;
            } else {
                u = S1.nxt[u][c];
                s++;
            }
            lim[i] = s;
        }
        long long ans = 0;
        for (int i = 1; i <= S2.tot; i++) {
            int mx = std::max(S2.len[S2.lnk[i]], lim[S2.pos[i]]);
            ans += std::max(0, S2.len[i] - mx);
        }
        printf("%lld\n", ans);
    }
    return 0;
}

100 pts

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 1e6 + 5;

int n, m, q, c[N], lim[N], rt[N << 1], p[N << 1];
char s[N], t[N];

template <int MAX, int S>
struct SAM {
    static const int N = MAX << 1;
    int tot, lst, len[N], lnk[N], pos[N], nxt[N][S];
    void clear(int x) {
        len[x] = lnk[x] = pos[x] = 0;
        std::fill(nxt[x], nxt[x] + S, 0);
    }
    void clear() {
        clear(tot = lst = 1);
    }
    void insert(int x, int i) {
        int cur = ++tot, p = lst;
        clear(cur);
        len[cur] = len[lst] + 1;
        pos[cur] = i;
        for (; p && !nxt[p][x]; p = lnk[p]) nxt[p][x] = cur;
        if (!p) {
            lnk[cur] = 1;
        } else {
            int q = nxt[p][x];
            if (len[q] == len[p] + 1) {
                lnk[cur] = q;
            } else {
                int c = ++tot;
                clear(c);
                len[c] = len[p] + 1;
                lnk[c] = lnk[q];
                pos[c] = pos[q];
                std::copy(nxt[q], nxt[q] + S, nxt[c]);
                for (; p && nxt[p][x] == q; p = lnk[p]) nxt[p][x] = c;
                lnk[cur] = lnk[q] = c;
            }
        }
        lst = cur;
    }
};

template <int MAX>
struct SegmentTree {
    static const int N = MAX << 5;
    int tot, ls[N], rs[N];
    int merge(int p1, int p2) {
        if (!p1 || !p2) return p1 | p2;
        int p = ++tot;
        ls[p] = merge(ls[p1], ls[p2]);
        rs[p] = merge(rs[p1], rs[p2]);
        return p;
    }
    void modify(int &p, int l, int r, int x) {
        p = ++tot;
        if (l == r) return;
        int mid = (l + r) >> 1;
        if (x <= mid) {
            modify(ls[p], l, mid, x);
        } else {
            modify(rs[p], mid + 1, r, x);
        }
    }
    bool query(int p, int l, int r, int x, int y) {
        if (!p || x > y) return false;
        if (x == l && r == y) return true;
        int mid = (l + r) >> 1;
        if (y <= mid) {
            return query(ls[p], l, mid, x, y);
        } else if (x > mid) {
            return query(rs[p], mid + 1, r, x, y);
        } else {
            return query(ls[p], l, mid, x, mid) | query(rs[p], mid + 1, r, mid + 1, y);
        }
    }
};

SAM<N, 26> S1, S2;
SegmentTree<N << 1> seg;

void build() {
    for (int i = 1; i <= S1.tot; i++) c[S1.len[i]]++;
    for (int i = 1; i <= n; i++) c[i] += c[i - 1];
    for (int i = 1; i <= S1.tot; i++) p[c[S1.len[i]]--] = i;
    for (int i = S1.tot; i >= 2; i--) {
        int u = S1.lnk[p[i]], v = p[i];
        rt[u] = seg.merge(rt[u], rt[v]);
    }
}
int main() {
    scanf("%s%d", s + 1, &q);
    n = strlen(s + 1);
    S1.clear();
    for (int i = 1; i <= n; i++) {
        S1.insert(s[i] - 'a', i);
        seg.modify(rt[S1.lst], 1, n, i);
    }
    build();
    for (int _ = 1; _ <= q; _++) {
        int l, r;
        scanf("%s%d%d", t + 1, &l, &r);
        m = strlen(t + 1);
        S2.clear();
        for (int i = 1; i <= m; i++) S2.insert(t[i] - 'a', i);
        for (int i = 1, u = 1, s = 0; i <= m; i++) {
            int c = t[i] - 'a';
            while (true) {
                int v = S1.nxt[u][c];
                if (v && seg.query(rt[v], 1, n, l + s, r)) {
                    u = v, s++;
                    break;
                }
                if (!s) break;
                if (--s == S1.len[S1.lnk[u]]) u = S1.lnk[u];
            }
            lim[i] = s;
        }
        long long ans = 0;
        for (int i = 1; i <= S2.tot; i++) {
            int mx = std::max(S2.len[S2.lnk[i]], lim[S2.pos[i]]);
            ans += std::max(0, S2.len[i] - mx);
        }
        printf("%lld\n", ans);
    }
    return 0;
}
最后修改:2019 年 06 月 01 日
如果觉得我的文章对你有用,请随意赞赏