题意
对于一个字符串(S),我们如下定义(f(S)):
每次以(p_{alpha})的概率将字母(alpha)(可能是26个小写字母中的一种)加入到初始为空的字符序列(T)的末尾,如果(T)中出现了(S),即(S)是(T)的子串,则停止,记(T)的期望长度为(f(S))。
现给定字符串(S)((1le |S|le 5 imes10^5)),求
[sum_{S的所有子串t}f(t)
]
题解
对于一个字符串(S(n=|S|)),令(a_i)为结束时随机序列长度为(i)的概率,其生成函数为(A(x))。令(b_i)为随机序列长度达到(i)且还未结束的概率,其生成函数为(B(x))。则有
[B(x)(prod_{i=1}^{n}p_{S_{i}}x)=sum_{i=1}^{n}[S[1...i]是S的border]A(x)(prod_{j=n-i}^{n}p_{S_j}x),
]
可得
[f(S)=A'(1)=B(1)=sum_{i=1}^{n}[S[1...i]是S的border](prod_{j=1}^{i}frac{1}{p_{S_j}}).
]
即只有(S)的(border)对(f(S))有贡献。于是我们考虑计算(SAM)上每个节点对最后答案的贡献。
对于(SAM)上的节点(u),设(pos[u]_i)表示(u)节点所代表的的最长串在(S)中第(i)次出现的结束位置,(cnt_u)为其在(S)中的出现次数,则对于(u)在(S)中任意两次不同出现(i,j(i<j)),(u)所代表的一系列串在串(S[pos[u]_i-len[u]+1,pos[u]_j])中形成了(len_u-len_{fa_u})个(border),所有这些形成的(border)即为节点(u)对最后答案的贡献,即
[inom{cnt_u}{2}sum_{i=1+len_{fa_u}}^{len_u}(prod_{j=pos_u-i}^{pos_u}frac{1}{p_{S_j}}),
]
记
[w_i=prod_{j=1}^{i}frac{1}{p_{S_j}},
]
则
[inom{cnt_u}{2}sum_{i=1+len_{fa_u}}^{len_u}(prod_{j=pos_u-i}^{pos_u}frac{1}{p_{S_j}})=inom{cnt_u}{2}sum_{i=1+len_{fa_u}}^{len_u}frac{w_{pos_u}}{w_{pos_u-i}}=inom{cnt_u}{2}w_{pos_u}sum_{i=1+len_{fa_u}}^{len_u}frac{1}{w_{pos_u-i}}.
]
记录(frac{1}{w_i})的前缀和即可计算。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1000005,M=26;
const ll MOD=1e9+7;
char S[N];
int n,m,x,y;
ll w[N],sw[N],p[N];
ll pm(ll x,ll b){ll res=1;while(b){if(b&1)res=res*x%MOD;b>>=1;x=x*x%MOD;}return res;}
ll getlr(int l,int r){return (sw[r]-sw[l-1]+MOD)%MOD;}
struct sam{
int fa[N],sz[N],len[N],lst,gt,ch[N][M],pos[N];
void init(){gt=lst=1;}
void init2(){
for(int p=1;p<=gt;p++){
for(int i=0;i<M;i++)if(ch[p][i])ch[p][i]=0;
}
lst=gt=1;
}
void ins(int c,int id){
int f=lst,p=++gt;lst=p;
len[p]=len[f]+1;sz[p]=1;pos[p]=id;
while(f&&!ch[f][c])ch[f][c]=p,f=fa[f];
if(!f){fa[p]=1;return ;}
int x=ch[f][c],y=++gt;
if(len[x]==len[f]+1){gt--;fa[p]=x;return ;}
len[y]=len[f]+1;pos[y]=pos[x];fa[y]=fa[x];fa[x]=fa[p]=y;
for(int i=0;i<M;i++)ch[y][i]=ch[x][i];
while(f&&ch[f][c]==x)ch[f][c]=y,f=fa[f];
}
int A[N],c[N];
void rsort(){
for(int i=1;i<=gt;i++){c[i]=0;}
for(int i=1;i<=gt;i++)++c[len[i]];
for(int i=1;i<=gt;i++)c[i]+=c[i-1];
for(int i=gt;i>=1;i--){A[c[len[i]]--]=i;}
for(int i=gt;i>=1;i--){
int u=A[i];
sz[fa[u]]+=sz[u];
}
}
ll f2(){
rsort();
ll ans=0;
for(int u=2;u<=gt;u++){
ll na=1ll*sz[u]*(sz[u]+1)/2%MOD*w[pos[u]]%MOD;
int a=len[fa[u]]+1,b=len[u],r=pos[u]-a,l=pos[u]-b;
if(r>=0){ans=(ans+na*sw[r]%MOD)%MOD;}
if(l-1>=0){ans=(ans-na*sw[l-1]%MOD+MOD)%MOD;}
}
return ans;
}
}g,t;
void f1(){
scanf("%s",S+1);
n=strlen(S+1);
ll na=0;
for(int i=0;i<M;i++){scanf("%lld",&p[i]);na+=p[i];}
for(int i=0;i<M;i++){p[i]=na*pm(p[i],MOD-2)%MOD;}
w[0]=1;sw[0]=1;
for(int i=1;i<=n;i++){
w[i]=w[i-1]*p[S[i]-'a']%MOD;
sw[i]=(sw[i-1]+pm(w[i],MOD-2))%MOD;
}
g.init();
for(int i=1;i<=n;i++){g.ins(S[i]-'a',i);}
printf("%lld",g.f2());
}
int main(){
f1();
return 0;
}