概述
众所周知,平面上 $n+1$ 个点可以确定一个 $n$ 次多项式 $P(x)$。现在已知 $n+1$ 个点 $(x_i,y_i)$,请确定经过这 $n+1$ 个点的多项式 $P(x)$ 在 $x=k$ 时的值。
拉格朗日插值可以解决这类问题!
实现
拉格朗日插值的经精妙之处就在于拉格朗日基本多项式。
我们设这 $n$ 个点为 $(x_0,y_0),(x_1,y_1),\dots,(x_n,y_n)$。那么我们构造多项式:
$$ \ell_i(k)=\prod_{j=0,i\neq j}^n \frac{k-x_j}{x_i-x_j} $$
这个多项式的构造十分巧妙,我们可以注意到如下规律:
$$ \ell_i(x_j)= \begin{cases} 1 & i=j \\ 0 & i\neq j \end{cases} $$
接着我们构造这个 $n$ 次多项式:
$$ P(x)=\sum_{i=0}^n y_i\ell_i(x) $$
根据基本多项式的性质,我们可以得到 $P(x_i)=y_i$,也就是说保证经过这 $n+1$ 个点,经过简单的计算就可以得到系数表达式,当然也很容易得到 $P(k)$ 的值了。
时间复杂度:$\mathcal O(n^2)$
扩展
取值连续
考虑在 $x$ 的值连续的情况下,如何快速求 $P(k)$ 的值?(这里的 $x$ 的值连续的意思是:对于 $0\le i\le n$,$x_i=i$)
那么式子化为:
$$ P(x)=\sum_{i=0}^n y_i\prod_{j=0,i\neq j}^n \frac{x-j}{i-j} $$
对于分子,我们维护 $pre_i=\prod_{j=0}^i k-j$,$suf_i=\prod_{j=i}^n k-j$;而对于分母,我们发现就是阶乘的形式!于是这个式子变成了:
$$ P(k)=\sum_{i=0}^n y_i\frac{pre_{i-1}\cdot suf_{i+1}}{i!\cdot (n-i)!} $$
但是注意:分母可能出现符号问题,也就是说说当 $n-i$ 为奇数时,分母应该取负号!
时间复杂度:$\mathcal O(n)$
经典应用
求解 $1$ 到 $n$ 的 $k$ 次方和。即如下式子:
$$ \sum_{i=1}^n i^k $$
数据范围:$1\le n\le 10^{18}$,$1\le k\le 10^6$
这个东西显然是一个以 $n$ 为自变量的 $k+1$ 次多项式。那么我们代入 $1\sim k+2$ 这一共 $k+2$ 个点后直接拉格朗日插值即可!
定理:若多项式 $P(x)$ 的次数为 $n$,那么对于任何 $k\ge n+1$,多项式 $P(x)$ 的 $k$ 阶差分恒等于 $0$。
我们预处理出 $pre_i$ 和 $suf_i$,在处理 $k+2$ 个点时可以采用线性筛求出每个数的 $k$ 次方。
时间复杂度:$\mathcal O(k)$
代码
我们用一个函数 $\text{lagrange}$ 来实现。其中 *x, *y
分别存点的坐标,函数的返回值为 $P(k)$ 的值(其中 $\text{mod}$ 为模数,$\text{inv}(x)$ 表示 $x$ 的逆元,$\text{add}(x,y)$ 表示将 $x$ 加上 $y$)。
常用写法
int lagrange(int n, int *x, int *y, int k) {
int ans = 0;
for (int i = 1; i <= n; i++) {
int f = 1, g = 1;
for (int j = 1; j <= n; j++) {
if (i == j) continue;
f = 1LL * f * (k - x[j] + mod) % mod;
g = 1LL * g * (x[i] - x[j] + mod) % mod;
}
add(ans, 1LL * y[i] * f % mod * inv(g) % mod);
}
return ans;
}
取值连续
int lagrange(int n, int *y, int k) {
pre[0] = suf[n + 1] = 1;
for (int i = 1; i <= n; i++) {
pre[i] = 1LL * pre[i - 1] * (k - i) % mod;
}
for (int i = n; i >= 1; i--) {
suf[i] = 1LL * suf[i + 1] * (k - i) % mod;
}
int ans = 0;
for (int i = 1; i <= n; i++) {
int a = 1LL * pre[i - 1] * suf[i + 1] % mod * inv[i - 1] % mod * inv[n - i] % mod;
if ((n - i) & 1) a = mod - a;
add(ans, 1LL * a * y[i] % mod);
}
return ans;
}