https://atcoder.jp/contests/arc165/tasks/arc165_e
考虑一个常见套路,我们对每个连通块统计其概率,设为 p ( T ) p(T) p(T),则答案为
∑ ∣ T ∣ > k p ( T ) \sum_{|T|>k} p(T) ∣T∣>k∑p(T)
可以想成对于每个大小大于 k k k 的连通块,都要至少选一个点,也就是进行一次操作。所以我们就统计所以这么大的连通块存在的概率,乘上1,就是其期望。总期望就是单独的期望之和。
然后一些显然的性质,一个连通块必然是树上的一个子树弄走它的一堆子树。而且要形成这样一个连通块,连通块内的点一定没有被选,而且和连通块相邻的所有点一定全部被选过。
那么我们现在就发现决定一个连通块贡献的性质。有了这个性质,我们就可以合并等价类了。
在合并之前,我们先考虑 n n n 个点的块, m m m 个点相连的概率。这 m m m 个点必然早于这 n n n 个点先被选。
但我们发现操作的先后可能影响一个点是否被选。
有一种常见套路,就是我们直接枚举操作顺序,这个操作顺序代表中如果某个点所在连通块大小小于等于 k k k,我们可以视为跳过这个操作。所以对于任意 n n n 个点的操作方案为 n ! n! n!
所以 n n n 个点的块, m m m 个点相连的概率为 n ! m ! ( n + m ) ! \frac{n!m!}{(n+m)!} (n+m)!n!m!
这一步也可以操作官方题解的转化来理解:
然后回到前面那个问题,如何计算合法 ( n , m ) (n,m) (n,m) 对呢?
因为这是一棵树,直接树形dp就行。数据范围就100,暴力转移即可。
复杂度 O ( n 4 ) O(n^4) O(n4)
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){int x=0,f=1;char ch=getchar(); while(ch<'0'||
ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}return x*f;}
#define Z(x) (x)*(x)
#define pb push_back
//mt19937 rand(time(0));
//mt19937_64 rand(time(0));
//srand(time(0));
#define N 210
//#define M
#define mo 998244353
int pw(int a, int b) {
int ans=1;
while(b) {
if(b&1) ans*=a;
b>>=1; a*=a;
ans%=mo; a%=mo;
}
return ans;
}
int n, m, i, j, k, T;
int dp[N][N][N], f[N][N], w[N];
int u, v, x, fac[N], inv[N], ans;
vector<int>G[N];
void Mod(int &a) {
a=(a%mo+mo)%mo;
}
void dfs(int x, int fa) {
dp[x][1][0]=w[x]=1;
int a, b, c, d;
for(int y : G[x]) {
if(y==fa) continue;
dfs(y, x);
for(a=0; a<=w[x]+w[y]; ++a)
for(b=0; b<=w[x]+w[y]; ++b)
f[a][b]=dp[x][a][b], dp[x][a][b]=0;
for(a=0; a<=w[x]; ++a)
for(b=0; b<=w[x]; ++b) {
dp[x][a][b+1]+=f[a][b]; Mod(dp[x][a][b+1]);
for(c=0; c<=w[y]; ++c)
for(d=0; d<=w[y]; ++d) {
dp[x][a+c][b+d]+=f[a][b]*dp[y][c][d]%mo;
Mod(dp[x][a+c][b+d]);
}
}
w[x]+=w[y];
}
}
signed main()
{
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
// T=read();
// while(T--) {
//
// }
n=read(); m=read();
for(i=fac[0]=1; i<=2*n+5; ++i) fac[i]=fac[i-1]*i%mo;
inv[2*n+5]=pw(fac[2*n+5], mo-2);
for(i=2*n+4; i>=0; --i) inv[i]=inv[i+1]*(i+1)%mo;
for(i=1; i<n; ++i) {
u=read(); v=read();
G[u].pb(v); G[v].pb(u);
}
dfs(1, 0);
for(x=1; x<=n; ++x)
for(i=m+1; i<=n; ++i)
for(j=0; j<=n; ++j) {
u=i; v=j+(x!=1);
if(dp[x][i][j])
ans+=fac[u]*fac[v]%mo*inv[u+v]%mo*dp[x][i][j]%mo;
Mod(ans);
}
printf("%lld", ans);
return 0;
}