• 回忆树


    先上代码

    代码

    #include <bits/stdc++.h>
    /*
    a.长度不一定刚好是len*2-2,所以要计算kmp的长度 
    b.忘了返回值 
    */
    using namespace std;
    
    const int N = 101000,LOGN = 20,M = 301000;
    
    char read(){
        char ch = getchar();
        while (ch < 'a' || 'z' < ch) ch = getchar();
        return ch;
    }
    struct Edge{
        int next,end;
    }edge[LOGN*N];
    struct Edge2{
        int next,end;char ch;
    }edge2[N<<1];
    struct Node{
        int cnt,son[2];
    }nod[N*LOGN];
    
    
    char ch[N],s[M],d[M];
    int efn,fa[N][LOGN],first[N][LOGN];//倍增 
    int first2[N],efn2;
    int sa[N],rank[N];//后缀数组 sa 排名为i的是 rank->i的排名 
    int dep[N],siz[N],hson[N],fat[N],top[N];//树剖 
    int next[N];//kmp
    int root[N],Len;
    int n,m,o,lg2[N],len;
    void init();
    void addedge(int,int,int);
    void addedge(int,int,char);
    void dfs(int,int);
    void dfs3(int,int);
    void dfs2(int,int);
    void insert(int,int,int,int,int);
    void build(int,int,int);
    int lca(int,int);
    int ef1(int,int);
    int ef2(int,int);
    int getans(int,int,int,int,int,int);
    bool check1(int);
    bool check2(int);
    int main(){
        scanf("%d%d",&n,&m);    
        lg2[0] = -1;
        for (int i = 1;i <= n;i++) lg2[i] = lg2[i/2]+1;
        o = lg2[n]+1;
        
        for (int i = 1;i < n;i++){
            int x,y;char ch;
            scanf("%d%d",&x,&y);
            ch = read();
            addedge(x,y,ch);
        }
        dfs(1,0);
        top[1] = 1;dfs3(1,0);
        
        for (int i = 1;i <= n;i++) if (fa[i][0]) addedge(fa[i][0],i,0);
        
        for (int i = 1;i <= o;i++){
            for (int j = 1;j <= n;j++){
                fa[j][i] = fa[fa[j][i-1]][i-1];
                if (fa[j][i]) addedge(fa[j][i],j,i);
            }
        }
        init();
        len = 1;root[0] = 1;build(1,1,n);
        dfs2(1,0);
        ch[1] = ' ';
        
        for (int i = 1;i <= m;i++){
            int x,y,z,u,v;
            int l,r,ans = 0,len,cntt = 0;
            
            scanf("%d%d",&x,&y);z = lca(x,y);
            scanf("%s",s);len = strlen(s);Len = len;
            u = x;v = y;
            for (int j = 16;j >= 0;j--){
                if (dep[fa[u][j]] >= dep[z]+len-1) u = fa[u][j];
                if (dep[fa[v][j]] >= dep[z]+len-1) v = fa[v][j];
            }    
            cntt = dep[u]-dep[z] + dep[v] - dep[z];//a 
            l = ef1(2,n+1);
            r = ef2(1,n);
            if (l <= r)ans += getans(root[u],root[x],1,n,min(n,l),max(2,r));
            reverse(s+0,s+len);
            l = ef1(2,n+1);
            r = ef2(1,n);
            if (l <= r)ans += getans(root[v],root[y],1,n,min(n,l),max(2,r));
            int t1 = 0,w1 = cntt-1;
            while (u != z){
                d[t1++] = ch[u];
                u = fat[u];
            }
            while (v != z){
                d[w1--] = ch[v];
                v = fat[v];
            }    
            reverse(s+0,s+len);
            w1 = cntt;
            next[0] = -1;t1 = -1;
            for (int j = 1;j < len;j++){
                t1++;
                while (t1 && s[t1] != s[j]) t1 = next[t1];
                if (s[t1] == s[j]) next[j] = t1;
                else next[j] = --t1;
            }
            int now = -1;
            for (int j = 0;j < w1;j++){
                while (now != -1 && s[now+1] != d[j]) now = next[now];
                if (s[now+1] == d[j]) now++;
                if (now == len-1) {ans++;now = next[now];}
            }
            printf("%d
    ",ans);
        }
        
        return 0;
    }
    void init(){
        static int x2[N],y2[N],a[N];
        int *x = x2,*y = y2,m = 256,cnt = -1;
        for (int i = 0;i <= m;i++) a[i] = 0;
        for (int i = 1;i <= n;i++) a[x[i] = ch[i]]++;
        for (int i = 1;i <= m;i++) a[i] += a[i-1];
        for (int i = 1;i <= n;i++) sa[a[ch[i]]--] = i;
        x[0] = -1;
        for (int k = 1;k <= n;k <<= 1){
            int p = 0;cnt++;
            for (int i = 0;i <= m;i++) a[i] = 0;
            for (int i = 1;i <= n;i++)
                if (fa[sa[i]][cnt] <= 1) y[++p] = sa[i];
            for (int i = 2;i <= n;i++)
                for (int h = first[sa[i]][cnt];h;h = edge[h].next){
                    int u = edge[h].end;
                    y[++p] = u;
                }
            for (int i = 1;i <= n;i++) a[x[y[i]]]++;
            for (int i = 1;i <= m;i++) a[i] += a[i-1];
            for (int i = n;i >= 1;i--) sa[a[x[y[i]]]--] = y[i];
            swap(x,y);
            p = 1;
            x[sa[1]] = p;
            for (int i = 2;i <= n;i++){
                int u = fa[sa[i]][cnt],v = fa[sa[i-1]][cnt];
                if (u == 1) u = 0;if (v == 1) v = 0;
                x[sa[i]] = (y[sa[i]] == y[sa[i-1]]) ? (u == 0 && v == 0 ? p : y[u] == y[v] ? p : ++p) : ++p;        
            }
            m = p;
            if (m >= n) break;
        }
        for (int i = 1;i <= n;i++) rank[sa[i]] = i;
    }
    void addedge(int x,int y,int z){
        edge[++efn].end = y;
        edge[  efn].next = first[x][z];
        first[x][z] = efn;
    }
    void dfs(int x,int y){
        fa[x][0] = y;
        siz[x] = 1;fat[x] = y;dep[x] = dep[y]+1;
        for (int h = first2[x];h;h = edge2[h].next){
            int u = edge2[h].end;
            if (u != y) {
                ch[u] = edge2[h].ch;
                dfs(u,x);
                siz[x] += siz[u];
                hson[x] = siz[u] > siz[hson[x]] ? u : hson[x];
            }
        }
    }
    void dfs3(int x,int y){
        if (hson[x]){
            top[hson[x]] = top[x];
            dfs3(hson[x],x);
        }
        for (int h = first2[x];h;h = edge2[h].next){
            int u = edge2[h].end;
            if (u != y && u != hson[x]){
                top[u] = u;
                dfs3(u,x);
            }
        }
    }
    void dfs2(int x,int y){
        root[x] = ++len;
        insert(root[y],root[x],1,n,rank[x]);
        for (int h = first[x][0];h;h = edge[h].next){
            int u = edge[h].end;
            dfs2(u,x);
        }
    }
    void build(int p,int l,int r){
        if (l == r) return;
        int mid = l + r >> 1;
        nod[p].son[0] = ++len;
        nod[p].son[1] = ++len;
        build(nod[p].son[0],l,mid);
        build(nod[p].son[1],mid+1,r);
    }
    void insert(int p,int q,int l,int r,int x){
        nod[q].cnt = nod[p].cnt+1;
        if (l == r) return;
        int mid = l + r >> 1;
        if (x <= mid) {
            nod[q].son[0] = ++len;
            nod[q].son[1] = nod[p].son[1];
            insert(nod[p].son[0],nod[q].son[0],l,mid,x);
        }
        else{
            nod[q].son[1] = ++len;
            nod[q].son[0] = nod[p].son[0];
            insert(nod[p].son[1],nod[q].son[1],mid+1,r,x);
        }
    }
    int lca(int x,int y){
        while (top[x] != top[y]){
            if (dep[top[x]] < dep[top[y]]) y = fat[top[y]];
            else x = fat[top[x]];
        }
        return dep[x] < dep[y] ? x : y;
    }
    void addedge(int x,int y,char ch){
        edge2[++efn2].end = y;
        edge2[  efn2].ch = ch;
        edge2[  efn2].next = first2[x];
        first2[x] = efn2;
        edge2[++efn2].end = x;
        edge2[  efn2].ch = ch;
        edge2[  efn2].next = first2[y];
        first2[y] = efn2;    
    }
    int ef1(int l,int r){
        int mid = l + r >> 1;
        while (l < r){
            if (check1(sa[mid])) r = mid;else l = mid+1;
            mid = l + r >> 1;
        }
        return l;
    }
    int ef2(int l,int r){
        int mid = l + r + 1 >> 1;
        while (l < r){
            if (check2(sa[mid])) l = mid;else r = mid-1;
            mid = l + r + 1 >> 1;
        }
        return l;    
    }
    bool check1(int p){
        for (int i = 0;i < Len;i++){
            if (p == 1) return 0;
            if (ch[p] < s[i]) return 0;
            if (ch[p] > s[i]) return 1;
            p = fat[p];
        }
        return 1;
    }
    bool check2(int p){
        for (int i = 0;i < Len;i++){
            if (p == 1) return 1;
            if (ch[p] < s[i]) return 1;
            if (ch[p] > s[i]) return 0;
            p = fat[p];
        }
        if (p == 1) return 1;
        return 1;//b
    }
    int getans(int p,int q,int l,int r,int x,int y){
        if (y < x) return 0;
        if (l == x && r == y) return nod[q].cnt - nod[p].cnt;
        int mid = l + r >> 1;
        if (y <= mid) return getans(nod[p].son[0],nod[q].son[0],l,mid,x,y);
        else if (mid < x) return getans(nod[p].son[1],nod[q].son[1],mid+1,r,x,y);
        else return getans(nod[p].son[0],nod[q].son[0],l,mid,x,mid)+getans(nod[p].son[1],nod[q].son[1],mid+1,r,mid+1,y);
    }
  • 相关阅读:
    【STM32】串行通信原理
    【STM32】NVIC中断优先级管理
    【Linux-驱动】RTC设备驱动架构
    【Linux-驱动】在sysfs下创建对应的class节点---class_create
    【Linux-驱动】将cdev加入到系统中去---cdev_add
    【Linux-驱动】简单字符设备驱动结构和初始化
    【Linux-驱动】printk的打印级别
    【Linux 网络编程】REUSADDR
    【Linux 网络编程】常用TCP/IP网络编程函数
    linux定时重启tomcat脚本
  • 原文地址:https://www.cnblogs.com/victbr/p/6510725.html
Copyright © 2020-2023  润新知