题意
给你一个1e5的字符串,让你求出子串的数量,使得子串满足:
长度为(3n-2),且([1,2n-1])和([n,3n-2])都是回文串
思路
可以看出,这两个回文串长度都是奇数
我们先用马拉车处理出每个字符(s[i])可以左右扩展的最长长度(a[i])
我们找的就是找((i,j))的对数((i<j)),使得:
(i+a[i]geq j)且(j-a[j] leq i)
处理方式有以下几种
- 由第二个不等式可以将满足条件的(j)放入bit中,每次枚举(i)的时候查询([i,i+a[i]])的和
- 按照(a[i])从大到小排序,每次在BIT中将(i)加一,答案就是每次在(a[i])影响范围内的和
代码
这是第二种处理方式
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#include<functional>
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1
using namespace std;
typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;
typedef pair<ll,int> PIL;
const db eps = 1e-6;
const int mod = 1e9+7;
const int maxn = 2e6+100;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);
int t;
int n;
int a[maxn];
char s[maxn];
char ma[maxn*2];
int mp[maxn*2];
void manacher(char s[], int len){
int l = 0;
ma[l++]='$';ma[l++]='#';
for(int i = 0; i < len; i++){
ma[l++]=s[i];ma[l++]='#';
}
ma[l]=0;
int mx = 0,id=0;
for(int i = 0; i < l; i++){
mp[i]=mx>i?min(mp[2*id-i],mx-i):1;
while(ma[i+mp[i]]==ma[i-mp[i]])mp[i]++;
if(i+mp[i]>mx){
mx=i+mp[i];id=i;
}
}
}
int len;
ll tree[maxn];
int lowbit(int x){return x&-x;}
void add(int x ,int c){
for(int i = x; i <= len; i+=lowbit(i))tree[i]+=c;
}
ll sum(int x){
ll ans = 0;
for(int i = x; i; i-=lowbit(i))ans+=tree[i];
return ans;
}
struct node{
int l,r,id,len;
node(){}
node(int l, int r, int id, int len):l(l),r(r),id(id),len(len){}
};
bool cmp(node a, node b){return a.len>b.len;}
vector<node>v;
int main(){
scanf("%d", &t);
while(t--){
v.clear();
scanf("%s",s);
len = strlen(s);
for(int i = 0; i <= len; i++)tree[i]=0;
manacher(s,len);
for(int i = 0; i< 2*len+2; i++){
if(i&&i%2==0){
a[i/2]=(mp[i]-1)/2;
int len = a[i/2];
v.pb(node(i/2-len,i/2+len,i/2,len));
}
}
sort(v.begin(),v.end(),cmp);
ll ans = 0;
for(int i = 1; i <= len; i++){
ans+=1ll*(sum(v[i-1].r)-sum(v[i-1].l-1));
add(v[i-1].id,1);
}
printf("%lld
",ans);
}
return 0;
}
/*
*/