前言:一开始由于失误,误以为分数相加取模不能,但是其实是可以取模的
这个题目如果按照一般方法,到达每个节点再进行概率统计,但是不知道为什么只过了百分之十五的测试集
题目地址
附上没过关的代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n; int ans = 0;
const int N = (int)2e6 + 5;
const int Mod = 998244353;
int e[N], ne[N], h[N / 2], idx = 0;
void add(int a, int b) {
e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}
int qw(int x, int p) {
int temp = 1;
while (p) {
if(p&1)
temp = x * temp % Mod;
x = x * x % Mod;
p >>= 1;
}return temp;
}
void dfs(int u, int fa, int g, int step) {
int cnt = 0;
for (int i = h[u]; i; i = ne[i]) {
int v = e[i]; if (fa == v) continue;
cnt++;
}
if (cnt == 0) {
// 已经是子节点了
//ans = (ans + (step % Mod) * qw(g, Mod - 2)) % Mod; return;
ans = (ans + step*g%Mod) % Mod; return;
}
g = (g % Mod) * (qw(cnt, Mod - 2) % Mod) % Mod;
for (int i = h[u]; i; i = ne[i]) {
int v = e[i]; if (fa == v) continue;
dfs(v, u, g , step + 1);
}
}
signed main() {
cin >> n;
for(int i=1;i<n;i++){
int u,v; cin >> u >> v;
add(u,v),add(v,u);
}
if(n==1){
cout << 0 ; return 0;
}
dfs(1,0,1,0);
cout << ans;
return 0;
}
再写一个过关的,按照官方答案的解法的
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n; int ans = 0;
const int N = (int)2e6 + 5;
const int Mod = 998244353;
const int P = 998244353;
int e[N], ne[N], h[N / 2], idx = 0;
vector<int> a[N / 2];
int siz[N], ye[N]; // 记录每一层的节点个数以及叶子节点的个数
void add(int a, int b) {
e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}
int qw(int x, int p) {
int temp = 1;
while (p) {
if (p & 1)
temp = x * temp % Mod;
x = x * x % Mod;
p >>= 1;
}return temp;
}
void dfs(int u, int fa, int dep) {
int cnt = 0; siz[dep]++;
for (int i = h[u]; i; i = ne[i]) {
int to = e[i]; if (to == fa) continue;
cnt++; dfs(to, u, dep + 1);
}
if (cnt == 0) {
ye[dep]++;
}
}
void solve() {
int pre = 1; // 概率
for (int i = 1; i < n; i++) {
//cout << " siz " << i << " " << ye[i] << endl;
if (siz[i] == 0) break;
//ans = (ans+(pre*(ye[i]*(qw(siz[i],Mod-2),Mod-2)%Mod)%Mod) * (i)%Mod) % Mod;
ans = (ans + pre * ye[i] % P * qw(siz[i], P - 2) % P * (i) % P) % P;
pre = pre * ((siz[i] - ye[i]) * (qw(siz[i], Mod - 2)) % Mod)%Mod;
//pre = pre * (((siz[i] - ye[i]) % P + P) % P) % P * qw(siz[i], P - 2) % P;
}
cout << ans; return;
}
signed main() {
cin >> n;
for (int i = 1; i < n; i++) {
int u, v; cin >> u >> v;
add(u, v), add(v, u);
//a[u].push_back(v); a[v].push_back(u);
}
if (n == 1) {
cout << 0; return 0;
}
dfs(1, 0, 0);
solve();
return 0;
}