题目链接
令f【x】【0】表示不选根的x子树的最大贡献,f【x】【1】表示选根的x子树最大贡献,g【x】为max(f【x】【0】,f【x】【1】)。
如果我们要连接x和u1,那么贡献是:
w【x】+w【u1】+f【u1】【0】+(f【u2】【0】-g【u1】)+(f【u3】【0】-g【u2】)........+(f【uk】【0】-g【uk-1】)。
我们先考虑只有一个v点的时候,它有两个孩子v1和v2.如果我们要选v那么我们只能从不跨根的v1和v2里选,也就是f【v】【0】。
然后我们考虑有两个v的时候。除了上面的f【v1】【0】外,我们还要算g【v5】(v2也在被选的路径上,不能再选)。容易知道f【v2】【0】包括了两部分,一个是v1为根的子树,一个是v5为根的子树。我们要算v5的子树贡献就只需要让f【v2】【0】减去v1子树的贡献,也就是g【v1】,所以v5的贡献就是(f【v2】【0】-g【v1】)。
我们将路径上的点不断扩充,也就得到了一开始的那个式子。我们发现那个式子括号里的下标并不对齐,我们等价变形一下:
w【x】+w【u1】+(f【u1】【0】-g【u1】)+(f【u2】【0】-g【u2】)....+(f【uk】【0】-g【uk】)+g【uk】。因为差了一个g【uk】我们把它添加一个再减去仍然是相等的。
现在我们观察这个式子,对于后面那些括号的和其实就是路径上的每个点的f【x】【0】-g【x】的和。
然后我们现在再回到一开始的问题,我们要让根x去其子树去找一个颜色相同的点连线当做路线。w【x】是不变的,所以在挑选的时候我们可以不用管w【x】。我们令sum=f【x】【0】,表示sum是每个子树都不跨根的贡献。显然这样每个子树都是互相不影响的,他们的贡献是可以直接求和当做x的贡献的。
我们现在考虑选一个子树来连接根。因为我们选了这个子树,所以我们要先把这个子树的贡献先去掉,然后加上新贡献。令该子树为v,减贡献的操作就是sum-g【v】。我们再回到上面的式子w【x】+w【u1】+(f【u1】【0】-g【u1】)+(f【u2】【0】-g【u2】)....+(f【uk】【0】-g【uk】)+g【uk】。g【uk】其实就是这个g【v】,就是x的直接儿子。也就是说其实我们只要把式子的最后一项去掉,即w【x】+w【u1】+(f【u1】【0】-g【u1】)+(f【u2】【0】-g【u2】)....+(f【uk】【0】-g【uk】)就是这个路径新添加的贡献,也就是说新贡献就是w【x】+w【u1】+(f【u1】【0】-g【u1】)+(f【u2】【0】-g【u2】)....+(f【uk】【0】-g【uk】)+sum。
w【x】是固定的,sum是固定的,所以对于这颗子树,我们只需要让w【u1】+(f【u1】【0】-g【u1】)+(f【u2】【0】-g【u2】)....+(f【uk】【0】-g【uk】)这个式子最大,就是我们要从这颗树考虑的贡献了。
因为我们之前把括号里的下标处理成一样的了,所以我们可以考虑在u1回溯的时候,在每层dfs结束后再把每层的f【u】【0】-g【u】给u1的贡献加上。最后我们只需要从子树里和根x颜色相同的节点u里去找该贡献最大的点即可。
因为我们要将儿子节点的信息在回溯的时候更新,所以我们需要每次回溯到父节点的时候都去更新所有子节点的信息。
暴力合并和暴力修改肯定是不行的....所以对于合并用启发式合并,对于更新添加贡献的操作就用类似线段树的lazy标记来标记整个连通块...
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define int ll
const int N=2e5+10;
const int inf=0x3f3f3f3f;
typedef pair<int,int> pii;
typedef unsigned long long ull;
//const ll P=2281701377;
const ll P=998244353;
int n,c[N],w[N];
vector<int> e[N];
ll f[N][2],g[N],bl[N];//0,1表示取不取根节点,g表示f的0,1里的最大值
ll tag[N];//bl表示每个子树连通块map的祖先节点,tag表示连通块的标记,只标记在祖先节点,表示选了这个点,不选这条路径上的所有点子树会减少的最大贡献
map<int,ll> mp[N];//存连通块某种颜色的w【某种颜色的某个节点】+那个节点到当前节点的(f【路上节点,0】-g【路上节点】)之和的最大值
void upd(ll &x,ll y){
x=max(x,y);
}
void dfs(int x,int fa){
f[x][0]=f[x][1]=g[x]=0;
for(auto to:e[x]){
if(to==fa) continue;
dfs(to,x);
f[x][0]+=g[to];//不取根就可以取所有子树的情况之和
}//因为子树就算取子树根也不会取到当前节点,每个子树没有影响
for(auto v:e[x]){
if(v==fa) continue;
if(mp[bl[v]].count(c[x])){//根为端点
upd(f[x][1],(mp[bl[v]][c[x]]+tag[bl[v]])+f[x][0]+w[x]);//tag【bl【v】】就表示bl【v】这个连通块共同添加的贡献
}//因为现在遍历的点都是x的儿子节点,也就是说x回溯后新添加的贡献对他们都是一样,只要比当前的贡献就可以了,因为后面贡献都是一样的
if(mp[bl[v]].size()>mp[bl[x]].size()){
swap(bl[v],bl[x]);//启发式合并,siz小连到siz大,祖先节点表示一个连通块
}//如果子树节点个数大,那么bl【x】=bl【v】,bl【v】=bl【x】,这时候mp【bl【x】】实际上是子树的mp
//tag【bl【x】】也是子树的tag,直接交换祖先节点就可以交换所有信息了
for(auto [col,val]:mp[bl[v]]){//遍历子树
if(mp[bl[x]].count(col)){//看其他之前的子树有没有这个颜色(不包括当前子树)
upd(f[x][1],(val+tag[bl[v]])+(mp[bl[x]][col]+tag[bl[x]])+f[x][0]);
}//val+tag才是它自己现在的贡献 mp+tag也才是mp真实的贡献 每层都让f【x】【0】-g【x】实际上多减了最后一次的g【x】
//也就是说直接算val+tag【bl【v】】的时候减去了g【直接儿子】,mp【bl【x】】【col】+tag【bl【x】】也减去了那颗子树与x的直接儿子的g【直接儿子】
}
for(auto [c,val]:mp[bl[v]]){
if(mp[bl[x]].count(c)){//用当前子树信息与前面子树信息来更新,这里实际上就是把v的信息合并到x上了
upd(mp[bl[x]][c],val+tag[bl[v]]-tag[bl[x]]);//因为我们只需要保留贡献最大的点即可,所以比较保留max,且因为他们后面会添加的贡献都是相同的
}
else{
mp[bl[x]][c]=val+tag[bl[v]]-tag[bl[x]];//因为tag是对子树内所有点进行添加贡献的
}//但是如果是新添加进去的点,它在父亲子树同一加贡献的时候并不属于父亲子树,所以父亲子树添加的贡献并不能加在它身上
} //但我们并不能修改父亲子树的tag,所以我们只能修改新添加的子树属于自身的贡献,我们令它减去父亲子树的tag,
} //那么在最后算它贡献的时候再加上父亲子树的tag,等价于它本身的贡献是没有被父亲子树的tag影响的
//val是w【u】+它加入父亲子树前的tag才是它加入前的自身贡献
if(mp[bl[x]].count(c[x])){//用自己和所有子树信息来更新
upd(mp[bl[x]][c[x]],w[x]-tag[bl[x]]);
}//同理 mp【bl【x】】【c【x】】的实际值是 mp【bl【x】】【c【x】】+tag【bl【x】】,因为我们tag要对区间操作不能每个点修改
else mp[bl[x]][c[x]]=w[x]-tag[bl[x]];//所以我们将后来区间共同添加的贡献都放到了tag里,我们得把mp加上这些区间共同加的tag贡献
g[x]=max(f[x][0],f[x][1]);//和线段树的lz标记一样
tag[bl[x]]+=f[x][0]-g[x];//每次都会把所有子节点和当前节点都合并到一个集合也就是(bl【x】),
//将这个节点的祖先节点标记上f【x】【0】-g【x】等于将整个子树的节点都新添了这个贡献 ,类似线段树的lz标记
}
void solve(){
cin>>n;
for(int i=1;i<=n;i++){
cin>>c[i];
}
for(int i=1;i<=n;i++){
cin>>w[i];
}
for(int i=1;i<=n;i++){
bl[i]=i,tag[i]=0,mp[i].clear();
e[i].clear();
}
for(int i=1;i<n;i++){
int a,b;
cin>>a>>b;
e[a].push_back(b);
e[b].push_back(a);
}
dfs(1,-1);
cout<<g[1]<<endl;
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
int t=1;
cin>>t;
while(t--){
solve();
}
}