• 【bzoj3473】字符串 【后缀自动机+树状数组】


    题意:给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?(本质相同重复计算)
    题解:首先我们把这n个字符串的广义后缀自动机建立出来,然后处理出每个状态出现在n个串的多少个之中。接着把每个串在后缀自动机跑一遍,统计即可。
    如何处理出每个状态出现在n个串的多少个之中?
    如果一个状态x出现在某个串y中,那么fail[x]一定也出现在y中,因为fail[x]是x的一个后缀。设val[i]代表状态i来自于哪个串。所以如果我们把fail链倒过来建一棵树,状态x出现在n个串之中的个数就是x的子树中的val的不同个数。我们就可以把这棵树dfs一次,处理出每个节点的dfs序区间,就把这个问题转化为了查询一个区间有多少个不同的数字,跟HH的项链那题一模一样。用树状数组处理一下就好了。
    如何统计?
    设ans[x]表示x出现在n个串之中的个数。只需要对每个串在后缀自动机上走,如果ans[now]小于k的话now就不停地跳fail。这就相当于把当前匹配到的不停地截短。然后答案累加上len[now]即可。至于为什么,请读者自行思考。而且匹配的过程中,不会失配,这也很显然。
    时间复杂度: n log n
    代码实现:

    #include<cstdio>
    #include<algorithm>
    #include<string>
    #include<cstring>
    #include<vector>
    using namespace std;
    const int N=100005;
    int n,k,l;
    string s[N];
    char str[N];
    bool cmp(int a,int b);
    struct SAM{
        int last,tot,len[N*2],fail[N*2],val[N*2],ch[N*2][26],c[N*2],a[N*2];
        int idx,in[N*2],out[N*2],pos[N*2],nxt[N*2],ck[N*2],ans[N*2];
        vector<int> e[N*2];
        SAM(){
            last=tot=1;
        }
        void insert(int x,int id){
            int p=last,np=++tot;
            len[np]=len[p]+1;
            last=np;
            val[np]=id;
            for(;p&&!ch[p][x];p=fail[p]){
                ch[p][x]=np;
            }
            if(!p){
                fail[np]=1;
            }else{
                int q=ch[p][x];
                if(len[q]==len[p]+1){
                    fail[np]=q;
                }else{
                    int nq=++tot;
                    len[nq]=len[p]+1;
                    memcpy(ch[nq],ch[q],sizeof(ch[q]));
                    fail[nq]=fail[q];
                    fail[q]=fail[np]=nq;
                    for(;p&&ch[p][x]==q;p=fail[p]){
                        ch[p][x]=nq;
                    }
                }
            }
        }
        void dfs(int u){
            in[u]=++idx;
            pos[idx]=u;
            for(int i=0;i<e[u].size();i++){
                dfs(e[u][i]);
            }
            out[u]=idx;
        }
        int lowbit(int x){
            return x&(-x);
        }
        void add(int i){
            while(i<=tot){
                c[i]++;
                i+=lowbit(i);
            }
        }
        int sum(int i){
            int res=0;
            while(i){
                res+=c[i];
                i-=lowbit(i);
            }
            return res;
        }
        void build(){
            for(int i=2;i<=tot;i++){
                e[fail[i]].push_back(i);
            }
            dfs(1);
            for(int i=1;i<=tot;i++){
                a[i]=i;
            }
            sort(a+1,a+tot+1,cmp);
            for(int i=tot;i>=1;i--){
                if(val[pos[i]]){
                    nxt[i]=ck[val[pos[i]]];
                    ck[val[pos[i]]]=i;
                }
            }
            for(int i=1;i<=tot;i++){
                if(ck[i]){
                    add(ck[i]);
                }
            }
            for(int i=1,j=1;i<=tot;i++){
                while(j<in[a[i]]){
                    if(nxt[j]){
                        add(nxt[j]);
                    }
                    j++;
                }
                ans[a[i]]=sum(out[a[i]])-sum(in[a[i]]-1);
            }
        }
        long long query(const char *s,int l){
            long long res=0;
            int now=1;
            for(int i=0;i<l;i++){
                now=ch[now][s[i]-'a'];
                while(now&&ans[now]<k){
                    now=fail[now];
                }
                if(!now){
                    now=1;
                    continue;
                }
                res+=len[now];
            }
            return res;
        }
    }sam;
    bool cmp(int a,int b){
        return sam.in[a]==sam.in[b]?sam.out[a]<sam.out[b]:sam.in[a]<sam.in[b];
    }
    int main(){
        scanf("%d%d",&n,&k);
        for(int i=1;i<=n;i++){
            scanf("%s",str);
            s[i]=str;
            l=strlen(str);
            sam.last=1;
            for(int j=0;j<l;j++){
                sam.insert(str[j]-'a',i);
            }
        }
        sam.build();
        for(int i=1;i<=n;i++){
            printf("%lld ",sam.query(s[i].c_str(),s[i].size()));
        }
        puts("");
        return 0;
    }
  • 相关阅读:
    怎么重新启动网卡
    @JsonProperty的使用
    JAVA中的反射机制
    spring的IOC入门案例
    spring的IOC底层原理
    maven+Spring环境搭建
    SpringMVC与Struts2区别与比较总结
    Struts2面试题
    oracle自增序列创建
    Hibernate分页查询报错
  • 原文地址:https://www.cnblogs.com/2016gdgzoi471/p/9476884.html
Copyright © 2020-2023  润新知