题目链接:LOJ 2483

有 $n$ 根柱子依次排列,每根柱子都有一个高度。第 $i$ 根柱子的高度为 $h_i$。

现在想要建造若干座桥,如果一座桥架在第 $i$ 根柱子和第 $j$ 根柱子之间,那么需要 $(h_i - h_j)^2$ 的代价。

在造桥前,所有用不到的柱子都会被拆除,因为他们会干扰造桥进程。第 $i$ 根柱子被拆除的代价为 $w_i$,注意 $w_i$ 不一定非负,因为可能政府希望拆除某些柱子。

现在政府想要知道,通过桥梁把第 $1$ 根柱子和第 $n$ 根柱子连接的最小代价。注意桥梁不能在端点以外的任何地方相交。

数据范围:$2 \le n \le 10 ^ 5$,$0 \le h_i, \vert w_i \vert \le 10 ^ 6$。


Solution

首先设 $S_i = \sum_{j = 1} ^ i w_j$,我们可以很容易地推出 $\text{DP}$ 方程:

$$ f(i) = \min_{j = 0} ^ {i - 1} f(j) + (h_i - h_j) ^ 2 + S_{i - 1} - S_j $$

注意到这个东西可以使用斜率优化,如果决策 $j$ 比决策 $k$ 优,那么有:

$$ f(j) + (h_i - h_j) ^ 2 - S_j \le f(k) + (h_i - h_k)^ 2 - S_k $$

我们设 $X(i) = h_i, Y(i) = f(i) +{h_i} ^ 2 - S_i$,那么上式化为:

$$ \frac{Y(j) - Y(k)}{X(j) - X(k)} \le 2\cdot h_i $$

但是这里的 $2\cdot h_i$ 不是单调的,我们需要对 $h_i$ 使用 $\text{CDQ}$ 分治。实现过程中需要特别注意 $X(j) = X(k)$ 的情况。

时间复杂度:$\mathcal O(n \log n)$。


Code

#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 1e5 + 5;

int n, h[N], w[N], pos[N], t[N], q[N];
long long s[N], x[N], y[N], f[N];

double slope(int i, int j) {
    if (x[i] == x[j]) return y[i] < y[j] ? 1e18 : -1e18;
    return 1.0 * (y[i] - y[j]) / (x[i] - x[j]);
}
void CDQ(int l, int r) {
    if (l == r) {
        int i = pos[l];
        x[i] = h[i];
        y[i] = f[i] - s[i] + 1LL * h[i] * h[i];
        return;
    }
    int m = (l + r) >> 1;
    std::stable_partition(pos + l, pos + r + 1, [m](int x) {
        return x <= m;
    });
    CDQ(l, m);
    int _l = 1, _r = 0;
    for (int j = l; j <= m; j++) {
        int i = pos[j];
        while (_l < _r && slope(q[_r - 1], q[_r]) >= slope(q[_r], i)) _r--;
        q[++_r] = i;
    }
    for (int j = m + 1; j <= r; j++) {
        int i = pos[j];
        while (_l < _r && slope(q[_l], q[_l + 1]) <= 2.0 * h[i]) _l++;
        int t = q[_l];
        f[i] = std::min(f[i], f[t] + s[i - 1] - s[t] + 1LL * (h[i] - h[t]) * (h[i] - h[t]));
    }
    CDQ(m + 1, r);
    std::merge(pos + l, pos + m + 1, pos + m + 1, pos + r + 1, t + l, [](int x, int y){
        return h[x] < h[y];
    });
    std::copy(t + l, t + r + 1, pos + l);
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &h[i]), pos[i] = i;
    for (int i = 1 ;i <= n; i++) scanf("%d", &w[i]), s[i] = s[i - 1] + w[i];
    std::sort(pos + 1, pos + n + 1, [](int x, int y) {
        return h[x] < h[y];
    });
    memset(f, 0x7f, sizeof(f));
    f[1] = 0;
    CDQ(1, n);
    printf("%lld\n", f[n]);
    return 0;
}
最后修改:2019 年 06 月 28 日
如果觉得我的文章对你有用,请随意赞赏