概述
所谓多项式求逆,就是给定一个 $n - 1$ 次多项式 $A(x)$,你需要求出一个多项式 $B(x)$ 满足 $A(x)B(x) \equiv 1\pmod{x ^ n}$。
思路
我们考虑一个子问题:假如已经求出了多项式 $A(x)$ 在模 $x ^ {\left\lceil\frac{n}{2}\right\rceil}$ 意义下的逆元 $B'(x)$,那么有:
$$ A(x)B'(x) \equiv 1\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$
又因为:
$$ A(x)B(x) \equiv 1\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$
将两式相减得到:
$$ A(x)[B(x) - B'(x)] \equiv 0\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$
由于 $A(x)\not\mid x ^ {\left\lceil\frac{n}{2}\right\rceil}$,那么我们可以把 $A(x)$ 除掉得到:
$$ B(x) - B'(x) \equiv 0\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$
接下来将等式两边平方得到:
$$ [B(x) - B'(x)] ^ 2 \equiv 0\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$
分析一下平方后的多项式有什么特点。设 $P(x) = B(x) - B'(x)$,那么对于 $P(x)$ 任意的 $i \in \left[0, \left\lceil\frac{n}{2}\right\rceil \right)$,第 $i$ 项的系数均为 $0$。考虑将其平方后,得到系数 $a'_i = \sum_{j = 0} ^ i a_j\times a_{i - j}$,对于任意的 $i \in \left[0, 2\times\left\lceil\frac{n}{2}\right\rceil\right)$,$i$ 或 $i - j$ 中必有一个值小于 $\left\lceil\frac{n}{2}\right\rceil$,那么 $a_i$ 和 $a_{i - j}$ 中必有一项值为 $0$。换言之,我们可以得到结论:$P(x) ^ 2$ 在模 $x ^ n$ 的意义下与 $0$ 同余。
$$ B ^ 2(x) +B' ^ 2(x) - 2 B(x) B'(x)\equiv 0 \pmod{x ^ n} $$
我们在两边同时乘上 $A(x)$ 得到:
$$ A(x)B ^ 2(x) + A(x)B' ^ 2(x) - 2A(x)B(x)B'(x) \equiv 0\pmod{x ^ n} $$
通过逆元的定义 $A(x)B(x)\equiv 1\pmod{x ^ n}$ 可以化简为:
$$ B(x) + A(x)B' ^ 2(x) - 2B'(x)\equiv 0\pmod{x ^ n} $$
移项得到:
$$ B(x) = 2B'(x) - A(x)B' ^ 2(x)\pmod{x ^ n} $$
时间复杂度:$\mathcal O(n \log n)$。
实现
实现主要由如下两种方式:
- 递归:通过实现部分的前几句话,我们就可以知道可以递归求解。递归边界为 $n = 1$,此时答案为常数项的逆元。
- 迭代:枚举迭代长度也可以求得答案。迭代实现的常数较小,但是细节较多。
考虑证明复杂度,$T(n) = T(\frac{n}{2}) + \mathcal O(n\log n)$。通过主定理可以得到复杂度为 $\mathcal O(n\log n)$。
代码
此处只给出迭代实现的代码。
Vec operator ~ (Vec A) {
int n = A.size(), N = extend(n);
A.resize(N);
Vec I(N, 0);
I[0] = inv(A[0]);
for (int l = 2; l <= N; l <<= 1) {
Vec P(l), Q(l);
std::copy(A.begin(), A.begin() + l, P.begin());
std::copy(I.begin(), I.begin() + l, Q.begin());
int L = l << 1;
P.resize(L), DFT(P);
Q.resize(L), DFT(Q);
for (int i = 0; i < L; i++) {
P[i] = 1LL * Q[i] * (2 - 1LL * P[i] * Q[i] % MOD + MOD) % MOD;
}
IDFT(P), P.resize(l);
std::copy(P.begin(), P.begin() + l, I.begin());
}
I.resize(n);
return I;
}