反正先求一遍sa
然后这个问题可以稍微转化一下
默认比较A、B数组中元素的大小都是比较它们rank的大小,毕竟两个位置的LCP就是它们rank的rmq
然后每次只要求B[j]>=A[i]的LCP(B[j],A[i]),然后再求A[j]>B[i]的LCP(A[j],B[i])即可
这两个其实是差不多的,下面只说B[j]>=A[i]的怎么算
排序以后从后往前推着做(当然从前往后也行)
用一个权值线段树记下来LCP(A[i],B[j])==x的B[j]的数量、以及这个数量*x的和
然后考虑怎么把它从A[i]转移到A[i-1]
其实就是对于每个j给LCP(A[i],B[j])和LCP(A[i-1],A[i])取个min,放到权值线段树上,就是把大于LCP(A[i-1],A[i])的都删掉,然后在LCP(A[i-1],A[i])处加上刚才删掉的个数
所以我推着做的时候,先来取个min,然后把新来的B[j]加到线段树里,每做一个A[i]统计一下线段树整体的和即可
1 #include<bits/stdc++.h> 2 #define pa pair<int,int> 3 #define CLR(a,x) memset(a,x,sizeof(a)) 4 using namespace std; 5 typedef long long ll; 6 const int maxn=4e5+10; 7 8 inline ll rd(){ 9 ll x=0;char c=getchar();int neg=1; 10 while(c<'0'||c>'9'){if(c=='-') neg=-1;c=getchar();} 11 while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar(); 12 return x*neg; 13 } 14 15 int N,M,Q; 16 char s[maxn]; 17 int sa[maxn],rnk[maxn],hei[maxn],rank1[maxn],tmp[maxn],cnt[maxn]; 18 int st[maxn][22]; 19 20 inline void getsa(){ 21 int i,j=0,k; 22 for(i=1;i<=N;i++) cnt[s[i]]=1; 23 for(i=1;i<=M;i++) cnt[i]+=cnt[i-1]; 24 for(i=N;i;i--) rnk[i]=cnt[s[i]]; 25 26 for(k=1;j!=N;k<<=1){ 27 memset(cnt,0,sizeof(cnt)); 28 for(i=1;i<=N;i++) cnt[rnk[i+k>N?0:i+k]]++; 29 for(i=1;i<=M;i++) cnt[i]+=cnt[i-1]; 30 for(i=N;i;i--) tmp[cnt[rnk[i+k>N?0:i+k]]--]=i; 31 memset(cnt,0,sizeof(cnt)); 32 for(i=1;i<=N;i++) cnt[rnk[i]]++; 33 for(i=1;i<=M;i++) cnt[i]+=cnt[i-1]; 34 for(i=N;i;i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i]; 35 memcpy(rank1,rnk,sizeof(rank1)); 36 rnk[sa[1]]=j=1; 37 for(i=2;i<=N;i++){ 38 if(rank1[sa[i]]!=rank1[sa[i-1]]||rank1[sa[i]+k>N?0:sa[i]+k]!=rank1[sa[i-1]+k>N?0:sa[i-1]+k]) j++; 39 rnk[sa[i]]=j; 40 }M=j; 41 } 42 for(i=1;i<=N;i++) sa[rnk[i]]=i; 43 } 44 45 inline void geth(){ 46 for(int i=1,j=0;i<=N;i++){ 47 if(rnk[i]==1) continue; 48 if(j) j--; 49 int x=sa[rnk[i]-1]; 50 while(x+j<=N&&i+j<=N&&s[x+j]==s[i+j]) j++; 51 hei[rnk[i]]=j; 52 } 53 } 54 55 inline void getst(){ 56 for(int i=N;i;i--){ 57 st[i][0]=hei[i]; 58 for(int j=1;st[i+(1<<(j-1))][j-1];j++){ 59 st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]); 60 } 61 } 62 } 63 64 inline int rmq(int l,int r){ 65 if(l>r) return N-sa[r]+1; 66 int x=log2(r-l+1); 67 return min(st[l][x],st[r-(1<<x)+1][x]); 68 } 69 70 ll sum[maxn<<2],siz[maxn<<2]; 71 bool laz[maxn<<2]; 72 73 inline void update(int p){sum[p]=sum[p<<1]+sum[p<<1|1],siz[p]=siz[p<<1]+siz[p<<1|1];} 74 inline void pushdown(int p){ 75 if(!laz[p]) return; 76 int a=p<<1,b=p<<1|1; 77 sum[a]=siz[a]=0,laz[a]=1; 78 sum[b]=siz[b]=0,laz[b]=1; 79 laz[p]=0; 80 } 81 int erase(int p,int l,int r,int x,int y){ 82 if(x<=l&&r<=y){ 83 int re=siz[p]; 84 siz[p]=sum[p]=0;laz[p]=1; 85 pushdown(p); 86 return re; 87 }else{ 88 pushdown(p); 89 int m=l+r>>1,re=0; 90 if(x<=m) re=erase(p<<1,l,m,x,y); 91 if(y>=m+1) re+=erase(p<<1|1,m+1,r,x,y); 92 update(p); 93 return re; 94 } 95 } 96 void add(int p,int l,int r,int x,int y){ 97 if(l==r){ 98 siz[p]+=y,sum[p]+=1ll*y*l; 99 }else{ 100 pushdown(p); 101 int m=l+r>>1; 102 if(x<=m) add(p<<1,l,m,x,y); 103 else add(p<<1|1,m+1,r,x,y); 104 update(p); 105 } 106 } 107 108 inline bool cmp(int a,int b){return rnk[a]<rnk[b];} 109 110 int A[maxn],B[maxn]; 111 ll solve(int k,int l){ 112 sort(A+1,A+k+1,cmp);sort(B+1,B+l+1,cmp); 113 ll re=0; 114 for(int i=k,j=l;i;i--){ 115 if(i!=k){ 116 int x=rmq(rnk[A[i]]+1,rnk[A[i+1]]),t=erase(1,0,N,x+1,N); 117 add(1,0,N,x,t); 118 } 119 for(;rnk[B[j]]>=rnk[A[i]]&&j;j--) 120 add(1,0,N,rmq(rnk[A[i]]+1,rnk[B[j]]),1); 121 re+=sum[1]; 122 } 123 erase(1,0,N,0,N); 124 for(int i=k,j=l;j;j--){ 125 if(j!=l){ 126 int x=rmq(rnk[B[j]]+1,rnk[B[j+1]]),t=erase(1,0,N,x+1,N); 127 add(1,0,N,x,t); 128 } 129 for(;rnk[A[i]]>rnk[B[j]]&&i;i--) 130 add(1,0,N,rmq(rnk[B[j]]+1,rnk[A[i]]),1); 131 re+=sum[1]; 132 } 133 erase(1,0,N,0,N); 134 return re; 135 } 136 137 int main(){ 138 //freopen(".in","r",stdin); 139 int i,j,k; 140 N=rd(),Q=rd(); 141 scanf("%s",s+1); 142 M=128; 143 getsa();geth();getst(); 144 for(i=1;i<=Q;i++){ 145 int a=rd(),b=rd(); 146 for(j=1;j<=a;j++) 147 A[j]=rd(); 148 for(j=1;j<=b;j++) 149 B[j]=rd(); 150 printf("%I64d ",solve(a,b)); 151 } 152 return 0; 153 }