题目链接:BZOJ 3295

对于序列 $a_i$,它的逆序对数定义为满足 $i<j$,且 $a_i > a_j$ 的数对 $(i,j)$ 的个数。给 $1$ 到 $n$ 的一个排列,按照某种顺序依次删除 $m$ 个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。

数据范围:$1\le n\le 10^5$,$0\le m\le 5\times 10^4$。


Solution

我们将删除操作倒过来变成插入操作。考虑使用 $\text{CDQ}$ 分治,我们对插入时间分治,合并时按照插入位置归并。考虑左侧对右侧的影响:位置更小值更大的有贡献、位置更大值更小的有贡献。分两类对右区间算贡献即可。

时间复杂度:$\mathcal O(n\log^2 n)$。


Code

#include <cstdio>

typedef long long LL;

const int N = 1e5 + 5;

int n, m, a[N], d[N], p[N];
LL ans[N];
bool del[N];

struct Data {
    int idx, pos, val;
    LL ans;
    Data(int _idx = 0, int _pos = 0, int _val = 0, LL _ans = 0) {
        idx = _idx, pos = _pos, val = _val, ans = _ans;
    }
} q[N], t[N];
struct BIT {
    int b[N];
    void add(int x, int v) {
        for (; x <= n; x += x & -x) b[x] += v;
    }
    int query(int x) {
        int ans = 0;
        for (; x; x ^= x & -x) ans += b[x];
        return ans;
    }
} bit;

void CDQ(int l, int r) {
    if (l == r) {
        return;
    }
    int m = (l + r) >> 1;
    CDQ(l, m), CDQ(m + 1, r);
    int i = l, j = m + 1, k = l;
    while (i <= m && j <= r) {
        if (q[i].pos < q[j].pos) {
            t[k++] = q[i++];
        } else {
            t[k++] = q[j++];
        }
    }
    while (i <= m) t[k++] = q[i++];
    while (j <= r) t[k++] = q[j++];
    for (int i = l; i <= r; i++) q[i] = t[i];
    int cnt = 0;
    for (int i = l; i <= r; i++) {
        if (q[i].idx <= m) {
            cnt++, bit.add(q[i].val, 1);
        } else {
            q[i].ans += cnt - bit.query(q[i].val);
        }
    }
    for (int i = l; i <= r; i++) {
        if (q[i].idx <= m) bit.add(q[i].val, -1);
    }
    for (int i = r; i >= l; i--) {
        if (q[i].idx <= m) {
            bit.add(q[i].val, 1);
        } else {
            q[i].ans += bit.query(q[i].val);
        }
    }
    for (int i = r; i >= l; i--) {
        if (q[i].idx <= m) bit.add(q[i].val, -1);
    }

}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        p[a[i]] = i;
    }
    for (int i = 1; i <= m; i++) {
        scanf("%d", &d[i]);
        del[p[d[i]]] = 1;
    }
    int idx = 0;
    for (int i = 1; i <= n; i++) {
        if (!del[i]) {
            idx++, q[idx] = Data(idx, i, a[i]);
        }
    }
    for (int i = m; i >= 1; i--) {
        idx++, q[idx] = Data(idx, p[d[i]], d[i]);
    }
    CDQ(1, n);
    for (int i = 1; i <= n; i++) {
        ans[q[i].idx] = q[i].ans;
    }
    for (int i = 1; i <= n; i++) {
        ans[i] += ans[i - 1];
    }
    for (int i = n; i >= n - m + 1; i--) {
        printf("%lld\n", ans[i]);
    }
    return 0;
}
最后修改:2019 年 07 月 25 日
如果觉得我的文章对你有用,请随意赞赏