【题解】Digit Tree
呵呵以为是数据结构题然后是淀粉质还行...
题目就是给你一颗有边权的树,问你有多少路径,把路径上的数字顺次写出来,是(m)的倍数。
很明显可以点分治嘛,我们可以按照图上的样子,把一条路径本来是(12345678)的路径,变成(1234|5678),我们记录图中左边的那种路径为(f)(往根),右边的那种路径为(g)(从根),记右边的那种到分治中心的深度为(d),那么这条路径就可以被表示成(f imes 10^d+g),条件就变成了
[f imes 10^d +gequiv 0
\
f imes 10^d equiv -g
\
f equiv -g imes 10^{-d}
]
我们把坐边压到一个(map)里面,每次分治时拿右边直接枚举就好了,然后还要用第二个(map)去掉同一颗子树内的非法情况,具体实现看代码。
由于处理这个(f,g)真的很难(博主搞了好久,自己都晕了),所以代码里的(f,g)可能是反的...
不觉得难的可以自己去试试,如果你真的没晕的话..收下我的膝盖orz
咱们把(map)看做一个(log),时间复杂度就是(O(n log^2n))的
#include<bits/stdc++.h>
using namespace std; typedef long long ll;
template < class ccf > inline ccf qr(ccf ret){ ret=0;
register char c=getchar();
while(not isdigit(c)) c=getchar();
while(isdigit(c)) ret=ret*10+c-48,c=getchar();
return ret;
}
const int maxn=1e5+5;
typedef pair < int , ll > P;
vector < P > e[maxn];
vector < int > ve;
#define pb push_back
#define st first
#define nd second
#define mk make_pair
inline void add(int fr,int to,int w){
e[fr].pb(mk(to,w));
e[to].pb(mk(fr,w));
}
int sum;
int siz[maxn];
int d0[maxn];//深度
int f[maxn];
int g[maxn];
int rt;
int spc[maxn];
int inv[maxn];
int ten[maxn];
bool usd[maxn];
int n,mod;
map < int , int > mp,un;
ll ans;
void dfsrt(const int&now){//重心
usd[now]=1;
siz[now]=spc[now]=1;
for(auto t:e[now])
if(not usd[t.first]){
dfsrt(t.st);
siz[now]+=siz[t.st];
if(siz[t.st]>spc[now])spc[now]=siz[t.st];
}
spc[now]=max(spc[now],sum-siz[now]);
if(spc[now]<spc[rt]|| not rt) rt=now;
usd[now]=0;
}
void dfsd(const int&now,const int& last,const int&w){//dis
usd[now]=1;
d0[now]=d0[last]+1;
g[now]=(g[last]+1ll*ten[d0[last]]*w%mod)%mod;
f[now]=(f[last]*10ll%mod+w)%mod;
//printf("now=%d d0=%d f=%d g=%d
",now-1,d0[now],f[now],g[now]);
ans+=(f[now]==0)+(g[now]==0);
++un[g[now]];
++mp[g[now]];
ve.pb(now);
for(auto t:e[now])
if(not usd[t.st])
dfsd(t.st,now,t.nd);
usd[now]=0;
}
inline void calc(const int&now){
d0[now]=f[now]=g[now]=0;
ve.clear();mp.clear();
int k=0;
for(auto t:e[now])
if(not usd[t.st]){
un.clear();
dfsd(t.st,now,t.nd);
register int edd=ve.size();
while(k<edd){
register int it=ve[k];
register int p=1ll*(((mod-f[it])%mod+mod)%mod)*inv[d0[it]]%mod;
if(un.find(p)!=un.end())
ans-=un[p];
++k;
}
}
for(auto t:ve){
register int p=1ll*(((mod-f[t])%mod+mod)%mod)*inv[d0[t]]%mod;
if(mp.find(p)!=mp.end())
/*cout<<"?qaq="<<t-1<<' '<<p<<endl;*/
ans+=mp[p];
}
}
void divd(const int&now){
usd[now]=1;calc(now);
for(auto t:e[now])
if(not usd[t.st]){
sum=siz[t.st];rt=0;
dfsrt(t.st);
divd(rt);
}
}
void exgcd(int a,int b,int&d,int&x,int&y){
if(!b) d=a,x=1,y=0;
else exgcd(b,a%b,d,y,x),y-=x*(a/b);
}
int Inv(const int&a, const int&p){
int d,x,y;
exgcd(a,p,d,x,y);
return d==1?(x+p)%p:-1;
}
int main(){
sum=n=qr(1);mod=qr(1);
if(mod==1)return cout<<1ll*n*(n-1)<<endl,0;
inv[0]=ten[0]=1;
ten[1]=10;
inv[1]=Inv(10,mod);
if(inv[1]==-1)return -1;
for(register int t=2;t<=n+1;++t)
ten[t]=1ll*ten[t-1]*ten[1]%mod,inv[t]=1ll*inv[t-1]*inv[1]%mod;
for(register int t=1,t1,t2,t3;t< n;++t){
t1=qr(1)+1;t2=qr(1)+1;t3=qr(1);
add(t1,t2,t3);
}
dfsrt(1);
divd(rt);
cout<<ans<<endl;
return 0;
}