• CF809E 【Surprise me!】


    题意:

    给一颗树,求:

    [frac{1}{n(n-1)}sum_{i=1}^nsum_{j=1}^nvarphi(a_i*a_j)·dist(i,j) ]

    (10^9+7)取模

    首先我们知道

    [varphi(x)=x*prod_{p|x}(1-dfrac{1}{p}) ]

    所以就有:

    [varphi(x)*varphi(y)=x*y*prod_{p|x}(1-dfrac{1}{p})*prod_{p|y}(1-dfrac{1}{p}) ]

    (k=gcd(x,y))

    我们假装让(varphi(x)*varphi(y)=varphi(xy)=x*y*prod_{p|xy}(1-dfrac{1}{p}))

    不难发现有少乘的嫌疑...

    被少计算的是(gcd)质因数分解后的那些因数,所以我们可以得出:

    [varphi(xy)=dfrac{varphi(x)*varphi(y)*gcd(x,y)}{varphi(gcd(x,y))} ]

    忽略掉前面的系数,我们知道我们要求的是:

    [sum_{i=1}^nsum_{j=1}^ndfrac{varphi(a_i)*varphi(a_j)*gcd(a_i,a_j)}{varphi(gcd(a_i,a_j))}·dist(i,j) ]

    [sum_{i=1}^nsum_{j=1}^ndfrac{varphi(a_i)*varphi(a_j)*gcd(a_i,a_j)}{varphi(gcd(a_i,a_j))}(dep_x+dep_y-dep_{lca}) ]

    这样做显然很不好做

    我们考虑枚举(gcd),可以发现每个(gcd)对答案造成的贡献固定

    于是就有:

    [sum_{d=1}^ndfrac{d}{varphi(d)}sum_{i=1}^nsum_{j=1}^nvarphi(a_i)varphi(a_j)(gcd(a_i,a_j)==d)dist(i,j) ]

    看到了一个恰好等于...

    我们考虑记

    [f(d)=sum_{i=1}^nsum_{j=1}^nvarphi(a_i)varphi(a_j)(gcd(a_i,a_j)==d)dist(i,j) ]

    假装有:

    [F(x)=sum_{x|d}f(d) ]

    我们考虑求解(F(x))

    不难发现

    [F(x)=sum_{i=1}^nsum_{j=1}^nvarphi(a_i)varphi(a_j)[x|gcd(a_i,a_j)]dist(i,j) ]

    且有:

    [f(x)=sum_{x|d}F(d)*mu(dfrac{d}{x}) ]

    可以发现涉及到的点可以直接枚举出来,而且因为(a)(1-n)的排列,所以涉及到的点的总数应该是可以由调和级数得到为(O(nlog n))

    于是我们枚举(x),对于所有(x|d)的点拿下来建虚树,中间那个艾弗森括号就可以忽略了

    现在式子变成,给的(m)个点,点有点权,求:

    [sum_{i=1}^msum_{j=1}^mv_i*v_j*dist(i,j) ]

    [sum_{i=1}^msum_{j=1}^mv_i*v_j*(dep_i+dep_j-2*dep_{lca}) ]

    (s=v_i*v_j)

    [sum_{i=1}^msum_{j=1}^ms*dep_i+s*dep_j-2*s*dep_{lca} ]

    我们将其拆成(3)(sum)的形式

    [sum_{i=1}^msum_{j=1}^mv_i*v_j*dep_i+sum_{i=1}^msum_{j=1}^mv_i*v_j*dep_j-2sum_{i=1}^msum_{j=1}^mv_i*v_j*dep_{lca} ]

    不难发现前面两个是相同的

    [2sum_{i=1}^mv_i*dep_isum_{j=1}^mv_j-2*sum_{i=1}^msum_{j=1}^mv_i*v_j*dep_{lca} ]

    前面那个显然是可以(O(m))预处理出来的

    于是问题的关键就在于后面那一坨

    我们考虑在虚树上(dp),可以发现只需要在(lca)处合并答案即可,所以记(dp_i)表示(i)内部的子树权值总和即可,复杂度同样(O(m))

    当然建立虚树的复杂度是带(log)的,所以总复杂度为(O(nlog^2n))

    qwq

    #include<bits/stdc++.h>
    using namespace std;
    #define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
    #define rep( i, s, t ) for( register int i = s; i <= t; ++ i )
    #define re register
    #define int long long
    int gi() {
    	char cc = getchar(); int cn = 0, flus = 1;
    	while(cc < '0' || cc > '9') {  if( cc == '-' ) flus = -flus;  cc = getchar();  }
    	while(cc >= '0' && cc <= '9')  cn = cn * 10 + cc - '0', cc = getchar();
    	return cn * flus;
    }
    const int N = 4e5 + 5 ; 
    const int P = 1e9 + 7 ;
    int n, cnt, head[N], w[N], rev[N], vis[N], dp[N], Ans ; 
    int phi[N], mu[N], p[N], F[N], f[N], top, s[N] ; 
    int dfn[N], Fa[N], dep[N], sz[N], Son[N], Top[N], idnex ; 
    bool isp[N] ; 
    struct E {
    	int to, next ; 
    } e[N * 2] ;
    void add( int x, int y ) {
    	e[++ cnt] = (E){ y, head[x] }, head[x] = cnt , 
    	e[++ cnt] = (E){ x, head[y] }, head[y] = cnt ; 
    }
    void init() {
    	mu[1] = 1, phi[1] = 1 ; rep( i, 2, n ) {
    		if( !isp[i] ) phi[i] = i - 1, mu[i] = -1, p[++ top] = i ;
    		for( re int j = 1; j <= top && p[j] * i <= n; ++ j ) {
    			isp[i * p[j]] = 1 ; 
    			if( i % p[j] == 0 ) { phi[i * p[j]] = phi[i] * p[j] ; break ; } 
    			phi[i * p[j]] = phi[i] * phi[p[j]], mu[i * p[j]] = - mu[i] ; 
    		}
    	}
    }
    inline void dfs( int x, int fa ) {
    	dfn[x] = ++ idnex, Fa[x] = fa, dep[x] = dep[fa] + 1, sz[x] = 1 ; 
    	Next( i, x ) {
    		int v = e[i].to ; if( v == fa ) continue ; 
    		dfs( v, x ), sz[x] += sz[v] ;
    		if( sz[v] > sz[Son[x]] ) Son[x] = v ; 
    	}
    }
    inline void dfs2( int x, int high ) {
    	Top[x] = high ; 
    	if( Son[x] ) dfs2( Son[x], high ) ;
    	Next( i, x ) {
    		int v = e[i].to ; if( v == Fa[x] || v == Son[x] ) continue ; 
    		dfs2( v, v ) ; 
    	}
    }
    inline int LCA( int x, int y ) {
    	while( Top[x] != Top[y] ) {
    		if( dep[Top[x]] < dep[Top[y]] ) swap( x, y ) ;
    		x = Fa[Top[x]] ; 
    	}
    	return ( dep[x] < dep[y] ) ? x : y ; 
    }
    inline void Dp( int x, int fa ) {
    	if( vis[x] ) Ans = ( Ans + 2 * phi[w[x]] * phi[w[x]] * dep[x] % P ) % P, dp[x] = phi[w[x]] ; 
    	Next( i, x ) {
    		int v = e[i].to ; if( v == fa ) continue ; 
    		Dp( v, x ), Ans = ( Ans + 4 * dp[x] * dp[v] % P * dep[x] % P ) % P, 
    		dp[x] = ( dp[x] + dp[v] ) % P, dp[v] = 0 ;
    	}
    	head[x] = 0 ; 
    }
    int fpow( int x, int k ) {
    	int ans = 1, base = x ; 
    	while( k ) {
    		if( k & 1 ) ans = ( ans * base ) % P ;
    		base = ( base * base ) % P, k >>= 1 ;
    	}
    	return ans ;  
    }
    int K[N], tot ; 
    bool cmp( int x, int y ) {
    	return dfn[x] < dfn[y] ; 
    } 
    void insert( int x ) {
    	if( top == 1 && x != 1 ) { s[++ top] = x; return ; }
    	int lca = LCA( s[top], x ) ; if( lca == x ) return ; 
    	while( top > 1 && dfn[s[top - 1]] > dfn[lca] ) add( s[top - 1], s[top] ), -- top ; 
    	if( dfn[lca] < dfn[s[top]] ) add( lca, s[top] ), -- top ; 
    	if( dfn[lca] > dfn[s[top]] ) s[++ top] = lca ; 
    	s[++ top] = x ; 
    }
    void input() {
    	n = gi(), init(), top = 0 ;
    	rep( i, 1, n ) w[i] = gi(), rev[w[i]] = i ; 
    	rep( i, 2, n ) add( gi(), gi() ) ; 
    	dfs( 1, 1 ), dfs2( 1, 1 ), memset( head, 0, sizeof(head) ) ;
    }
    void solve() {
    	rep( i, 1, n ) {
    		if( i > n / 2 ) break ; 
    		int Re = 0, Sum = 0 ;
    		for( re int j = i; j <= n; j += i ) {
    			vis[rev[j]] = 1, Sum = ( Sum + phi[j] * dep[rev[j]] % P ) % P,  
    			Re = ( Re + phi[j] ) % P, K[++ tot] = rev[j] ; 
    		}
    		sort( K + 1, K + tot + 1, cmp ), s[top = 1] = 1 ;
    		rep( i, 1, tot ) insert( K[i] ) ;
    		while( top > 1 ) add( s[top - 1], s[top] ), -- top ; 
    		Sum *= Re, Sum %= P, Dp( 1, 1 ), dp[1] = head[1] = 0, F[i] = ( Sum * 2 - Ans + P ) % P ;
    		for( re int j = i; j <= n; j += i ) vis[rev[j]] = 0 ; Ans = 0, tot = 0, cnt = 0 ;
    	}
    }
    void Get_F() {
    	rep( i, 1, n ) for( re int j = i; j <= n; j += i ) 
    	f[i] += F[j] * mu[j / i], f[i] = ( f[i] + P ) % P ;
    	rep( i, 1, n ) Ans = ( Ans + i * fpow( phi[i], P - 2 ) % P * f[i] ) % P ; 
    	int Ks = Ans * fpow( n, P - 2 ) % P ; Ks = Ks * fpow( n - 1, P - 2 ) % P ;
    	printf("%I64d
    ", ( Ks + P ) % P ) ;
    }
    signed main()
    {
    	input(), solve(), Get_F() ; 
    	return 0 ;
    }
    
  • 相关阅读:
    Maven入门教程
    认识Java Core和Heap Dump
    [Java IO]03_字符流
    Eclipse 实用技巧
    可变和不可变的区分
    什么猴子补丁待补充
    当退出python时,是否释放全部内存
    解释python中的help()和dir()函数
    在python中是如何管理内存的
    解释一下python中的继承
  • 原文地址:https://www.cnblogs.com/Soulist/p/11637301.html
Copyright © 2020-2023  润新知