bzoj3473
简单的想法就是把这些串的广义(mathrm{SAM})建出来,然后对每个节点求出它代表的串出现在了多少个原串中。假设这个已经求出,接下来我们对每个节点求出它及其祖先节点的贡献(因为它们对应了最长串的一连串后缀),在求每个串的答案时在(mathrm{SAM})匹配就好了。
那么怎么求每个节点的串出现在了多少个原串中呢?暴力的想法是在(mathrm{SAM})上匹配,然后对其祖先打上标记看起来复杂度是(O(|s|^2))的,但是由于(mathrm{SAM})的节点数是(O(sum |s|))的,在配合上一些不等式技巧可证得打标记的总时间复杂度是(O(Lsqrt L))的((L=sum |s|))
#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
int n,m,lst,tot=1,siz[N<<1],ch[N<<1][26],len[N<<1],sum[N<<1],ord[N<<1],fa[N<<1],tax[N<<1],cnt[N<<1],col[N<<1];
char s[N];
vector<char> str[N];
void insert(int x)
{
if ((ch[lst][x]) && (len[ch[lst][x]]==len[lst]+1))
{
lst=ch[lst][x];
return;
}
int np=(++tot),p=lst,flag=0;len[np]=len[p]+1;
while ((p) && (!ch[p][x])) {ch[p][x]=np;p=fa[p];}
if (!p) fa[np]=1;
else
{
int q=ch[p][x];
if (len[q]==len[p]+1) fa[np]=q;
else
{
if (len[np]==len[p]+1) flag=1;
int nq=(++tot);len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];fa[np]=fa[q]=nq;
while ((p) && (ch[p][x]==q)) {ch[p][x]=nq;p=fa[p];}
if (flag) np=nq;
}
}
siz[np]=1;lst=np;
}
int main()
{
n=read();m=read();
rep(i,1,n)
{
scanf("%s",s+1);
int len=strlen(s+1);lst=1;
rep(j,1,len)
{
insert(s[j]-'a');
str[i].pb(s[j]);
}
}
rep(i,1,tot) tax[len[i]]++;
rep(i,1,tot) tax[i]+=tax[i-1];
per(i,tot,1) ord[tax[len[i]]--]=i;
rep(i,1,n)
{
int now=1,len=str[i].size();
rep(j,0,len-1)
{
int x=str[i][j]-'a';
now=ch[now][x];
int tmp=now;
while ((tmp) && (col[tmp]!=i))
{
col[tmp]=i;cnt[tmp]++;
tmp=fa[tmp];
}
}
}
cnt[1]=0;
rep(i,1,tot)
{
int u=ord[i],f=fa[u];
sum[u]=sum[f];
if (cnt[u]>=m) sum[u]+=len[u]-len[f];
}
rep(i,1,n)
{
ll ans=0;int now=1,len=str[i].size();
rep(j,0,len-1)
{
int x=str[i][j]-'a';
now=ch[now][x];
ans+=sum[now];
}
printf("%lld ",ans);
}
return 0;
}
hdu5343
最naive的想法就是将两个串的子串数目乘起来,这样显然会算重。
考虑对于每个合法串,我们让其与串(A)的匹配长度尽可能大,以保证不会重复计数。
这个过程可以通过在(mathrm{SAM})上进行dp来解决,在串(B)的(mathrm{SAM})上直接跑本质不同的子串个数的dp,在串(A)上还要加上在这个位置终止的方案数。两个dp都可以通过枚举下一个字母来解决。
#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
typedef unsigned long long ull;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
int n,m;
ull f[N<<1],g[N<<1];
char s[N],t[N];
bool visf[N<<1],visg[N<<1];
struct Suffix_Automaton{
int tot,lst,fa[N<<1],ch[N<<1][26],len[N<<1];
Suffix_Automaton() {tot=lst=1;}
void insert(int x)
{
int np=(++tot),p=lst;lst=np;len[np]=len[p]+1;
memset(ch[np],0,sizeof(ch[np]));
while ((p) && (!ch[p][x])) {ch[p][x]=np;p=fa[p];}
if (!p) {fa[np]=1;return;}
int q=ch[p][x];
if (len[q]==len[p]+1) {fa[np]=q;return;}
int nq=(++tot);len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];fa[np]=fa[q]=nq;
while ((p) && (ch[p][x]==q)) {ch[p][x]=nq;p=fa[p];}
}
void clr()
{
tot=lst=1;fa[1]=0;len[1]=0;
memset(ch[1],0,sizeof(ch[1]));
}
}sam1,sam2;
void dfs2(int u)
{
if (!u) return;
if (visg[u]) return;
g[u]=1;visg[u]=1;
rep(i,0,25)
{
int v=sam2.ch[u][i];
if (v) {dfs2(v);g[u]+=g[v];}
}
}
ll calc(int c) {return g[sam2.ch[1][c]];}
void dfs1(int u)
{
if (!u) return;
if (visf[u]) return;
f[u]=1;visf[u]=1;
rep(i,0,25)
{
int v=sam1.ch[u][i];
if (v) {dfs1(v);f[u]+=f[v];}
else f[u]+=calc(i);
}
}
int main()
{
int T=read();
while (T--)
{
scanf("%s",s+1);n=strlen(s+1);
scanf("%s",t+1);m=strlen(t+1);
rep(i,1,n) sam1.insert(s[i]-'a');
rep(i,1,m) sam2.insert(t[i]-'a');
rep(i,1,sam1.tot) visf[i]=0;
rep(i,1,sam2.tot) visg[i]=0;
dfs2(1);dfs1(1);
printf("%llu
",f[1]);
sam1.clr();sam2.clr();
}
return 0;
}
bzoj1396
出现次数为1的子串在(mathrm{SAM})上对应的节点显然是那些(mathrm{endpos})集合大小为(1)的点,对于每一个这样的点,记其(mathrm{endpos})中的元素为(r), 其对应的子串长度为([mn,mx]).
-
(forall pin[r-mn+1,r]),该节点的最短的串有可能成为(p)的答案。
-
(forall pin[r-mx+1,r-mn]), 串(s[p:r])有可能成为(p)的答案。
对上面两种情况分别开一棵线段树维护即可。
#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
inline int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
struct Segment_Tree{
int seg[N<<2],tag[N<<2];
void pushdown(int id)
{
if (tag[id]!=maxd)
{
seg[id<<1]=min(seg[id<<1],tag[id]);
seg[id<<1|1]=min(seg[id<<1|1],tag[id]);
tag[id<<1]=min(tag[id<<1],tag[id]);
tag[id<<1|1]=min(tag[id<<1|1],tag[id]);
tag[id]=maxd;
}
}
void build(int id,int l,int r)
{
seg[id]=tag[id]=maxd;
if (l==r) return;
int mid=(l+r)>>1;
build(id<<1,l,mid);build(id<<1|1,mid+1,r);
}
void modify(int id,int l,int r,int ql,int qr,int v)
{
if (ql>qr) return;
if ((l>=ql) && (r<=qr))
{
seg[id]=min(seg[id],v);
tag[id]=min(tag[id],v);
return;
}
pushdown(id);int mid=(l+r)>>1;
if (ql<=mid) modify(id<<1,l,mid,ql,qr,v);
if (qr>mid) modify(id<<1|1,mid+1,r,ql,qr,v);
seg[id]=min(seg[id<<1],seg[id<<1|1]);
}
int query(int id,int l,int r,int pos)
{
if (l==r) return seg[id];
pushdown(id);
int mid=(l+r)>>1;
if (pos<=mid) return query(id<<1,l,mid,pos);
else return query(id<<1|1,mid+1,r,pos);
}
}seg1,seg2;
int n,tot=1,lst=1,tax[N<<1],ord[N<<1],ch[N<<1][26],fa[N<<1],siz[N<<1],pos[N<<1],len[N<<1];
char s[N];
void insert(int x,int id)
{
int np=(++tot),p=lst;lst=np;len[np]=len[p]+1;
siz[np]=1;pos[np]=id;
while ((p) && (!ch[p][x])) {ch[p][x]=np;p=fa[p];}
if (!p) {fa[np]=1;return;}
int q=ch[p][x];
if (len[q]==len[p]+1) {fa[np]=q;return;}
int nq=(++tot);len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fa[nq]=fa[q];fa[np]=fa[q]=nq;
while ((p) && (ch[p][x]==q)) {ch[p][x]=nq;p=fa[p];}
}
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
rep(i,1,n) insert(s[i]-'a',i);
seg1.build(1,1,n);seg2.build(1,1,n);
rep(i,1,tot) tax[len[i]]++;
rep(i,1,n) tax[i]+=tax[i-1];
rep(i,1,tot) ord[tax[len[i]]--]=i;
per(i,tot,1)
{
int u=ord[i];siz[fa[u]]+=siz[u];
if (siz[u]!=1) continue;
int mx=len[u],mn=len[fa[u]]+1,p=pos[u];
seg1.modify(1,1,n,p-mn+1,p,mn);
seg2.modify(1,1,n,p-mx+1,p-mn,p+1);
}
rep(i,1,n)
{
int ans1=seg1.query(1,1,n,i),ans2=seg2.query(1,1,n,i)-i;
printf("%d
",min(ans1,ans2));
}
return 0;
}