题目
给出一棵有根树,问有多少个拓扑序满足 p i = i p_i=i pi=i 。 n ≤ 5000 n\leq 5000 n≤5000
样例
input1:
4
1 1 2
output1:
3 2 1 2
input2:
9
1 1 2 2 3 3 4 4 5
output2:
672 420 180 160 152 108 120 170 210
题解
考虑设 f[x][i]
表示还未在序列中插入 x
子树内除了 x
外的其他节点,满足
p
x
=
i
p_x=i
px=i 的拓扑序个数。
那么对于某一个点 x
,答案就是 f[x][x]*C(n-x,sz[x]-1)*g[x]
。
其中 C(n-x,sz[x]-1)
表示 sz[x]-1
个元素只能放在位置 x
的后面。
g[x]
表示的是 x
的子树的拓扑序个数,这个可以一遍 dfs
轻松得到。
考虑 f[x][i]
到 f[y][j]
的转移:
发现除了 x
的子树外,其他所有 n-sz[x]+1
个点(包括 x
)的拓扑序已经确定好了,现在要转移到 f[y][j]
,所以我们要先将 y
所在的子树剔除掉(n-sz[y]
),然后再将 x
的其余旁支加进 x
的拓扑序后面(注意是位置 i
,所以总共的位置是 n-sz[y]-i
)。
那么转移系数就是
(
n
−
s
z
[
y
]
−
i
s
z
[
y
1
]
,
s
z
[
y
2
]
,
.
.
.
,
s
z
[
y
k
−
1
]
,
n
−
s
z
[
y
]
−
i
−
∑
i
=
1
k
−
1
s
z
[
y
i
]
)
\binom{n-sz[y]-i}{sz[y_1],sz[y_2],...,sz[y_{k-1}],n-sz[y]-i-\sum_{i=1}^{k-1}sz[y_i]}
(sz[y1],sz[y2],...,sz[yk−1],n−sz[y]−i−∑i=1k−1sz[yi]n−sz[y]−i)
其中
y
1
,
.
.
.
y
k
−
1
y_1,...y_{k-1}
y1,...yk−1 表示 x
节点的其他儿子。
其实最后的一个式子恰好就是 x
祖先的其他旁支在位置 i
之后的点数。
然后再将 y
节点自然插入到第 j
个位置,这里并不需要乘任何系数,当
j
>
i
j>i
j>i 就可以转移,所以使用前缀和优化即可。
#include<bits/stdc++.h>
using namespace std;
const int N=5010,mod=998244353;
int T,fac[N],inv[N],n,sz[N],fin[N],f[N][N],g[N];
int ans[N];
vector<int> V[N];
int ksm(int x,int t){
int tot=1;
while(t){
if(t&1) tot=1ll*tot*x%mod;
x=1ll*x*x%mod;
t>>=1;
}
return tot;
}
int C(int x,int y){
return 1ll*fac[x]*inv[y]%mod*inv[x-y]%mod;
}
void dfs(int x){
sz[x]=0;g[x]=1;
for(auto y:V[x]){
dfs(y);
sz[x]+=sz[y];
g[x]=1ll*g[x]*g[y]%mod*C(sz[x],sz[y])%mod;
}
sz[x]++;
}
void gs(int x){
// printf("%d:\n",x);
// for(int i=1;i<=n;i++) printf("%d ",f[x][i]);printf("\n");
ans[x]=1ll*f[x][x]*C(n-x,sz[x]-1)%mod*g[x]%mod;
int prod=1;
for(auto y:V[x]) prod=1ll*prod*inv[sz[y]]%mod*g[y]%mod;
for(auto y:V[x]){
prod=1ll*prod*fac[sz[y]]%mod*ksm(g[y],mod-2)%mod;
for(int i=1;i<=n-sz[x]+1;i++){
f[y][i+1]=(f[y][i]+1ll*f[x][i]*fac[n-sz[y]-i]%mod*prod%mod*inv[n-i-sz[x]+1])%mod;
// printf("%d\n",f[y][i+1]);
}
for(int i=n-sz[x]+2;i<n;i++) f[y][i+1]=f[y][i];
gs(y);
prod=1ll*prod*inv[sz[y]]%mod*g[y]%mod;
}
}
int main(){
//scanf("%d",&T);
scanf("%d",&n);
fac[0]=1;for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
inv[n]=ksm(fac[n],mod-2);for(int i=n-1;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
for(int i=1;i<=n;i++) fin[i]=1ll*fac[i]*inv[i-1]%mod;
int x;
for(int i=1;i<=n;i++) V[i].resize(0);
for(int i=2;i<=n;i++) scanf("%d",&x),V[x].push_back(i);
f[1][1]=1;
dfs(1);
gs(1);
for(int i=1;i<=n;i++) printf("%d ",ans[i]);printf("\n");
}