题目
给定m,n(m<=n<=5e3),
求大小为k的多重集合,满足元素和为n,
且每种数在集合中出现的次数都小于等于m的集合数有多少个
答案对998244353取模
思路来源
官方题解
「解题报告」[ABC221H] Count Multiset - K8He - 洛谷博客
Solution-ABC221H - yllcm 的博客 - 洛谷博客
【AtCoder思维训练】ABC221H Count Multiset - QAQ - 洛谷博客
题解1
整体来说,如果没有每个次数<=m的限制,就是分拆数
1. 把多重集合转成不下降序列(单增序列),
每个序列统计一次(A1,A2,...,Ak)(A1<=A2<=...<=Ak)
2. 把不下降序列转成差分数组,
令B[1]=A[1],B[i]=A[i]-A[i-1],
对于差分数组,需要满足以下三个条件:
①
②B数组中不存在连续的m个0
③
3. 发现①做dp的时候是有后效性的,
与k相关, 第k+1次的时候需要加上前k个的和
考虑对差分数组反转,即令i=k+1-i
反转后的差分数组B,需要满足以下三个条件:
①
②B数组中不存在连续的m个0
③
f[i][j]表示当前选了i个数,总和为j的方案数
①一种方式是,对反转的差分序列后面新增一个0,
如:
原序列2 2 3,差分序列2 0 1,反转差分序列1 0 2,
此时给反转差分序列后面加一个0,得到1 0 2 0,
对应差分序列0 2 0 1,原序列0 2 2 3,即原序列前面加一个0
即f[i][j]从f[i-1][j]转移而来
②另一种方式是,对反转的差分序列的最后一个数加1,
如:
原序列2 2 3,差分序列2 0 1,反转差分序列1 0 2,
此时给反转差分序列最后一个数加1,得到1 0 3,
对应差分序列3 0 1,原序列3 3 4,即原序列整体加1
即f[i][j]从f[i][j-i]转移而来
考虑怎么加上连续最多m个0的限制,
设g[i][j]表示当前填了i个数,总和为j,序列里不含0的方案数,
给整体加1过后的序列,即不包含0,有
f的转移,要么是对f数组整体加1,
要么是钦定0的个数,从一段没有0的g数组转移过来
由于序列里没有0,最后g[i][n]即为所求
当然,可以进一步化简,
因为最后是求g数组,可以上式代入下式联立消掉f,有
,
就与官方题解中的代码一致了,前缀和优化一下,复杂度
题解2
接题解1,反转后的差分数组B,需要满足以下三个条件:
①
②B数组中不存在连续的m个0
③
直接g[i][j]表示当前选了i个数,,最后一个数即b[i]>0的方案数
考虑暴力转移,
从1到m,枚举最后一段0的连续段长度,
也就是枚举上一个非0的位置x,再枚举b[i]选择的数为w,有:
对的第一维,也就是g[x]这一维维护前缀和,
即可实现转移,复杂度
题解3
考虑直接对原序列做dp,
f[i][j]表示前i个数和为j的方案数
如:原序列1 1 2,
①每次要么新增一个1,转移到1 1 1 2,f[i][j]从f[i-1][j-1]转移
②要么令所有数都+1,使得所有数都大于等于2,转移到2 2 3,f[i][j]从f[i][j-i]转移
但是,第一种转移新增了一个1,可能会导致恰出现连续m+1个1的情况,减掉这种情况即可
出现这种情况时,前m+1个数字为1,且第m+2个数为>=2的值,
只需全局减1,即可删掉m+1个1,并且使得第m+2个数的值>=1,也就对应了f[i-(m+1)][j-i]
有
复杂度
题解4
数形结合,
如果对原序列dp,如下图所示,有三条限制,
①
②不存在超过m个xi相同
③
按照箭头视角去看这个图,
也就是先顺时针旋转90度,再翻转,
新的序列仍然有三条限制,
①
②
③
发现限制2更强了,所以可以对新序列dp,
dp[i][j]表示最后一列高为i,柱状图面积总和为j的方案数,
枚举上一列高为x,需要满足x∈[i-m,i],有:
惊奇地发现,这和题解1得到的转移式子一模一样
复杂度
代码1、代码4 O(n^2)
#include<iostream>
using namespace std;
const int N=5e3+10,mod=998244353;
int n,m,dp[N][N],sum[N][N];
void add(int &x,int y){
x=(x+y)%mod;
}
int main(){
scanf("%d%d",&n,&m);
dp[0][0]=sum[0][0]=1;
for(int i=1;i<=n;++i){
for(int j=0;j<=n;++j){
if(j>=i){
dp[i][j]=sum[i][j-i];
if(i-m-1>=0){
add(dp[i][j],mod-sum[i-m-1][j-i]);
}
}
sum[i][j]=(sum[i-1][j]+dp[i][j])%mod;
}
}
for(int i=1;i<=n;++i){
printf("%d\n",dp[i][n]);
}
return 0;
}
代码2 O(n^2logn)
#include<iostream>
using namespace std;
const int N=5e3+10,mod=998244353;
int n,m,g[N][N],sum[N][N];
void add(int &x,int y){
x=(x+y)%mod;
}
int main(){
scanf("%d%d",&n,&m);
g[0][0]=sum[0][0]=1;
for(int i=1;i<=n;++i){
for(int j=0;j<=n;++j){
for(int w=1;w*i<=j;++w){
add(g[i][j],sum[i-1][j-w*i]);
if(i-m-1>=0)add(g[i][j],mod-sum[i-m-1][j-w*i]);
}
sum[i][j]=(sum[i-1][j]+g[i][j])%mod;
}
}
for(int i=1;i<=n;++i){
printf("%d\n",g[i][n]);
}
return 0;
}
代码3 O(n^2)
#include<iostream>
using namespace std;
const int N=5e3+10,mod=998244353;
int n,m,dp[N][N];
void add(int &x,int y){
x=(x+y)%mod;
}
int main(){
scanf("%d%d",&n,&m);
dp[0][0]=1;
for(int i=1;i<=n;++i){
for(int j=1;j<=n;++j){
dp[i][j]=dp[i-1][j-1];
if(j-i>=0)add(dp[i][j],dp[i][j-i]);
if(i>=m+1 && j-i>=0)add(dp[i][j],mod-dp[i-(m+1)][j-i]);
}
}
for(int i=1;i<=n;++i){
printf("%d\n",dp[i][n]);
}
return 0;
}
代码5 O(n^3)
自己乱搞了两个复杂度并不正确的做法,也贴在这里好了
这个是考虑容斥减掉不合法的答案
#include<iostream>
using namespace std;
const int N=5e3+10,mod=998244353;
typedef long long ll;
int n,m,dp[N][N],sum[N];//dp[i][j]选了i个和为j方案数
void add(int &x,int y){x=(x+y)%mod;}
int main(){
scanf("%d%d",&n,&m);
dp[0][0]=1;
for(int l=1;l<=n;++l){
for(int i=1;i<=n;++i){
for(int j=l;j<=n;++j){
add(dp[i][j],dp[i-1][j-l]);
/*
for(int k=1;k<=j;++k){
add(dp[i][j],dp[i-1][j-k]);
}
*/
}
}
for(int i=n;i>=m+1;--i){
for(int j=n;j-l*(m+1)>=0;--j){
add(dp[i][j],mod-dp[i-(m+1)][j-l*(m+1)]);
}
}
}
// for(int i=1;i<=n;++i){
// for(int j=1;j<=n;++j){
// printf("i:%d j:%d dp:%d\n",i,j,dp[i][j]);
// }
// }
for(int i=1;i<=n;++i){
printf("%d\n",dp[i][n]);
}
return 0;
}
代码6 O(n^3logn)
这个是纯纯暴力
#include<iostream>
using namespace std;
const int N=5e3+10,mod=998244353;
typedef long long ll;
int n,m,dp[N][N];//dp[i][j]选了i个和为j方案数
void add(int &x,int y){x=(x+y)%mod;}
int main(){
scanf("%d%d",&n,&m);
dp[0][0]=1;
for(int i=1;i<=n;++i){
for(int j=n;j>=i;--j){
for(int k=1;k<=m;++k){
if(j-k*i<0)break;
for(int l=n;l>=k;--l){
add(dp[l][j],dp[l-k][j-k*i]);
}
}
}
}
for(int i=1;i<=n;++i){
printf("%d\n",dp[i][n]);
}
return 0;
}