题目描述
给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m的模式串s,其中每一位仍然是A到z的大写字母。
Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径形成的字符串可以由模式串S重复若干次得到?
这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分。
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠)。例如当S=PLUS
的时候,重复两次会得到PLUSPLUS
,重复三次会得到PLUSPLUSPLUS
,同时要注恿,重复必须是整数次的。例如当S=XYXY
时,因为必须重复整数次,所以XYXYXY
不能看作是S重复若干次得到的。
输入输出格式
输入格式:
每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(第i个字符对应了第i个结点)。
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,为模式串S。
输出格式:
给出C行,对应C组测试。
每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模式串的若干次重复.
输入输出样例
说明
1<=C<=10,3<=N<=1000000,3<=M<=1000000
题解
题解大概看懂一点了……就是说用hash+点分治……好讨厌hash……总感觉还是半懂不懂……
考虑每一个分治点,从他延伸下去能形成长度为多少的前缀和后缀(不包含自己和包含自己),然后两个两两组合起来计算答案
据说时间复杂度$O(Tnlogn)$,数据就是为了卡点分的,然而因为全世界都只有前三组数据……所以能A……
1 //minamoto 2 #include<cstdio> 3 #include<iostream> 4 #include<cstring> 5 #define N 1000003 6 #define ull unsigned long long 7 #define ll long long 8 #define p 2000001001 9 #define inf 1000000000 10 using namespace std; 11 template<class T>inline bool cmax(T&a,const T&b){return a<b?a=b,1:0;} 12 inline int read(){ 13 #define num ch-'0' 14 char ch;bool flag=0;int res; 15 while(!isdigit(ch=getchar())) 16 (ch=='-')&&(flag=true); 17 for(res=num;isdigit(ch=getchar());res=res*10+num); 18 (flag)&&(res=-res); 19 #undef num 20 return res; 21 } 22 char sr[1<<21],z[30];int C=-1,Z; 23 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;} 24 inline void print(ll x){ 25 if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x; 26 while(z[++Z]=x%10+48,x/=10); 27 while(sr[++C]=z[Z],--Z);sr[++C]=' '; 28 } 29 int n,m,T,n1,rt,head[N],Next[N<<1],ver[N<<1],tot,size[N],son[N],sz,sz1; 30 int st[N],st1[N],len[N],cnt[N],cnt1[N]; 31 ull mi[N],a[N],a1[N],b[N],c[N],val[N],sum[N],sum1[N]; 32 ll ans; 33 bool vis[N];char s[N]; 34 inline void add(int u,int v){ 35 ver[++tot]=v,Next[tot]=head[u],head[u]=tot; 36 ver[++tot]=u,Next[tot]=head[v],head[v]=tot; 37 } 38 void findrt(int u,int fa){ 39 size[u]=1,son[u]=0; 40 for(int i=head[u];i;i=Next[i]){ 41 int v=ver[i]; 42 if(vis[v]||v==fa) continue; 43 findrt(v,u); 44 cmax(son[u],size[v]); 45 size[u]+=size[v]; 46 } 47 cmax(son[u],n1-size[u]); 48 if(son[u]<son[rt]) rt=u; 49 } 50 void getdep(int u,int fa){ 51 if(b[len[u]]==sum[u]&&val[u]==a[1]) st[++sz]=u; 52 for(int i=head[u];i;i=Next[i]){ 53 int v=ver[i];if(vis[v]||v==fa) continue; 54 sum[v]=sum[u]*p+val[v]; 55 len[v]=len[u]+1; 56 getdep(v,u); 57 } 58 } 59 void getdep1(int u,int fa){ 60 if(c[len[u]]==sum1[u]&&val[u]==a1[1]) st1[++sz1]=u; 61 for(int i=head[u];i;i=Next[i]){ 62 int v=ver[i]; 63 if(vis[v]||v==fa) continue; 64 sum1[v]=sum1[u]*p+val[v]; 65 getdep1(v,u); 66 } 67 } 68 void calc(int u){ 69 for(int i=0;i<=m;++i) cnt[i]=cnt1[i]=0; 70 if(a[1]==val[u]) cnt[1]=1; 71 if(a[m]==val[u]) cnt1[1]=1; 72 if(m==1) ans+=cnt1[1]; 73 for(int i=head[u];i;i=Next[i]){ 74 int v=ver[i]; 75 if(vis[v]) continue; 76 sz=0,len[v]=1,sum[v]=val[v]; 77 getdep(v,u); 78 for(int j=1;j<=sz;++j){ 79 int t=st[j];int pos=m-(len[t]-1)%m-1; 80 if(pos==0) pos+=m; 81 ans+=(ll)cnt1[pos]; 82 } 83 sz1=0,sum1[v]=val[v]; 84 getdep1(v,u); 85 for(int j=1;j<=sz1;++j){ 86 int t=st1[j];int pos=m-(len[t]-1)%m-1; 87 if(pos==0) pos+=m; 88 ans+=(ll)cnt[pos]; 89 } 90 for(int j=1;j<=sz;++j){ 91 int t=st[j];int pos=(len[t])%m+1; 92 if(val[u]==a[pos]) ++cnt[pos]; 93 } 94 for(int j=1;j<=sz1;++j){ 95 int t=st1[j];int pos=(len[t])%m+1; 96 if(val[u]==a1[pos]) ++cnt1[pos]; 97 } 98 } 99 } 100 void solve(int u){ 101 calc(u),vis[u]=1;int totsz=size[u]; 102 for(int i=head[u];i;i=Next[i]){ 103 int v=ver[i]; 104 if(vis[v]) continue; 105 rt=0; 106 n1=size[v]; 107 if(n1<m) continue; 108 findrt(v,u); 109 solve(rt); 110 } 111 } 112 int main(){ 113 T=read(),mi[0]=1; 114 for(int i=1;i<=1000000;++i) mi[i]=mi[i-1]*p; 115 while(T--){ 116 n=read(),m=read(),tot=0,ans=0; 117 memset(head,0,sizeof(head)); 118 scanf("%s",s+1); 119 for(int i=1;i<=n;++i) val[i]=s[i]-'A'+1; 120 for(int i=1;i<n;++i){ 121 int u=read(),v=read();add(u,v); 122 } 123 scanf("%s",s+1); 124 for(int i=1;i<=max(n,m);++i) a[i]=s[(i-1)%m+1]-'A'+1; 125 for(int i=1;i<=max(n,m);++i) b[i]=b[i-1]+a[i]*mi[i-1]; 126 for(int i=1;i<=m;++i) a1[m-i+1]=a[i]; 127 for(int i=1;i<=max(n,m);++i) a1[i]=a1[(i-1)%m+1]; 128 for(int i=1;i<=max(n,m);++i) c[i]=c[i-1]+a1[i]*mi[i-1]; 129 memset(vis,0,sizeof(vis)); 130 son[0]=inf,rt=0,n1=n; 131 findrt(1,0); 132 solve(rt); 133 print(ans); 134 } 135 Ot(); 136 return 0; 137 }