今天更新一道不错的状态压缩DP题,顺带总结一下状态压缩DP。
摘要:
Part1 浅谈状态压缩DP的理解
Part2 浅谈对状态机DP的理解
Part3 关于状态压缩DP的1道例题
Part1 状态压缩DP
1、状态压缩DP:
事物的状态可能包含多个特征,但是事物的状态之间却可以互相转移,此时我们引入状态压缩DP,将事物的复杂的状态用一个数字来替代,此时事物的状态可以用数组的某个位置表示,从而可以进行状态的转移。
2、常见的状态表示:(1) 用10进制数字本身表示状态,比如表示当前状态%某个数字的余数等等,这里举一个例子。
(2) 用10进制内蕴含的二进制位表示状态, 01, 表示了每个位置上的两种状态,它既可以表示是否存在,也可以表示数量的奇偶性。(3) 用10进制内蕴含的K进制(除了10进制和2进制外的其它进制)表示状态,这种题我没见过,但是基于上面我们很容易可以推广。
3、什么时候我们可以用:首先是在你的状态表示基础上,整体的转移图是一个拓扑图,也就是可以通过递推得来,并且时间空间可以过得去,此时我们就可以用状态压缩DP。
Part2 状态机DP
我在之前写过一篇关于状态机DP的文章,里面有详细的理论和几道很好的例题:
http://t.csdnimg.cn/POtFs
Part3 例题: 小红的回文数
题目链接:E-小红的回文数_牛客周赛 Round 32 (nowcoder.com)
(1)题意:
小红定义一个整数是“好数”,当且仅当该整数通过重排之后可以形成回文数。(可以包含前导零)现在小红拿到了一个正整数x,小红想截取一段连续区间得到好数,她想知道有多少种不同的方案?
(2)题解:
暴力显然会超时,必定需要n * n 的复杂度,此时我们不妨考虑一下DP, 我们此时从左到右去考虑这个数,我们考虑每个以第i位数结尾的情况,最后答案就是累加后的值,我们此时考虑一下以第i位数结尾的区间,我们发现对于一个数字而言,每位的数字只能是0-9的数字之一,我们不妨用10个二进制位表示每一种数字的数量%2是多少,这样我们就可以通过统计1的个数判断有几个奇数,如果要构造一个回文串,显然只能由一个或者0个奇数的位。
状态转移方程:,表示以第i位数结尾的数字区间,且0-9各个的数字情况是state的方案数,,并且对于每个位单独一位的情况也要考虑,所以状态转移代码是:vector<vector<int>> dp(n + 1, vector<int>((1 << 10) + 2)); dp[0][0] = 1; for(int i = 1; i <= n; i ++ ) { for(int j = 0; j < (1 << 10); j ++ ) { int k = j ^ (1 << (s[i - 1] - '0')); dp[i][k] += dp[i - 1][j]; } if(i >= 2) dp[i][1 << (s[i - 1] - '0')] ++; }
你以为事情结束了吗,这是一道比较毒瘤的题,它会卡你的空间,在此基础上我们需要引入滚动数组优化,优化掉一维的空间。
此时的转移代码是:dp[0] = 1; long long res = 0; for(int i = 1; i <= n; i ++ ) { memset(usdp, 0, sizeof usdp); for(int j = 0; j < 1 << 10; j ++ ) usdp[j ^ (1 << (s[i - 1] - '0'))] += dp[j]; if(i >= 2) usdp[1 << (s[i - 1] - '0')] ++; for(int j = 0; j < 1 << 10; j ++ ) { cnt = 0; for(int c = 0; c <= 9; c ++ ) if(j >> c & 1) ++ cnt; if(cnt <= 1) res += usdp[j]; dp[j] = usdp[j]; } } cout << res << endl;
(3) 代码 (C ++):滚动数组优化:
#include <bits/stdc++.h> // #define int long long #define lowbit(x) (x&-x) using namespace std; const int N = 1e5 + 2; const int inf = 0x3f3f3f3f; int n, cnt; long long dp[2025], usdp[2025]; // int a[N]; void solve() { string s; cin >> s; n = s.size(); cnt = 0; // for(int i = 1; i <= n; i ++ ) a[i] = s[i - 1] - '0'; dp[0] = 1; long long res = 0; for(int i = 1; i <= n; i ++ ) { memset(usdp, 0, sizeof usdp); for(int j = 0; j < 1 << 10; j ++ ) usdp[j ^ (1 << (s[i - 1] - '0'))] += dp[j]; if(i >= 2) usdp[1 << (s[i - 1] - '0')] ++; for(int j = 0; j < 1 << 10; j ++ ) { cnt = 0; for(int c = 0; c <= 9; c ++ ) if(j >> c & 1) ++ cnt; if(cnt <= 1) res += usdp[j]; dp[j] = usdp[j]; } } cout << res << endl; } int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); int ts = 1; // cin >> ts; while(ts -- ) solve(); return 0; }