Problem Digit Tree
题目大意
给一棵树,有边权1~9。
询问有多少个点对(i,j),将i--j路径上的数字依次连接后所形成新数字可以被k整除。gcd(K,10)=1
解题分析
点分治。考虑某一次分治,根为rt,求出所有子节点到根所形成数字为A,根到所有子节点所形成数字为B。
那么即求出所有满足 ( A[i] * 10 ^ B[j] . len + B[j] ) % K == 0的点对。
转化后为 A[i] == (K - B[j]) * 10 ^ ( - B[j] . len ) , 用map处理一下即可。
参考程序
1 #include <map> 2 #include <set> 3 #include <stack> 4 #include <queue> 5 #include <cmath> 6 #include <ctime> 7 #include <string> 8 #include <vector> 9 #include <cstdio> 10 #include <cstdlib> 11 #include <cstring> 12 #include <cassert> 13 #include <iostream> 14 #include <algorithm> 15 #pragma comment(linker,"/STACK:102400000,102400000") 16 using namespace std; 17 18 #define N 100008 19 #define LL long long 20 #define lson l,m,rt<<1 21 #define rson m+1,r,rt<<1|1 22 #define clr(x,v) memset(x,v,sizeof(x)); 23 #define bitcnt(x) __builtin_popcount(x) 24 #define rep(x,y,z) for (int x=y;x<=z;x++) 25 #define repd(x,y,z) for (int x=y;x>=z;x--) 26 const int mo = 1000000007; 27 const int inf = 0x3f3f3f3f; 28 const int INF = 2000000000; 29 /**************************************************************************/ 30 31 int n,m,phi; 32 int lt[N],size[N],f[N],vis[N],a[N],b[N],sum,tot,root,pre1[N],pre2[N]; 33 LL ans,Alltmp; 34 map <int,int> mp; 35 struct edge{int u,v,w,nt;}eg[N*2]; 36 void add(int u,int v,int w){ 37 eg[++sum]=(edge){u,v,w,lt[u]}; lt[u]=sum; 38 } 39 LL quick(LL x,LL y) 40 { 41 LL res=1; 42 while (y) 43 { 44 if (y & 1) res=res*x % m; 45 x=x*x % m; 46 y>>=1; 47 } 48 return res; 49 } 50 void init() 51 { 52 int mm=m; 53 phi=m; 54 for (int i=2;i*i<=m;i++) 55 { 56 if (mm % i==0) 57 { 58 while (mm % i==0) mm/=i; 59 phi=phi/i*(i-1); 60 } 61 } 62 if (mm>1) phi=phi/mm*(mm-1); 63 clr(lt,0); sum=1; 64 clr(f,0); f[0]=INF; 65 clr(vis,0); 66 ans=0; 67 pre1[0]=1; pre2[0]=1; 68 rep(i,1,100000) 69 { 70 pre1[i]=1ll*pre1[i-1]*10 % m; 71 pre2[i]=quick(pre1[i],phi-1); 72 } 73 } 74 void getRoot(int u,int fa) 75 { 76 size[u]=1; f[u]=0; 77 for (int i=lt[u];i;i=eg[i].nt) 78 { 79 int v=eg[i].v; 80 if (vis[v] || v==fa) continue; 81 getRoot(v,u); 82 size[u]+=size[v]; 83 f[u]=max(f[u],size[v]); 84 } 85 f[u]=max(f[u],tot-size[u]); 86 if (f[u]<f[root]) root=u; 87 } 88 void getA(int u,int fa,int len) 89 { 90 for (int i=lt[u];i;i=eg[i].nt) 91 { 92 int v=eg[i].v; 93 if (vis[v] || v==fa) continue; 94 a[v]=(1ll*eg[i].w*pre1[len]+a[u]) % m; 95 getA(v,u,len+1); 96 } 97 Alltmp+=mp[a[u]]; 98 int tp=(m-b[u]) % m; 99 tp=1ll*tp*pre2[len] % m; 100 if (a[u]==tp) Alltmp--; 101 } 102 void getB(int u,int fa,int len) 103 { 104 int tp=(m-b[u]) % m; 105 tp=1ll*tp*pre2[len]% m; 106 mp[tp]++; 107 for (int i=lt[u];i;i=eg[i].nt) 108 { 109 int v=eg[i].v; 110 if (vis[v] || v==fa) continue; 111 b[v]=(1ll*b[u]*10+eg[i].w) % m; 112 getB(v,u,len+1); 113 } 114 } 115 LL calc(int u,int key) 116 { 117 key%=m; 118 Alltmp=0; mp.clear(); 119 a[u]=b[u]=key; 120 getB(u,0,key==0?0:1); getA(u,0,key==0?0:1); 121 return Alltmp; 122 } 123 void solve(int u) 124 { 125 ans+=calc(u,0); vis[u]=1; 126 for (int i=lt[u];i;i=eg[i].nt) 127 { 128 int v=eg[i].v; 129 if (vis[v]) continue; 130 ans-=calc(v,eg[i].w); 131 root=0; tot=size[v]; 132 getRoot(v,u); 133 solve(root); 134 } 135 } 136 int main() 137 { 138 scanf("%d%d",&n,&m); 139 init(); 140 rep(i,1,n-1) 141 { 142 int u,v,w; 143 scanf("%d%d%d",&u,&v,&w); 144 add(u+1,v+1,w); 145 add(v+1,u+1,w); 146 } 147 root=0; tot=n; 148 getRoot(1,0); 149 solve(root); 150 cout<<ans<<endl; 151 }