例题
- 求一个字符串的最长回文子串的长度
O ( N 2 ) O(N^2) O(N2)的解法很容易想,就是从每个字符位置向左右同时拓展,然后检查当前是不是回文,更新长度,可以简单写一下代码
int solve(string &ss){
int ans = 0;
int n = ss.length();
string s;
s.resize(2 * n + 1);
int l = 0;
s[l++] = '$';
s[l++] = '#';
for(int i=0;i<n;i++){
s[l++] = ss[i];
s[l++] = '#';
}
for(int i=0;i<l;i++){
int p = 0;
while(i - p >= 0 && i + p < l && s[i + p] == s[i - p]){
p += 1;
}
ans = max(ans, p - 1);
}
return ans;
}
- 这里注意一个小技巧,一个字符串有两种情况,分别是长度为奇数和长度为偶数,如何将这两种情况归一化呢?考虑在字符串前面加一个$,在每两个字符之间放一个#,不一定非得是$和#,只要不会在字符串中出现即可,容易计算得到这样做得到的字符串长度一定是奇数
- 考虑优化,容易想到如果计算已经得到了前面的回文半径(以某个字符为中心的回文串长度的一半),那么对称的两个点中尚未计算的那个点的回文半径的最小值也就是已经计算得到的那个点的回文半径
- 观察上面的字符串,第一个红色的b的回文半径容易求出是2,后面的c的回文半径容易求出是6,那么我们如何根据它求出后面的红色的b的回文半径呢?显然因为c的回文半径范围覆盖了第一个红色的b,所以第二个红色的b的回文半径的最小值是第一个红色的b的回文半径(这里同时因为c的回文半径也覆盖了第二个红色的b的回文半径区域,因为如果第二个红色的b之后没有足够的字符也是到不了第一个红色的b的回文半径的),这样我们就得到了第二个b的回文半径的最小值,然后暴力拓展,就为之后的字符串也做好了铺垫,可以证明总的时间复杂度是 O ( n ) O(n) O(n)的
- 上面是我对manacher算法的个人理解,接下来从代码的角度来说一下,首先需要两个变量
r
和c
,因为有一个问题,怎么判断某个字符是不是在某个回文区间之内,我们可以从左边递推找到一个右端点最远的回文区域,这样记录一下中心端点c
和右端点r
,在从左到右计算的过程中更新最大的r
,这样就找到了回文半径和回文中心 - 然后使用一个数组
P[i]
记录以i
为中心的回文半径,具体代码如下
int manacher(string &ss){
int n = ss.length();
string s;
s.resize(2 * n + 1);
int l = 0;
s[l++] = '$';s[l++] = '#';
for(int i=0;i<n;i++){
s[l++] = ss[i];
s[l++] = '#';
}
vector<int> P(l);
int r, c;
r = c = 0;
int ans = 0;
for(int i=0;i<l;i++){
int &p = P[i];
// 用r - i约束的原因上面已经说过
p = (i + p < r ? min(r - i, P[2 * c - i]) : 1);// 2 * c - i是对称位置的字符
while(s[i + p] == s[i - p]) p += 1;
if(i + p > r){
r = i + p;
c = i;
}
ans = max(ans, p - 1);
}
return ans;
}
例题
https://www.luogu.com.cn/problem/P4287
在一个字符串里面找前后两个长度相等的子串都是回文串的字符串的最大长度
- 考虑manacher,如果计算出了一个回文串,因为它的前面已经计算好了,比如现在回文半径右端点最远在 r r r,那么从 i i i到 r r r这一段是回文串的一半我们是知道的,现在考虑它前面一段,怎么考虑呢?根据对称性,设 j ≥ r j\geq r j≥r,对称到左边就是 i − j − i 2 i-\frac{j-i}{2} i−2j−i同时还需要满足 ( j − i ) % 4 = 0 (j-i)\%4=0 (j−i)%4=0,这里的 i i i是右侧字符串的回文中心, j j j是左侧字符串关于 c c c对称的回文中心
- 因为要求这样的回文串长度必须是偶数,所以根据我们的回文串构造方法,每次枚举的 i i i必须是奇数
#include <bits/stdc++.h>
using namespace std;
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n;
string ss;
cin >> n >> ss;
string s;
s.resize(2 * n + 1);
int l = 0;
s[l++] = '$';
s[l++] = '#';
for(int i=0;i<n;i++){
s[l++] = ss[i];
s[l++] = '#';
}
vector<int> P(l);
int r, c;
r = c = 0;
int ans = 0;
for(int i=0;i<l;i++){
int &p = P[i];
p = (i + p < r ? min(P[2 * c - i], r - i) : 1);
while(s[i + p] == s[i - p]) p += 1;
if(i + p > r){
if(i & 1){
for(int j=max(r, i + 4);j<=i+p;j++){
if((j - i) % 4 == 0 && P[i - (j - i) / 2] >= (j - i) / 2){
ans = max(ans, j - i);
}
}
}
r = i + p;
c = i;
}
}
cout << ans << '\n';
return 0;
}