有一面由 n × m n\times m n×m 个格子组成的墙,每个格子要么是黑色,要么是白色。你每次将会进行这样的操作:等概率随机选择一个位置 ( x , y ) (x,y) (x,y),和一个颜色 c c c(黑色或者白色)( 1 ≤ x ≤ n , 1 ≤ y ≤ m 1≤x≤n,1≤y≤m 1≤x≤n,1≤y≤m,任意 ( x , y , c ) (x,y,c) (x,y,c) 的组合选择它的概率均为 1 2 ∗ n ∗ m \frac1{2∗n∗m} 2∗n∗m1),然后将在 ( x , y ) (x,y) (x,y) 左上⻆的所有格子的颜色涂成 c c c。即将所有满足 1 ≤ x ′ ≤ x , 1 ≤ y ′ ≤ y 1≤x′≤x,1≤y′≤y 1≤x′≤x,1≤y′≤y 的 ( x ′ , y ′ ) (x′,y′) (x′,y′) 格子上的颜色涂成 c c c。这次操作的代价为涂的格子的数量,即 x × y x\times y x×y。给定初始状态和终止状态,问期望要花费多少代价才能将墙面从初始状态涂成终止状态。答案模 998244353 998244353 998244353。
n , m ≤ 5 n,m\le5 n,m≤5
先考虑朴素的 dp,对墙上的每个格子进行状态压缩,总共有
2
n
m
2^{nm}
2nm 个状态,设
f
i
f_i
fi 表示状态
i
i
i 期望还要花费多少代价才能到达最终状态。转移为
f
i
=
∑
f
j
+
w
2
n
m
f_i=\sum\dfrac{f_j+w}{2nm}
fi=∑2nmfj+w
其中 j j j 表示一个格子的左上角都涂成黑色或白色所得到的状态, w w w 是涂一次颜色的代价。
这是有后效性 dp,可以高斯消元解决,时间复杂度 O ( 2 3 n m ) O(2^{3nm}) O(23nm),可以得到 60pts。
下面考虑减少状态。
考虑一个状态,若一个格子的颜色与最终状态的颜色不同,则这个点后面一定会修改,其左上角的点同样,所以这些点的颜色是什么就不重要了,反正后面都要被改。记
p
i
,
j
p_{i,j}
pi,j 表示状态中坐标为
(
i
,
j
)
(i,j)
(i,j) 的点的颜色是否(1/0)与终止状态一样,通过模拟可以发现,
p
p
p 数组中构成
1
1
1 的元素是类似阶梯的形状。如下图,圆点表示在这个格子颜色与最终状态的不一样,矩形表示
范围内的格子要被修改,红色部分就是
p
i
,
j
p_{i,j}
pi,j 为
1
1
1 的部分。
称这个矩阵为阶梯矩阵,原来的状态称为 01 矩阵。
那么,所有的 01 矩阵都能唯一转化为一个阶梯矩阵,所以只需将原来的状态换成新的阶梯状态进行高斯消元即可。
而阶梯状态是不多的,只有 ( n + m n ) \binom{n+m}{n} (nn+m) 种,下面证明。
考虑每一行,看每行红色部分的格子数量,容易发现它们是有单调性的,设格子数量为 i i i 的行数为 x i x_i xi,即方案就是求 ∑ i = 0 n x i = m \sum\limits_{i=0}^nx_i=m i=0∑nxi=m 的解的个数,这是组合数学的经典问题,易证。
而 ( n + m n ) \binom{n+m}{n} (nn+m) 最大也就 ( 10 5 ) = 252 \binom{10}{5}=252 (510)=252,高斯消元绰绰有余。
实现上,可以用 dfs 求出阶梯矩阵的状态和个数,01 矩阵转换成阶梯矩阵直接暴力枚举(时间充裕)每个点,若颜色不相同就直接对左上角修改。高斯消元要写模意义下的,我的实现直接用快速幂求逆元。
总的时间复杂度为 O ( ( n + m n ) 3 + ( n + m n ) n 3 m 3 ) O\left(\binom{n+m}{n}^3+\binom{n+m}{n}n^3m^3\right) O((nn+m)3+(nn+m)n3m3)
具体实现参照代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod=998244353;
const int N=260;
ll a[N][N],val[N];
int fr[N],nn,mm,n,m,cnt;
char s1[N][N],s2[N][N];
unordered_map<int,int> ma;
vector<int> v;
ll ksm(ll a,ll b)
{
ll ans=1;
while(b){
if(b&1) ans=ans*a%mod;
b>>=1;
a=a*a%mod;
}
return ans;
}
int gauss()
{
int r=1,c=1;
for(;r<=nn&&c<=mm;r++,c++){
int maxn=r;
for(int i=r+1;i<=nn;i++) if(abs(a[maxn][c])<abs(a[i][c])) maxn=i;
for(int j=1;j<=mm+1;j++){
swap(a[maxn][j],a[r][j]);
}
if(abs(a[r][c])==0){r--;continue;}
for(int i=1;i<=nn;i++){
if(i==r) continue;
ll g=(a[i][c]*ksm((a[r][c]%mod+mod)%mod,mod-2)%mod+mod)%mod;
for(int j=1;j<=mm+1;j++)
a[i][j]=(a[i][j]-a[r][j]*g%mod+2*mod)%mod;
}
}
for(int i=r;i<=nn;i++){
if(abs(a[i][i])==0&&abs(a[i][mm+1])>0){
return -1;
}
}
memset(fr,0x3f,sizeof(fr));
for(int i=1;i<r;i++){
int cnt=0,num=0;
for(int j=1;j<=mm;j++) if(fr[j]&&abs(a[i][j])>0) cnt++,num=j;
if(cnt==1) fr[num]=0,val[num]=(a[i][mm+1]*ksm((a[i][num]%mod+mod)%mod,mod-2)%mod+mod)%mod;
}
return 1;
}
int getid(int x,int y){return (x-1)*m+y;}
void dfs(int x,int k,int state)
{
if(!x){
ma[state]=++cnt;
v.push_back(state);
return;
}
for(int i=0;i<=k;i++) dfs(x-1,i,state|((1<<x*m)-1^(1<<x*m-i)-1));
}
int change(int id)
{
int newid=(1<<n*m)-1;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
if((id>>getid(i,j)-1&1)!=(s2[i][j]=='B')){
for(int ii=1;ii<=i;ii++){
for(int jj=1;jj<=j;jj++){
newid&=INT_MAX^(1<<getid(ii,jj)-1);
}
}
}
}
}
return newid;
}
int main()
{
freopen("graffiti.in","r",stdin);
freopen("graffiti.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%s",s1[i]+1);
for(int i=1;i<=n;i++) scanf("%s",s2[i]+1);
int finish=0,start=0;
for(int i=n;i>=1;i--){
for(int j=m;j>=1;j--){
finish=finish*2+(s2[i][j]=='B');
start=start*2+(s1[i][j]=='B');
}
}
dfs(n,m,0);
int N=1<<n*m;
nn=mm=v.size();
for(int i=0;i<v.size();i++){
int t=v[i];
if(i==ma[change(finish)]-1){
a[i+1][i+1]=1;
continue;
}
int tt=0;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
if(t>>getid(i,j)-1&1) tt|=(s2[i][j]=='B')<<getid(i,j)-1;
else tt|=(s2[i][j]=='W')<<getid(i,j)-1;
}
}
a[ma[t]][ma[t]]+=2*n*m;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
int id=tt;
for(int ii=1;ii<=i;ii++){
for(int jj=1;jj<=j;jj++){
id|=1<<(getid(ii,jj)-1);
}
}
int newid=change(id);
a[ma[t]][nn+1]+=i*j;
a[ma[t]][ma[newid]]--;
for(int ii=1;ii<=i;ii++){
for(int jj=1;jj<=j;jj++){
id^=1<<(getid(ii,jj)-1);
}
}
newid=change(id);
a[ma[t]][nn+1]+=i*j;
a[ma[t]][ma[newid]]--;
}
}
}
int czn=gauss();
printf("%lld\n",val[ma[change(start)]]);
}