题目大意
有一棵 n n n个顶点的树,这棵树上长度为 d d d的简单路径的价值为 2 d 2^d 2d。
有 q q q次询问,每次给出两个正整数 x , y x,y x,y,请你回答所有通过顶点 x x x和 y y y的简单路径的价值之和,输出答案模 998244353 998244353 998244353后的值。
1 ≤ n , q ≤ 1 0 6 , x ≠ y 1\leq n,q\leq 10^6,x\neq y 1≤n,q≤106,x=y
时间限制 1500 m s 1500ms 1500ms,空间限制 512 M B 512MB 512MB。
题解
对于每组询问 x , y x,y x,y,设 x x x到 y y y的路径长度为 d d d,则路径 x , y x,y x,y的价值为 2 d 2^d 2d。设 x x x不经过路径 x − y x-y x−y可以到达的部分为 A A A, y y y不经过路径 x − y x-y x−y可以到达的部分为 B B B。
设两个点 u u u到 v v v的距离为 d i s ( u , v ) dis(u,v) dis(u,v),那么,询问 x , y x,y x,y的答案为:
a n s = ∑ u ∈ A ∑ v ∈ B 2 d + d i s ( u , x ) + d i s ( v , y ) = 2 d × ( ∑ u ∈ A 2 d i s ( u , x ) ) × ( ∑ v ∈ B 2 d i s ( v , y ) ) ans=\sum\limits_{u\in A}\sum\limits_{v\in B}2^{d+dis(u,x)+dis(v,y)}=2^d\times (\sum\limits_{u\in A}2^{dis(u,x)})\times (\sum\limits_{v\in B}2^{dis(v,y)}) ans=u∈A∑v∈B∑2d+dis(u,x)+dis(v,y)=2d×(u∈A∑2dis(u,x))×(v∈B∑2dis(v,y))
也就是说,我们只需要求出 ∑ u ∈ A 2 d i s ( u , x ) \sum\limits_{u\in A}2^{dis(u,x)} u∈A∑2dis(u,x)和 ∑ v ∈ B 2 d i s ( v , y ) \sum\limits_{v\in B}2^{dis(v,y)} v∈B∑2dis(v,y)即可。
设 s u m x sum_x sumx表示点 x x x的子树中的每个点到 x x x的距离的二次幂之和, v s x vs_x vsx表示不在 x x x的子树中的点到 x x x的距离的二次幂之和,用一个换根 D P DP DP,做两次 d f s dfs dfs即可。
然后,求出每个询问中 x , y x,y x,y的 l c a lca lca。如果 x ≠ l c a x\neq lca x=lca且 y ≠ l c a y\neq lca y=lca,则答案为 2 d × s u m x × s u m y 2^d\times sum_x\times sum_y 2d×sumx×sumy。否则,不妨设 x = l c a x=lca x=lca,则从 y y y往上跳到 d e p x + 1 dep_x+1 depx+1的位置 t t t( d e p x dep_x depx表示 x x x的深度),则答案为 2 d × s u m y × ( v s x + s u m x − 2 × s u m t ) 2^d\times sum_y\times (vs_x+sum_x-2\times sum_t) 2d×sumy×(vsx+sumx−2×sumt)。
用倍增跳 l c a lca lca或 l c a lca lca下面的点,时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。因为 n n n的范围比较大,这题常数也大,时间限制还比较小,所以我们考虑优化。
用 t a r j a n tarjan tarjan求 l c a lca lca可以将时间复杂度平摊到 O ( n ) O(n) O(n),但求 l c a lca lca下面的点怎么处理呢?我们可以做一次 d f s dfs dfs,记录每个点 i i i当前遍历的是它的哪个儿子的子树,记这个儿子为 t o i to_i toi。那么,对于每一对 x , y x,y x,y,如果 x x x为 l c a lca lca,则 x x x一定为 y y y的祖先,那么遍历到 y y y时, t o x to_x tox就是 y y y往上跳到 x x x下面一个位置的的点 t t t。这样做的话,时间复杂度平摊下来也是 O ( n ) O(n) O(n)的。
总时间复杂度为 O ( n ) O(n) O(n)。
code
#include<bits/stdc++.h>
#define rg register
using namespace std;
const int N=1000000;
const long long mod=998244353;
int n,q,tot=0,d[2*N+5],l[2*N+5],r[N+5],z[N+5],dep[N+5],fa[N+5],to[N+5],lca[N+5];
int tot1=0,d1[2*N+5],l1[2*N+5],r1[N+5],id1[2*N+5];
int tot2=0,d2[N+5],l2[N+5],r2[N+5],id2[N+5];
long long mi[N+5],sum[N+5],vs[N+5],ans[N+5];
struct que{
int x,y;
}v[N+5];
inline int in(){
int t=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9'){
t=(t<<3)+(t<<1)+(ch^48);
ch=getchar();
}
return t;
}
inline void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
inline void add1(int xx,int yy,int zz){
l1[++tot1]=r1[xx];d1[tot1]=yy;r1[xx]=tot1;id1[tot1]=zz;
}
inline void add2(int xx,int yy,int zz){
l2[++tot2]=r2[xx];d2[tot2]=yy;r2[xx]=tot2;id2[tot2]=zz;
}
inline void dfs1(int u,int f){
sum[u]=1;
dep[u]=dep[f]+1;
for(rg int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
dfs1(d[i],u);
sum[u]=(sum[u]+2*sum[d[i]])%mod;
}
}
inline void dfs2(int u,int f){
for(rg int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
vs[d[i]]=(vs[u]+sum[u]-2*sum[d[i]]+2*mod)*2%mod;
dfs2(d[i],u);
}
}
inline int find(int ff){
if(fa[ff]!=ff) fa[ff]=find(fa[ff]);
return fa[ff];
}
inline void pt(int x,int y){
int v1=find(x),v2=find(y);
if(v1!=v2) fa[v1]=v2;
}
inline void dfs3(int u,int f){
z[u]=1;
for(rg int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
dfs3(d[i],u);
pt(d[i],u);
}
for(rg int i=r1[u];i;i=l1[i]){
if(z[d1[i]]) lca[id1[i]]=find(d1[i]);
}
}
inline void dfs4(int u,int f){
for(rg int i=r2[u];i;i=l2[i]){
int v=d2[i];
ans[id2[i]]=sum[u]
*(vs[v]+sum[v]-2*sum[to[v]]+2*mod)%mod;
}
for(rg int i=r[u];i;i=l[i]){
if(d[i]==f) continue;
to[u]=d[i];
dfs4(d[i],u);
}
}
inline void wt(int wk){
if(wk>=10) wt(wk/10);
putchar(wk%10+'0');
}
int main()
{
// freopen("d.in","r",stdin);
// freopen("d.out","w",stdout);
n=in();
mi[0]=1;
for(rg int i=1;i<=n;i++) mi[i]=mi[i-1]*2%mod;
for(rg int i=1,x,y;i<n;i++){
x=in();y=in();
add(x,y);add(y,x);
}
dfs1(1,0);
dfs2(1,0);
q=in();
for(rg int i=1;i<=q;i++){
v[i]=(que){in(),in()};
if(dep[v[i].x]>dep[v[i].y]) swap(v[i].x,v[i].y);
add1(v[i].x,v[i].y,i);add1(v[i].y,v[i].x,i);
}
for(rg int i=1;i<=n;i++) fa[i]=i;
dfs3(1,0);
for(rg int i=1;i<=q;i++){
if(lca[i]==v[i].x) add2(v[i].y,v[i].x,i);
else ans[i]=sum[v[i].x]*sum[v[i].y]%mod;
}
dfs4(1,0);
for(rg int i=1;i<=q;i++){
int dis=dep[v[i].x]+dep[v[i].y]-2*dep[lca[i]];
ans[i]=ans[i]*mi[dis]%mod;
wt(ans[i]);putchar('\n');
}
return 0;
}