最近接触了一点字符串算法,其实也就是一个简单的最大回文串算法,给定字符串s,求出最大字符串长度。
算法是这样的, 用'#'将s字符串中的每个字符分隔,比如s = “aba”,分割后变成#a#b#a#,然后利用下面的算法:
pre:
mx ←0
for i: = 1 to n-1
if(mx>i)
p[i] = min(p[2*id-i], mx-i)
else p[i] = 1
while(str[i+p[i]] == str[i-p[i]])
p[i]++
if(i+p[i]>mx)
mx = i+p[i]
id = i
注意在将s添加'#'之后为了防止越界访问,需要再整个字符串前面加上’$’这样i就是从1开始,p[i]表示在字符中以i为中心的回文串的右半长度,准确的说是r-i+1,r为回文串最右边的字符的下标,
mx表示i之前的位置j的回文串最大右端值,然后每次循环结束的时候更新mx并用id记录i值。
本题求的是字符串s的所有的前缀字符串的k值,k值是这样的定义的,也就是对一个回文串进行二分,分得的两部分仍然是回文串就将s字符串的k值增加1,然后继续分,直到不是回文串。
由于字符串k值取决于自身是不是回文串,所以要先进行判断,然后f[str] = f[substr] + 1,substr表示为str的前半部分的k值,由于substr应该在前面求出了,所以整个过程可以是一个dp过程。
代码:
#include <iostream> #include <sstream> #include <cstdio> #include <climits> #include <cstring> #include <cstdlib> #include <string> #include <stack> #include <map> #include <cmath> #include <vector> #include <queue> #include <algorithm> #define esp 1e-6 #define pi acos(-1.0) #define pb push_back #define lson l, m, rt<<1 #define rson m+1, r, rt<<1|1 #define mp(a, b) make_pair((a), (b)) #define in freopen("in.txt", "r", stdin); #define out freopen("out.txt", "w", stdout); #define print(a) printf("%d ",(a)); #define bug puts("********))))))"); #define stop system("pause"); #define Rep(i, c) for(__typeof(c.end()) i = c.begin(); i != c.end(); i++) #define inf 0x0f0f0f0f using namespace std; typedef long long LL; typedef vector<int> VI; typedef pair<int, int> pii; typedef vector<pii> VII; typedef vector<pii, int> VIII; typedef VI:: iterator IT; const int maxn = 5*1000000+100; char str[maxn<<1], s[maxn]; int p[maxn<<1]; int ans; int n; int f[maxn<<1]; void Init(void) { str[0] = '$', str[1] = '#'; for(int i = 0; i < n; i++) { str[i*2+2] = s[i]; str[i*2+3] = '#'; } int nn = 2*n+2; str[nn] = 0; int mx = 0, id; for(int i = 1; i < nn; i++) { if(mx > i) { p[i] = min(p[2*id-i], mx-i); } else p[i] = 1; while(str[i+p[i]] == str[i-p[i]]) p[i]++; if(i + p[i] > mx) mx = i+p[i], id = i; } } void solve(void) { LL ans = 0; for(int i = 1; i <= n; i++) { int l = 2, r = 2*i; int m = (l+r)>>1; if(p[m]*2-1 >= r-l+1) f[r] = f[m-1+((m%2) ? 0: -1)]+1; ans += f[r]; } printf("%I64d ", ans); } int main(void) { scanf("%s", s); n = strlen(s); Init(); solve(); return 0; }
更详细的介绍在这里:http://www.cnblogs.com/wuyiqi/archive/2012/06/25/2561063.html