NTT (O(nlogn)) 30pts
显然的,对每个字母跑一遍NTT即可,但就是TLE。
#include<bits/stdc++.h>
#define il inline
#define vd void
#define mod 998244353
typedef long long ll;
il int gi(){
int x=0,f=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}
il int pow(int x,int y){
int ret=1;
while(y){
if(y&1)ret=1ll*ret*x%mod;
x=1ll*x*x%mod;y>>=1;
}
return ret;
}
#define G 3
#define iG 332748118
char S[100003],T[100003],yyb[]="ATGC";
int ans[100003],rev[525191],A[525191],B[525191];
int P[525191],iP[525191];
il vd ntt(int*A,int n,int t){
for(int i=0;i<n;++i)if(rev[i]>i)std::swap(A[i],A[rev[i]]);
for(int o=1;o<n;o<<=1){
int W=t?P[o]:iP[o];
for(int*p=A;p!=A+n;p+=o<<1)
for(int i=0,w=1;i<o;++i,w=1ll*w*W%mod){
int t=1ll*p[i+o]*w%mod;
p[i+o]=(p[i]-t+mod)%mod,p[i]=(p[i]+t)%mod;
}
}
if(!t){
int inv=pow(n,mod-2);
for(int i=0;i<n;++i)A[i]=1ll*A[i]*inv%mod;
}
}
int main(){
int TT=gi();
while(TT--){
scanf("%s",S+1),scanf("%s",T+1);
int n=strlen(S+1),m=strlen(T+1),N,lg;
N=1,lg=0;while(N<(n+m+2)<<1)N<<=1,++lg;
for(int i=0;i<N;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
for(int i=1;i<N;i<<=1)P[i]=pow(G,(mod-1)/(i<<1)),iP[i]=pow(iG,(mod-1)/(i<<1));
for(int i=1;i<=n;++i)ans[i]=0;
for(int c=0;c<4;++c){
for(int i=0;i<N;++i)A[i]=B[i]=0;
for(int i=1;i<=n;++i)A[i]=S[i]==yyb[c];
for(int i=1;i<=m;++i)B[i]=T[i]!=yyb[c];
std::reverse(B+1,B+m+1);
ntt(A,N,1),ntt(B,N,1);
for(int i=0;i<N;++i)A[i]=1ll*A[i]*B[i]%mod;
ntt(A,N,0);
for(int i=1;i<=n-m+1;++i)ans[i]+=A[i+m];
}
int ANS=0;for(int i=1;i<=n-m+1;++i)ANS+=ans[i]<=3;
printf("%d
",ANS);
}
return 0;
}
哈希 (O(nlogn)) 100pts
题目要差不超过3,所以可以重复3次,求出LCP,然后修改LCP的后面一个。
#include<bits/stdc++.h>
#define il inline
#define vd void
#define mod1 998244853
#define mod2 1000000009
typedef long long ll;
il int gi(){
int x=0,f=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}
char S[100003],T[100003];
int n,m;
ll base1[100010],base2[100010];
ll hash1S[100010],hash1T[100010];
ll hash2S[100010],hash2T[100010];
il std::pair<ll,ll>getHashS(int l,int r){return std::make_pair((hash1S[r]-hash1S[l-1]+mod1)*base1[100000-r]%mod1,(hash2S[r]-hash2S[l-1]+mod2)*base2[100000-r]%mod2);}
il std::pair<ll,ll>getHashT(int l,int r){return std::make_pair((hash1T[r]-hash1T[l-1]+mod1)*base1[100000-r]%mod1,(hash2T[r]-hash2T[l-1]+mod2)*base2[100000-r]%mod2);}
int main(){
int TT=gi();
base1[0]=1;for(int i=1;i<=100000;++i)base1[i]=base1[i-1]*19260817ll%mod1;
base2[0]=1;for(int i=1;i<=100000;++i)base2[i]=base2[i-1]*23333ll%mod2;
while(TT--){
scanf("%s",S+1),scanf("%s",T+1);
n=strlen(S+1),m=strlen(T+1);
for(int i=1;i<=n;++i)hash1S[i]=(hash1S[i-1]+base1[i]*S[i])%mod1;
for(int i=1;i<=n;++i)hash2S[i]=(hash2S[i-1]+base2[i]*S[i])%mod2;
for(int i=1;i<=m;++i)hash1T[i]=(hash1T[i-1]+base1[i]*T[i])%mod1;
for(int i=1;i<=m;++i)hash2T[i]=(hash2T[i-1]+base2[i]*T[i])%mod2;
int ANS=0;
for(int i=1;i<=n-m+1;++i){
int lst=1,t=4,l,r,mid;
while(t--){
l=lst,r=m;
while(l<r){
mid=((l+r)>>1)+1;
if(getHashS(i+lst-1,i+mid-1)==getHashT(lst,mid))l=mid;
else r=mid-1;
}
if(S[i+l-1]!=T[l])--l;
if(l>=m){++ANS;break;}
lst=l+2;
}
}
printf("%d
",ANS);
}
return 0;
}