题目描述
你有一块 $R$ 行 $C$ 列的矩阵 $G$,矩阵里的每个格子有一个大写字母。
你有 $Q$ 个询问字符串,每个字符串都由大写字母构成。你想要知道这 $Q$ 个字符串每个在矩阵 $G$ 中出现的次数。
一个字符串 $S$ 在矩阵 $G$ 中出现,当且仅当存在一个四元组 $(r,c,dr,dc)$,满足:
- $1le rle R,\, rle r+drle R$.
- $1le cle C,\, cle c+dcle C$.
- $S=G_{r,c}G_{r,c+1}cdots G_{r,c+dc}G_{r+1,c+dc}cdots G_{r+dr,c+dc}$.
题解
考虑我们匹配到了一个串,如下图,我们可以把这个串看成红色的串的一个子串。
那这样的话考虑建出 $ ext{AC}$ 自动机的 $ ext{fail}$ 树,对于每个形如红色的串,我们可以遍历它的同时在 $ ext{AC}$ 自动机上走,每走到一个点就在 $ ext{fail}$ 树上的节点打上 $+1$ 的标记,那对于询问串的贡献就是询问串在 $ ext{fail}$ 树上的结束节点的子树总和,这个做个 $ ext{dfs}$ 序为下标的前缀和。
但是我们注意到如果黄色的串是横着或竖着的话(即没有拐点),那它会被重复计算,所以我们要把这一部分的贡献减去即可,具体操作和上面的类似。
由于红色的串只有 $nm$ 条,长度最多 $n+m$ ,所以效率为 $O(nm(n+m)+|s|)$ 。
代码
#include <bits/stdc++.h> #define LL long long using namespace std; const int N=2e5+5; int q,n,m,tt,tr[N][26],fi[N],id[N],sz[N],len[N],g[N]; LL a[N],b[N],s[N];char ch[505][505],Ch[N]; queue<int>Q;vector<int>e[N]; int ins(){ int v=0,l=strlen(Ch); for (int i=0,j;i<l;i++){ j=Ch[i]-'A'; if (!tr[v][j]) tr[v][j]=++tt; v=tr[v][j]; } return v; } void build(){ for (int i=0;i<26;i++) if (tr[0][i]) Q.push(tr[0][i]); while(!Q.empty()){ int u=Q.front();Q.pop(); for (int v,i=0;i<26;i++){ v=tr[u][i]; if (v) fi[v]=tr[fi[u]][i],Q.push(v); else tr[u][i]=tr[fi[u]][i]; } } for (int i=1;i<=tt;i++) e[fi[i]].push_back(i); } void dfs(int u){ id[u]=++tt;sz[u]=1; int z=e[u].size(); for (int v,i=0;i<z;i++) v=e[u][i],dfs(v),sz[u]+=sz[v]; } int main(){ cin>>n>>m>>q; for (int i=1;i<=n;i++) scanf("%s",ch[i]+1); for (int i=1;i<=q;i++) scanf("%s",Ch),len[i]=strlen(Ch),g[i]=ins(); build();dfs(tt=0); for (int i=1;i<=n;i++) for (int j=1;j<=m;j++){ int p=0,x=i,y=1; while(y<j) p=tr[p][ch[x][y]-'A'],y++,a[id[p]]++; while(x<=n) p=tr[p][ch[x][y]-'A'],x++,a[id[p]]++; } for (int i=1;i<=tt;i++) a[i]+=a[i-1]; for (int i=1,j;i<=q;i++) j=id[g[i]],s[i]=a[j+sz[g[i]]-1]-a[j-1]; for (int i=1;i<=tt;i++) a[i]=0; for (int j=1,p;j<=m;j++){ p=0; for (int i=1;i<=n;i++) p=tr[p][ch[i][j]-'A'],a[id[p]]+=i,b[id[p]]++; } for (int i=1;i<=tt;i++) a[i]+=a[i-1],b[i]+=b[i-1]; for (int i=1,j;i<=q;i++) j=id[g[i]],s[i]-=(a[j+sz[g[i]]-1]-a[j-1])-(b[j+sz[g[i]]-1]-b[j-1])*len[i]; for (int i=1;i<=tt;i++) a[i]=0; for (int i=1,p;i<=n;i++){ p=0; for (int j=1;j<=m;j++) p=tr[p][ch[i][j]-'A'],a[id[p]]+=m-j; } for (int i=1;i<=tt;i++) a[i]+=a[i-1]; for (int i=1,j;i<=q;i++) j=id[g[i]],s[i]-=(a[j+sz[g[i]]-1]-a[j-1]),printf("%lld ",s[i]); return 0; }