正题
题目链接:http://www.ybtoj.com.cn/contest/121/problem/2
题目大意
给出\(n\)个点的一棵树,每个点有一个权值\(a_i\),求
\[\sum_{i=1}^n\sum_{j=1}^ndis(i,j)\times \varphi(a_i\times a_j)
\]
\(2\leq n\leq 2\times 10^5\),\(a\)恰好是一个排列。
解题思路
一个十分显然的结论就是\(\varphi(x\times y)=\varphi(x)\times \varphi(y)\times \frac{gcd(x,y)}{\varphi(gcd(x,y))}\)。(相同的质因子只保留一个数\(p-1\)的就好了)
然后顺便把点编号换一下使得\(a_i=i\)再枚举约数就是
\[\sum_{d=1}^n\frac{\varphi(d)}{d}\sum_{i=1}^n\sum_{j=1}^{n}dis(i,j) \varphi(i)\varphi(j)\times [gcd(i,j)=d]
\]
然后就可以莫反了,定义
\[g_d=\sum_{d|i}^n\sum_{d|j}^ndis(i,j)\varphi(i)\varphi(j)
\]
\[g_d=\sum_{d|i}^n\sum_{d|j}^n(dep_i+dep_j-2dep_{lca(i,j)})\varphi(i)\varphi(j)
\]
\[g_d=2\sum_{d|i}^ndep_{i}\varphi(i)\sum_{d|j}^n\varphi(j)-2\sum_{k=1}^n\sum_{i=1}^n\sum_{j=1}^n[lca(i,j)=k]dep_{k}\varphi(i)\varphi(j)
\]
把所有\(d\)倍的点加入虚树,然后用树形\(dp\)计算后面那个东西,前面那个可以直接算。
然后答案就是
\[\sum_{d=1}^n\frac{\varphi(d)}{d}\sum_{d|i}g_i
\]
时间复杂度\(O(n\log^2 n)\),有点卡常。
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cctype>
#define ll long long
#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
using namespace std;
const int N=4e5+10,T=20,P=1e9+7;
int read() {
int x=0,f=1; char c=getchar();
while(!isdigit(c)) {if(c=='-')f=-f;c=getchar();}
while(isdigit(c)) x=(x<<1)+(x<<3)+c-48,c=getchar();
return x*f;
}
struct node{
int to,next;
}a[N<<1];
int n,m,tot,top,dfc,ls[N],rfl[N];
int mu[N],phi[N],pri[N],prn,s[N],p[N];
int stn,lg[N],wz[N],rfn[N],dep[N],f[N][T];
ll S[N],dp[N],g[N],ans;bool v[N],mark[N];
ll power(ll x,ll b){
ll ans=1;
while(b){
if(b&1)ans=ans*x%P;
x=x*x%P;b>>=1;
}
return ans;
}
void prime(){
mu[1]=phi[1]=1;
for(int i=2;i<=n;i++){
if(!v[i])pri[++prn]=i,phi[i]=i-1,mu[i]=-1;
for(int j=1;j<=prn&&i*pri[j]<=n;j++){
v[i*pri[j]]=1;
if(i%pri[j]==0){
phi[i*pri[j]]=phi[i]*pri[j];
break;
}
phi[i*pri[j]]=phi[i]*(pri[j]-1);
mu[i*pri[j]]=-mu[i];
}
}
return;
}
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
bool cmp(int x,int y)
{return rfn[x]<rfn[y];}
void dfs(int x,int fa){
dep[x]=dep[fa]+1;rfn[x]=++dfc;
f[++stn][0]=x;wz[x]=stn;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
dfs(y,x);f[++stn][0]=x;
}
return;
}
int LCA(int l,int r){
l=wz[l];r=wz[r];
if(l>r)swap(l,r);
int z=lg[r-l+1],x=f[l][z],y=f[r-(1<<z)+1][z];
return dep[x]<dep[y]?x:y;
}
void Ins(int x){
if(!top){s[++top]=x;return;}
int lca=LCA(x,s[top]);
while(top>1&&dep[s[top-1]]>dep[lca])
addl(s[top-1],s[top]),top--;
if(dep[s[top]]>dep[lca])addl(lca,s[top]),top--;
if((!top)||s[top]!=lca)s[++top]=lca;
s[++top]=x;return;
}
void calc(int x,ll &ans){
if(mark[x])S[x]=phi[x],dp[x]=1ll*phi[x]*phi[x]%P;
else S[x]=dp[x]=0;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;calc(y,ans);
(dp[x]+=S[x]*S[y]*2ll%P)%=P;
S[x]=(S[x]+S[y])%P;
}
(ans+=P-1ll*dp[x]*dep[x]%P)%=P;
ls[x]=mark[x]=0;return;
}
signed main()
{
freopen("sm.in","r",stdin);
freopen("sm.out","w",stdout);
n=read();prime();
for(int i=1;i<=n;i++){
int x=read();
rfl[i]=x;
}
for(int i=1;i<n;i++){
int x=read(),y=read();
x=rfl[x];y=rfl[y];
addl(x,y);addl(y,x);
}
dfs(1,1);
for(int j=1;(1<<j)<=stn;j++)
for(int i=1;i+(1<<j)-1<=stn;i++){
int x=f[i][j-1],y=f[i+(1<<j-1)][j-1];
f[i][j]=(dep[x]<dep[y])?x:y;
}
for(int i=2;i<=stn;i++)lg[i]=lg[i>>1]+1;
memset(ls,0,sizeof(ls));
for(int k=1;k<=n;k++){
m=top=tot=0;ll sum=0;
for(int i=k;i<=n;i+=k)
p[++m]=i,sum+=phi[i];
sort(p+1,p+1+m,cmp);sum%=P;
if(p[1]!=1)s[++top]=1;
for(int i=1;i<=m;i++){
Ins(p[i]);mark[p[i]]=1;
(g[k]+=1ll*phi[p[i]]*dep[p[i]]%P*sum%P)%=P;
}
while(top>1)addl(s[top-1],s[top]),top--;
calc(1,g[k]);g[k]=g[k]*2ll%P;
}
for(int i=1;i<=n;i++){
ll tmp=0;
for(int j=i;j<=n;j+=i)
(tmp+=mu[j/i]*g[j]%P)%=P;
(ans+=tmp*i%P*power(phi[i],P-2)%P)%=P;
}
printf("%d\n",(ans+P)%P*power(1ll*n*(n-1)%P,P-2)%P);
return 0;
}