题目链接:Codeforces 593D
今天是 Bogdan 的生日,他的母亲送给他一棵有 $n$ 个节点的树,每条边上有一个数字 $x_i$。有 $m$ 个客人参加了 Bogdan 的生日排队。第 $i$ 个客人到达后,他会进行如下 $2$ 种操作中的恰好一种:
- 选择一个数字 $y_i$ 和两个节点 $a_i,b_i$。接下来他沿着 $a_i$ 到 $b_i$ 的最短路径行走。每经过一条边 $j$,他就把当前的数字 $y_i$ 替换成 $\left\lfloor\frac{y_i}{x_j}\right\rfloor$。最后求出 $y_i$ 的值。
- 选择一条边 $p_i$,将这条边上的值 $x_{p_i}$ 替换成 $c_i$。其中 $c_i<x_{p_i}$。
由于 Bogdan 非常好客,他希望编写一个程序执行所有的操作,并对每个操作 $1$ 求出其结果 $y_i$ 的值。
数据范围:$2\le n \le 2\times 10^5$,$1\le m\le 2\times 10^5$,$1\le x_i,y_i\le 10^{18}$,$1\le c_i<x_{p_i}$。
Solution
显然用树链剖分是可以做的,但是鉴于其码量大、常数大,以及没有利用本题的性质,再此不赘述了。接下来将介绍一种好写、复杂度优越的做法。
首先我们发现,如果不考虑边权为 $1$ 的边,我们至多只会进行 $\mathcal O(\log y_i)$ 次除法。换言之,如果这棵树上没有权值为 $1$ 的边,我们可以暴力往上跳求出最后的 $y_i$。
又因为修改是单调下降的,于是一条边的边权只可能变小。对于一条权值为 $1$ 的边,我们把他两侧的点缩起来,使用并查集即可实现。但是需要注意合并的方向:一定是深度大的点合并到深度小的点上,否则复杂度是错的。
时间复杂度:$\mathcal O(m\log y_i)$。
Code
#include <cstdio>
#include <algorithm>
const int N = 2e5 + 5, M = 4e5 + 5;
int n, m, tot, lnk[N], sta[M], ter[M], nxt[M], up[N], fa[N], f[N], dep[N];
long long val[M];
void add(int u, int v, long long w) {
ter[++tot] = v, sta[tot] = u, nxt[tot] = lnk[u], lnk[u] = tot, val[tot] = w;
}
int find(int x) {
return f[x] == x ? x : f[x] = find(f[x]);
}
void merge(int u, int v) {
f[find(u)] = find(v);
}
void dfs(int u, int p) {
for (int i = lnk[u]; i; i = nxt[i]) {
int v = ter[i];
if (v == p) continue;
fa[v] = u, up[v] = i, dep[v] = dep[u] + 1;
if (val[i] == 1) {
merge(v, u);
dep[v] = dep[find(v)];
}
dfs(v, u);
}
}
long long query(int u, int v) {
long long ans = 1;
for (; (u = find(u)) != (v = find(v)); u = fa[u]) {
if (dep[u] < dep[v]) {
std::swap(u, v);
}
if (ans > 1e18 / val[up[u]]) {
return 0;
}
ans *= val[up[u]];
}
return ans;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i < n; i++) {
int u, v;
long long w;
scanf("%d%d%lld", &u, &v, &w);
add(u, v, w), add(v, u, w);
}
for (int i = 1; i <= n; i++) {
f[i] = i;
}
dep[1] = 1;
dfs(1, 0);
for (int i = 1; i <= m; i++) {
int opt;
scanf("%d", &opt);
if (opt == 1) {
int u, v;
long long w;
scanf("%d%d%lld", &u, &v, &w);
long long ans = query(u, v);
printf("%lld\n", ans ? w / ans : 0);
} else {
int x;
long long w;
scanf("%d%lld", &x, &w);
x <<= 1;
val[x - 1] = val[x] = w;
if (w == 1) {
int u = find(sta[x]), v = find(ter[x]);
if (dep[u] > dep[v]) {
std::swap(u, v);
}
merge(v, u);
dep[v] = dep[find(v)];
}
}
}
return 0;
}