链接:https://nanti.jisuanke.com/t/39277
思路:
一开始看着很像树分治,就用树分治写了下,发现因为异或操作的特殊性,我们是可以优化树分治中的容斥操作的,不合理的情况只有当两点在一条链上才存在,那么直接一遍dfs从根节点向下跑途中维护一下前缀和,把所有情况中不合理情况造成的值修正。
这样的话时间复杂度就可以降得非常低了,感觉还可以优化,但是懒得写了
代码耗时:142ms.
实现代码:
#include<bits/stdc++.h> using namespace std; #define ll long long const ll M = 2e5+10; const ll inf = 1e18+10; struct node{ ll to,next,w; }e[M]; const ll mod = 1000000007; struct node1{ ll num,id; }Xor[M]; bool cmp(node1 x,node1 y){ return x.num < y.num; } vector<ll>mp[M],v[M]; ll cnt,n,ans; ll head[M],sz[M],d[M],md[M]; void add(ll u,ll v,ll w){ e[++cnt].to = v;e[cnt].w = w;e[cnt].next = head[u];head[u] = cnt; } map<ll,ll>sum,sum1,num; void get_dis(ll u,ll fa){ Xor[++Xor[0].num].num = d[u]; Xor[Xor[0].num].id = u; for(ll i = head[u];i;i=e[i].next){ ll v = e[i].to; if(v != fa){ d[v] = d[u]^e[i].w; get_dis(v,u); } } return ; } void get_siz(ll u,ll fa){ sz[u] = 1; for(ll i = head[u];i;i=e[i].next){ ll v = e[i].to; if(v != fa){ get_siz(v,u); sz[u] += sz[v]; } } } void gcd(ll a,ll b,ll &d,ll &x,ll &y) { if(!b) {d=a;x=1;y=0;} else {gcd(b,a%b,d,y,x);y-=x*(a/b);} } ll finv(ll a,ll n) { ll d,x,y; gcd(a,n,d,x,y); return d==1?(x+n)%n:-1; } void cal(ll u){ d[u] = 0; Xor[0].num = 0; get_dis(u,0); sort(Xor+1,Xor+1+Xor[0].num,cmp); ll st = -1,idx = 0; for(ll i = 1;i <= Xor[0].num;i ++){ if(Xor[i].num != st){ st = Xor[i].num; mp[++idx].push_back(Xor[i].id); md[idx] = st; } else{ mp[idx].push_back(Xor[i].id); } } ans = 0; for(ll i = 1;i <= idx;i ++){ ll num1 = 0,num2 = 0; for(ll j = 0;j < mp[i].size();j ++){ num1 += sz[mp[i][j]]; num2 += sz[mp[i][j]]*sz[mp[i][j]]%mod; num1%=mod; num2%=mod; } ans += ((num1*num1%mod+mod - num2)%mod)*finv(2,mod)%mod; ans %= mod; } for(ll i = 1;i <= idx;i ++) mp[i].clear(); } void dfs(ll u,ll fa){ for(ll i = head[u];i;i=e[i].next){ ll v = e[i].to; if(v == fa) continue; sum1[d[u]] += (n - sz[v]+mod)%mod; if(num[d[v]] >= 1){ ans = (ans + mod - (sz[v]*sum[d[v]]%mod))%mod; ans += sz[v]*sum1[d[v]]%mod; ans %= mod; } sum[d[v]] += sz[v]; num[d[v]] += 1; sum[d[v]]%=mod; sum1[d[v]]%=mod; dfs(v,u); sum[d[v]] -= sz[v]-mod; sum1[d[u]] -= (n-sz[v])-mod; sum[d[v]]%=mod; sum1[d[v]]%=mod; num[d[v]] -= 1; } } int main() { ll v,w; scanf("%lld",&n); for(ll i = 2;i <= n;i ++){ scanf("%lld%lld",&v,&w); add(i,v,w); add(v,i,w); } get_siz(1,0); cal(1); sum[0] += sz[1]; num[0] += 1; dfs(1,0); ans %= mod; num.clear(); sum.clear(); sum1.clear(); printf("%lld ",ans); }