题目链接:UOJ 424
如果一个序列满足序列长度为 $n$,序列中的每个数都是 $1$ 到 $m$ 内的整数,且所有 $1$ 到 $m$ 内的整数都在序列中出现过,则称这是一个「挺好序列」。
对于一个序列 $A$,记 $f_{A}(l, r)$ 为 $A$ 的第 $l$ 个到第 $r$ 个数中最大值的下标(如果有多个最大值,取下标最小的)。
两个序列 $A$ 和 $B$ 同构,当且仅当 $A$ 和 $B$ 长度相等,且对于任意 $i \le j$,均有 $f_{A}(i, j) = f_{B}(i, j)$。
给出 $n, m$,求有多少种不同构的「挺好序列」。答案对 $998244353$ 取模。
数据范围:$1 \le n, m \le 10 ^ {5}$。
Solution
首先,$n < m$ 一定无解。
如果我们对序列建笛卡尔树,那么可以将「序列同构」转化为「笛卡尔树同构」。
接下来证明一个结论:一棵笛卡尔树满足条件(长度为 $n$ 的序列中 $1 \sim m$ 都至少出现一次)的充要条件是所有左链长度(一条路径中左儿子的个数)均小于 $m$。
- 必要性:由于至多只有 $m$ 种数,且每次都取最左侧的最大值作为根节点,则左儿子一定严格小于父节点。左链长度大于等于 $m$ 则意味着数字种数大于 $m$。
- 充分性:假如我们有一棵所有左链长度均小于 $m$ 的二叉树。可以通过如下方法构造序列:将最长链上的点依次标记为 $m, m - 1, \ldots$,剩下的点按照深度从小到大标记即可。
于是,原问题转化为左链长度小于 $m$ 的二叉树计数。左链长度又可以转化为括号序列中的前缀和,故问题转化为前缀和不超过 $m$ 的合法括号序列计数。
根据经典问题:平面上每次能走 $(1, 1)$ 或 $(1, -1)$,从 $(0, 0)$ 走到 $(2n, 0)$ 且不经过 $y = -1$ 的路径方案数——再加上一个 $y = m + 1$ 的限制就是本问题。
可以使用容斥计算方案数。具体地,我们令这条路径先经过若干次 $y = -1$ 再若干次 $y = m + 1$ 再若干次 $y = -1$ 再若干次 $y = m + 1$……对起点坐标不断对称后计算方案数,容斥系数为 $(-1) ^ {\text{对称次数}}$(当然 $m + 1, -1, m + 1, -1, \ldots$ 也要计算一遍)。
对于每种不合法的方案,容易发现最终一定会被计算 $0$ 次。
时间复杂度:$\mathcal O(n)$。
Code
#include <bits/stdc++.h>
typedef long long int64;
const int N = 2e5, MOD = 998244353;
int n, m, fac[N + 5], ifac[N + 5];
inline int power(int x, int k) {
int ans = 1;
for (; k > 0; k >>= 1, x = (int64)x * x % MOD) {
if (k & 1) {
ans = (int64)ans * x % MOD;
}
}
return ans;
}
inline int inver(int x) {
return power(x, MOD - 2);
}
void init(int n) {
fac[0] = 1;
for (int i = 1; i <= n; i++) {
fac[i] = (int64)fac[i - 1] * i % MOD;
}
ifac[n] = inver(fac[n]);
for (int i = n; i >= 1; i--) {
ifac[i - 1] = (int64)ifac[i] * i % MOD;
}
}
inline int binom(int n, int m) {
return (int64)fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;
}
inline int calc(int x) {
return binom(n * 2, (n * 2 - x) / 2);
}
int main() {
scanf("%d%d", &n, &m);
init(n * 2);
if (m > n) {
printf("0\n");
return 0;
}
int64 ans = calc(0);
for (int x = 0, i = 1; i <= n; i++) {
x = (i & 1) ? (m + 1) * 2 - x : -2 - x;
if (std::abs(x) > n * 2) {
break;
}
(i & 1) ? (ans -= calc(x)) : (ans += calc(x));
}
for (int x = 0, i = 1; i <= n; i++) {
x = (i & 1) ? -2 - x : (m + 1) * 2 - x;
if (std::abs(x) > n * 2) {
break;
}
(i & 1) ? (ans -= calc(x)) : (ans += calc(x));
}
printf("%lld\n", (ans % MOD + MOD) % MOD);
return 0;
}