洛谷题面
感觉是非常经典的一道题,最近好像总是见到,今天也算给它做了,发一篇题解来记录一下。
这道题是一道树形 DP 题,设 f [ u ] [ 0 / 1 ] f[u][0/1] f[u][0/1] 表示 u u u 点属于一个无黑点 / / / 有且仅有一个黑点的联通块时的方案数。我们考虑如何进行转移。
对于 f [ u ] [ 1 ] f[u][1] f[u][1],我们考虑如果 u u u 在一个有且仅有一个黑点的联通块时,对于它的儿子 v v v,有三种情况:
- 儿子是不合法的,那么我们需要把儿子和 u u u 放在一个联通块里,也就是儿子要贴着父亲,对应的贡献是 f [ u ] [ 1 ] × f [ v ] [ 0 ] f[u][1] \times f[v][0] f[u][1]×f[v][0]。
- 儿子是合法的,那么我们可以从中间断开,对应的贡献是 f [ u ] [ 1 ] × f [ v ] [ 1 ] f[u][1] \times f[v][1] f[u][1]×f[v][1]。
- 父亲是不合法的,那么需要把 u u u 和它的儿子放到一个联通块里,对应的贡献是 f [ u ] [ 0 ] × f [ v ] [ 1 ] f[u][0] \times f[v][1] f[u][0]×f[v][1]。
所以对于 f [ u ] [ 1 ] f[u][1] f[u][1] 的转移就是 f [ u ] [ 1 ] = f [ u ] [ 1 ] × ( f [ v ] [ 0 ] + f [ v ] [ 1 ] ) + f [ u ] [ 0 ] × f [ v ] [ 1 ] f[u][1] = f[u][1] \times (f[v][0] + f[v][1]) + f[u][0] \times f[v][1] f[u][1]=f[u][1]×(f[v][0]+f[v][1])+f[u][0]×f[v][1]。
对于 f [ u ] [ 0 ] f[u][0] f[u][0],有两种情况:
- 父亲不合法的同时,儿子也不合法,那么把父亲和儿子放在一起,对应的贡献是 f [ u ] [ 0 ] × f [ v ] [ 0 ] f[u][0] \times f[v][0] f[u][0]×f[v][0]。
- 父亲不合法,但儿子合法,那么就从中间断开,对应的贡献是 f [ u ] [ 0 ] × f [ v ] [ 1 ] f[u][0] \times f[v][1] f[u][0]×f[v][1]。
所以对于 f [ u ] [ 0 ] f[u][0] f[u][0] 的转移就是 f [ u ] [ 0 ] = f [ u ] [ 0 ] × ( f [ v ] [ 0 ] + f [ v ] [ 1 ] ) f[u][0] = f[u][0] \times (f[v][0] + f[v][1]) f[u][0]=f[u][0]×(f[v][0]+f[v][1])。
初始的状态是 f [ u ] [ c o l o r [ u ] ] = 1 f[u][color[u]] = 1 f[u][color[u]]=1,最后的答案如果以 1 1 1 为根的话就是 f [ 1 ] [ 1 ] f[1][1] f[1][1]。
#include <bits/stdc++.h>
#define ll long long
#define drep(a,b,c) for(int a(b) ; a>=(c) ; --a)
#define rep(a,b,c) for(int a(b) ; a<=(c) ; ++a)
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch == '-') f=-1 ; ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
inline void print(int x){
if(x < 0) putchar('-'),x = -x;
if(x >= 10) print(x / 10);
putchar(x % 10 + '0');
}
const int M = 2e5+10;
const int mod = 1e9+7;
int n,cnt;
int head[M],col[M];
ll f[M][2];
struct Edge{
int to,nxt;
}e[M<<1];
inline void add(int u,int v){
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
void dfs(int u,int fa){
f[u][col[u]] = 1;
for(int i(head[u]) ; i ; i=e[i].nxt){
int v = e[i].to;
if(v == fa) continue;
dfs(v,u);
f[u][1] = f[u][1] * f[v][1] % mod + f[u][1] * f[v][0] % mod + f[u][0] * f[v][1] % mod;
f[u][0] = f[u][0] * f[v][0] % mod + f[u][0] * f[v][1] % mod;
f[u][1] %= mod;
f[u][0] %= mod;
}
}
signed main(){
n = read();
rep(i,2,n){
int u = read() + 1;
add(u,i),add(i,u);
}
rep(i,1,n) col[i] = read();
dfs(1,0);
printf("%lld\n",f[1][1]);
return 0;
}