题目链接:Codeforces 280D
你有一个长度为 $n$ 的序列 $a_i$,接下来进行 $m$ 次操作,操作分为如下 $2$ 种:
0 i val
:将第 $i$ 个数 $a_i$ 修改为 $val$。1 l r k
:你需要在序列 $a_l, a_{l + 1}, \cdots, a_r$ 中找出至多 $k$ 个不相交的子序列,使得他们的和最大。形式化地,你需要找出至多 $k$ 对 $(x_1, y_1), (x_2, y_2), \cdots, (x_t, y_t)$(其中 $l\le x_1\le y_1<x_2\le y_2<\cdots<x_t \le y_t\le r$,$0\le t\le k$),使得 $(a_{x_1} + a_{x_1 + 1} + \cdots + a_{y_1}) + (a_{x_2} + a_{x_ 2 + 1} + \cdots + a_{y_2})+\cdots + (a_{x_t} + a_{x_t + 1} + \cdots + a_{y_t})$ 的值最大。特别地,你可以选择 $0$ 个子序列,这时和式等于 $0$。
数据范围:$1 \le n, m\le 10 ^ 5$,$\vert a_i, val \vert \le 500$,$1\le k\le 20$,求 $k$ 个子序列和的操作不超过 $10^4$ 个。
Solution
我们很容易建立起费用流的模型。
- 源点像每个点连流量为 $1$、费用为 $0$ 的边。
- 每个点向下一个点连流量 $1$,费用为 $a_i$ 的边。
- 每个点向汇点连流量为 $1$、费用为 $0$ 的边。
可以发现多流一个单位的流量就会多出一个区间。那么问题就变成使用不超过 $k$ 个单位的流量,能够得到的最大费用。
但是直接上费用流肯定 $\text{TLE}$,但是注意到每次増广的贡献是一段区间,我们考虑用线段树维护来模拟费用流。
我们把増广的过程转化为线段树能处理的问题:
- 每次増广相当于询问整个区间的最大子段和。
- 每次更新反向弧的费用相当于将区间取反。
- 増广 $k$ 次相当于在线段树上进行上述过程 $k$ 次。
这样一来我们就可以直接在线段树上操作了!
最后考虑一下代码实现问题!由于我们要实现区间最大子段和及其范围、区间取反,因此需要维护的不止左右端点最大值、区间最大答案……对于每个最大值,还需要记录最小值,这样才能实现区间取反的问题。
为了使代码难度降低,我们可以定义一个 $\text{struct}$ 记录每个值的 $l,r,val$ 等信息,重载运算符来合并区间。
时间复杂度:$\mathcal O(mk\log n)$ 且带有大常数。
Code
#include <cstdio>
#include <algorithm>
#define lson p << 1
#define rson p << 1 | 1
const int N = 1e5 + 5;
int n, m, A[N];
struct Data {
int l, r, val;
Data(int _l = 0, int _r = 0, int _val = 0) {
l = _l, r = _r, val = _val;
}
Data operator+(const Data &b) const {
return Data(l, b.r, val + b.val);
}
bool operator<(const Data &b) const {
return val < b.val;
}
};
struct Node {
int rev;
Data sum, lmx, lmn, rmx, rmn, smx, smn;
Node() {
rev = 0;
sum = lmx = lmn = rmx = rmn = smx = smn = Data();
}
void init(int pos, int val) {
rev = 0;
sum = lmx = lmn = rmx = rmn = smx = smn = Data(pos, pos, val);
}
void reverse() {
rev ^= 1;
std::swap(lmx, lmn);
std::swap(rmx, rmn);
std::swap(smx, smn);
sum.val *= -1;
lmx.val *= -1, lmn.val *= -1;
rmx.val *= -1, rmn.val *= -1;
smx.val *= -1, smn.val *= -1;
}
};
struct Segment {
Node a[N << 2];
Node merge(Node x, Node y) {
Node ans;
ans.sum = x.sum + y.sum;
ans.lmx = std::max(x.lmx, x.sum + y.lmx);
ans.lmn = std::min(x.lmn, x.sum + y.lmn);
ans.rmx = std::max(y.rmx, x.rmx + y.sum);
ans.rmn = std::min(y.rmn, x.rmn + y.sum);
ans.smx = std::max(x.rmx + y.lmx, std::max(x.smx, y.smx));
ans.smn = std::min(x.rmn + y.lmn, std::min(x.smn, y.smn));
return ans;
}
void pushdown(int p) {
if (a[p].rev) {
a[lson].reverse(); a[rson].reverse();
a[p].rev = 0;
}
}
void build(int p, int l, int r) {
if (l == r) {
a[p].init(l, A[l]);
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid);
build(rson, mid + 1, r);
a[p] = merge(a[lson], a[rson]);
}
void modify(int p, int l, int r, int x, int v) {
if (l == r) {
a[p].init(l, v);
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if (x <= mid) {
modify(lson, l, mid, x, v);
} else {
modify(rson, mid + 1, r, x, v);
}
a[p] = merge(a[lson], a[rson]);
}
void reverse(int p, int l, int r, int x, int y) {
if (l == x && y == r) {
a[p].reverse();
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if (y <= mid) {
reverse(lson, l, mid, x, y);
} else if(x > mid) {
reverse(rson, mid + 1, r, x, y);
} else {
reverse(lson, l, mid, x, mid);
reverse(rson, mid + 1, r, mid + 1, y);
}
a[p] = merge(a[lson], a[rson]);
}
Node query(int p, int l, int r, int x, int y) {
if (l == x && y == r) {
return a[p];
}
pushdown(p);
int mid = (l + r) >> 1;
if (y <= mid) {
return query(lson, l, mid, x, y);
} else if(x > mid) {
return query(rson, mid + 1, r, x, y);
} else {
return merge(query(lson, l, mid, x, mid), query(rson, mid + 1, r, mid + 1, y));
}
}
} seg;
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &A[i]);
}
seg.build(1, 1, n);
scanf("%d", &m);
for (int i = 1; i <= m; i++) {
int opt;
scanf("%d", &opt);
if (!opt) {
int x, v;
scanf("%d%d", &x, &v);
seg.modify(1, 1, n, x, v);
} else {
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
int ans = 0, tp = 0;
Data st[25];
for (int j = 1; j <= k; j++) {
Node now = seg.query(1, 1, n, l, r);
if (now.smx.val < 0) {
break;
}
ans += now.smx.val;
st[++tp] = now.smx;
seg.reverse(1, 1, n, now.smx.l, now.smx.r);
}
for (int j = 1; j <= tp; j++) {
seg.reverse(1, 1, n, st[j].l, st[j].r);
}
printf("%d\n", ans);
}
}
return 0;
}