题目
题目链接:https://www.luogu.com.cn/problem/P7582
小 Soup 在这段时间中记录了 (n) 个有意义的东西,他把它们用字符串表示了出来,第 (i) 个东西被表示成 (s_i),并定义了它的价值 (a_i)。下面,小 Soup 会进行 (m) 次操作。
操作 (1):小 Soup 将区间 (l,r) 里的 (a_i) 都加上一个常数 (k)。
操作 (2):小 Soup 将区间 (l,r) 里的 (a_i) 都赋值成一个常数 (k)。
操作 (3):小 Soup 给出了一段回忆,这段回忆形成了一个字符串 (S),他想求 (S) 在区间 (l,r) 中的意义有多大。定义 (cnt_i) 为 (s_i) 在 (S) 中的出现次数,则 (S) 在区间 (l,r) 中的意义为 (sumlimits_{i=l}^r cnt_i imes a_i)。
(n,mleq 3 imes 10^4),(sum|S|,sum |s|leq 2 imes 10^5)。
思路
如果修改是单点修改,询问是全局查询,那么相当于把所有串的 AC 自动机建出来,每次询问就在 AC 自动机上跑,走到点 (p) 的时候贡献就是 fail 树根到 (p) 的路径上点的权值和。可以转化为子树加单点查询。求出 fail 树的 dfs 序后用树状数组就可以维护。
回到本题。分块,设每块有 (B) 个字符串。
对于每一个块维护 (tag) 和 (val),分别表示是否区间推平,以及区间加的的标记。把每一个块的字符串分别扔进一个 AC 自动机中。
对于一次修改操作,两边零散的部分就暴力处理,更新树状数组。注意在暴力处理一个区间时,需要先把这个区间的标记下传。而中间的整块直接修改标记。如果是操作 (1) 就让 (val) 加上 (k);如果是操作 (2) 就让 (val) 赋值为 (k) 并且打上推平标记。
对于询问操作,两边零散的块中,长度超过 (|S|) 的字符串肯定没有贡献,长度不超过 (|S|) 的字符串就暴力 KMP 求出现次数。这样对于一次询问,这部分复杂度就是 (O(B imes|S|)) 的。中间的整块就把 (S) 扔到 AC 自动机上跑匹配,记匹配次数为 (cnt),匹配中树状数组求出来的贡献为 (sum)。那么 (cnt imes val) 是肯定会被算进答案中的,如果没有区间推平标记,那么还需要加上 (sum)。
时间复杂度 (O(msqrt{n}log (sum |s|)))。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=201000;
int n,Q,B,bel[N],L[N],R[N],a[N],rt[N],pos[N],id[N],siz[N],beg[N],nxt[N];
char s[N],t[N],tt[N];
bool tag[N];
ll val[N];
struct edge
{
int next,to;
};
struct BIT
{
ll c[N];
void add(int x,ll v)
{
for (int i=x;i<N;i+=i&-i)
c[i]+=v;
}
ll query(int x)
{
ll ans=0;
for (int i=x;i;i-=i&-i)
ans+=c[i];
return ans;
}
}bit;
struct ACA
{
int tot,tot1,tot2,ch[N][3],fail[N],head[N],ed[N];
edge e[N];
ACA() { memset(head,-1,sizeof(head)); }
void add(int from,int to)
{
e[++tot1]=(edge){head[from],to};
head[from]=tot1;
}
int insert(int p,char *s)
{
int len=strlen(s+1);
for (int i=1;i<=len;i++)
{
if (!ch[p][s[i]-'a']) ch[p][s[i]-'a']=++tot;
p=ch[p][s[i]-'a'];
}
ed[p]++;
return p;
}
void build(int p)
{
queue<int> q;
fail[p]=p;
for (int i=0;i<3;i++)
if (ch[p][i])
q.push(ch[p][i]),fail[ch[p][i]]=p;
else
ch[p][i]=p;
while (q.size())
{
int u=q.front(); q.pop();
add(fail[u],u);
ed[u]+=ed[fail[u]];
for (int i=0;i<3;i++)
if (ch[u][i])
q.push(ch[u][i]),fail[ch[u][i]]=ch[fail[u]][i];
else
ch[u][i]=ch[fail[u]][i];
}
}
void dfs(int x)
{
id[x]=++tot2; siz[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
dfs(v); siz[x]+=siz[v];
}
}
void query(int p,int len,ll &cnt,ll &sum)
{
for (int i=1;i<=len;i++)
{
p=ch[p][s[i]-'a'];
cnt+=ed[p]; sum+=bit.query(id[p]);
}
}
}AC;
void pushdown(int x)
{
if (tag[x])
{
for (int i=L[x];i<=R[x];i++)
{
bit.add(id[pos[i]],-a[i]);
bit.add(id[pos[i]]+siz[pos[i]],a[i]);
a[i]=0;
}
tag[x]=0;
}
if (val[x])
{
for (int i=L[x];i<=R[x];i++)
{
bit.add(id[pos[i]],val[x]);
bit.add(id[pos[i]]+siz[pos[i]],-val[x]);
a[i]+=val[x];
}
val[x]=0;
}
}
void upd1(int l,int r,ll k)
{
pushdown(bel[l]);
for (int i=l;i<=r;i++)
{
bit.add(id[pos[i]],k);
bit.add(id[pos[i]]+siz[pos[i]],-k);
a[i]+=k;
}
}
void update1(int l,int r,ll k)
{
int ql=bel[l],qr=bel[r];
if (ql==qr) { upd1(l,r,k); return; }
upd1(l,R[ql],k); upd1(L[qr],r,k);
for (int i=ql+1;i<qr;i++) val[i]+=k;
}
void upd2(int l,int r,ll k)
{
pushdown(bel[l]);
for (int i=l;i<=r;i++)
{
bit.add(id[pos[i]],-a[i]+k);
bit.add(id[pos[i]]+siz[pos[i]],a[i]-k);
a[i]=k;
}
}
void update2(int l,int r,ll k)
{
int ql=bel[l],qr=bel[r];
if (ql==qr) { upd2(l,r,k); return; }
upd2(l,R[ql],k); upd2(L[qr],r,k);
for (int i=ql+1;i<qr;i++) val[i]=k,tag[i]=1;
}
int kmp(char *t,int n)
{
int m=strlen(t+1),cnt=0;
nxt[1]=0;
for (int i=2,j=0;i<=m;i++)
{
while (j && t[j+1]!=t[i]) j=nxt[j];
if (t[j+1]==t[i]) j++;
nxt[i]=j;
}
for (int i=1,j=0;i<=n;i++)
{
while (j && t[j+1]!=s[i]) j=nxt[j];
if (t[j+1]==s[i]) j++;
if (j==m) cnt++;
}
return cnt;
}
ll qry(int l,int r,int len)
{
pushdown(bel[l]);
ll ans=0;
for (int i=l;i<=r;i++)
if (beg[i+1]-beg[i]<=len)
{
for (int j=beg[i];j<beg[i+1];j++)
tt[j-beg[i]+1]=t[j];
tt[beg[i+1]-beg[i]+1]=' 00';
ans+=1LL*kmp(tt,len)*a[i];
}
return ans;
}
ll query(int l,int r)
{
int ql=bel[l],qr=bel[r],len=strlen(s+1);
if (ql==qr) return qry(l,r,len);
ll ans=qry(l,R[ql],len)+qry(L[qr],r,len);
for (int i=ql+1;i<qr;i++)
{
ll cnt=0,sum=0;
AC.query(rt[i],len,cnt,sum);
ans+=cnt*val[i];
if (!tag[i]) ans+=sum;
}
return ans;
}
int main()
{
scanf("%d%d",&n,&Q);
B=sqrt(n)+1; beg[1]=1;
for (int i=1;i<=B;i++)
{
L[i]=R[i-1]+1; R[i]=min(i*B,n);
rt[i]=++AC.tot;
for (int j=L[i];j<=R[i];j++)
{
scanf("%s%d",s+1,&a[j]);
pos[j]=AC.insert(rt[i],s); bel[j]=i;
int len=strlen(s+1); beg[j+1]=beg[j]+len;
for (int k=1;k<=len;k++) t[k+beg[j]-1]=s[k];
}
AC.build(rt[i]); AC.dfs(rt[i]);
}
for (int i=1;i<=n;i++)
{
bit.add(id[pos[i]],a[i]);
bit.add(id[pos[i]]+siz[pos[i]],-a[i]);
}
while (Q--)
{
int opt,l,r,k;
scanf("%d%d%d",&opt,&l,&r);
if (opt==1) scanf("%d",&k),update1(l,r,k);
if (opt==2) scanf("%d",&k),update2(l,r,k);
if (opt==3) scanf("%s",s+1),cout<<query(l,r)<<"
";
}
return 0;
}