题目链接:BZOJ 3110

有 $n$ 个位置,$m$ 个操作。操作分为以下 $2$ 种:

  • 1 a b c:表示在第 $a$ 个位置到第 $b$ 个位置,每个位置加入一个数 $c$。
  • 2 a b c:表示询问从第 $a$ 个位置到第 $b$ 个位置,第 $c$ 大的数是多少。

数据范围:$1\le n,m\le 5\times 10^4$,操作 $1$ 中 $0\le \vert c \vert \le 2 ^ {63} - 1$。


Solution

我们考虑整体二分。用树状数组维护区间和——树状数组是可以支持区间修改、区间查询的,并且复杂度还是 $\mathcal O(\log n)$ 的(众所周知应该是 $\mathcal O(1)$ 的)。然后按照套个板子就做完了。

时间复杂度:$\mathcal O(n\log^2 n)$。


Code

#include <cstdio>
#include <algorithm>

typedef long long LL;

const int N = 2e5 + 5;
const LL INF = 0x3f3f3f3f3f3f3f3f;

int n, m;
LL ans[N];

struct Data {
    int opt, x, y, idx;
    LL k;
} q[N], q1[N], q2[N];

struct BIT {
    LL c1[N], c2[N];
    void modify(int x, LL v) {
        int _x = x;
        for (; x <= n; x += x & -x) {
            c1[x] += v, c2[x] += 1LL * _x * v;
        }
    }
    LL query(int x) {
        int _x = x;
        LL ans = 0;
        for (; x; x ^= x & -x) {
            ans += 1LL * (_x + 1) * c1[x] - c2[x];
        }
        return ans;
    }
} bit;

void solve(LL l, LL r, int L, int R) {
    if (l > r || L > R) {
        return;
    }
    if (l == r) {
        for (int i = L; i <= R; i++) {
            if (q[i].opt == 2) ans[q[i].idx] = l;
        }
        return;
    }
    LL mid = (l + r) >> 1;
    int t1 = 0, t2 = 0;
    for (int i = L; i <= R; i++) {
        if (q[i].opt == 1) {
            if (q[i].k > mid) {
                q2[++t2] = q[i];
                bit.modify(q[i].x, 1);
                bit.modify(q[i].y + 1, -1);
            } else {
                q1[++t1] = q[i];
            }
        } else {
            LL sum = bit.query(q[i].y) - bit.query(q[i].x - 1);
            if (sum >= q[i].k) {
                q2[++t2] = q[i];
            } else {
                q[i].k -= sum;
                q1[++t1] = q[i];
            }
        }
    }
    for (int i = 1; i <= t2; i++) {
        if (q2[i].opt == 1) {
            bit.modify(q2[i].x, -1);
            bit.modify(q2[i].y + 1, 1);
        }
    }
    std::copy(q1 + 1, q1 + t1 + 1, q + L);
    std::copy(q2 + 1, q2 + t2 + 1, q + L + t1);
    solve(l, mid, L, L + t1 - 1);
    solve(mid + 1, r, L + t1, R);
}
int main() {
    scanf("%d%d", &n, &m);
    int cnt = 0;
    for (int i = 1; i <= m; i++) {
        scanf("%d%d%d%lld", &q[i].opt, &q[i].x, &q[i].y, &q[i].k);
        if (q[i].opt == 2) q[i].idx = ++cnt;
    }
    solve(-INF, INF, 1, m);
    for (int i = 1; i <= cnt; i++) {
        printf("%lld\n", ans[i]);
    }
    return 0;
}
最后修改:2019 年 06 月 28 日
如果觉得我的文章对你有用,请随意赞赏