• 牛客练习赛71 E- 神奇的迷宫 点分治+NTT


    牛客练习赛71 E- 神奇的迷宫

    题意

    给一颗(n)个点的树,每条边的长度均为(1),Alice和Bob两人依次传送到树的某两个结点。对于任意一个人,传送到点(i)的概率为(p_i),假设两人传送到的结点之间的最短距离为(L),那么他们挑战这个树的困难度为(w_i)

    问他们挑战这个树的困难度的期望是多少。

    (nle 10^5)

    分析

    (ans[i])表示两人最短距离为(i)的概率,答案即为(sum_{i=0}^{n-1}ans[i]cdot w[i])

    (ans[i])可以用点分治来做,以(u)作为分治中心时,枚举每个子树,用(A[i])表示已经枚举过的子树中到根的距离为(i)的点的概率之和,用(B[i])表示当前子树中到根的距离为(i)的点的概率之和,那么就可以更新(ans[k]+=sum_{i=0}^{k}A[i]cdot B[k-i]),注意到这是一个卷积形式,所以我们对(A,B)做一次卷积就能更新(ans[i]),因为答案要取模,所以用(NTT)来做卷积。

    复杂度为(O(nlog^2n))

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    typedef long long ll;
    const int N=1e6+10;
    const int mod = 998244353, G = 3, Gi = 332748118;
    int n;
    ll p[N],w[N];
    vector<int>g[N];
    int sz[N],vis[N],mx[N],rt,tot,k1,k2;
    ll ans,A[N],B[N];
    int limit = 1, L, r[N];
    ll a[N], b[N];
    ll ksm(ll a, ll b) {
        ll ret = 1;
        while(b) {
            if(b & 1) ret = (ret * a ) % mod;
            a = (a * a) % mod;
            b >>= 1;
        }
        return ret;
    }
    void NTT(ll *A, int type) {
        for(int i = 0; i < limit; i++)
            if(i < r[i]) swap(A[i], A[r[i]]);
        for(int mid = 1; mid < limit; mid <<= 1) {
            ll Wn = ksm( type == 1 ? G : Gi , (mod - 1) / (mid << 1));
            for(int j = 0; j < limit; j += (mid << 1)) {
                ll w = 1;
                for(int k = 0; k < mid; k++, w = (w * Wn) % mod) {
                     int x = A[j + k], y = w * A[j + k + mid] % mod;
                     A[j + k] = (x + y) % mod,
                     A[j + k + mid] = (x - y + mod) % mod;
                }
            }
        }
    }
    void gao() {
        limit=1,L=0;
        for(int i=0;i<=k1;i++) a[i]=A[i];
        for(int i=0;i<=k2;i++) b[i]=B[i];
        while(limit <= k1 + k2) limit <<= 1, L++;
        for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
        NTT(a, 1);NTT(b, 1);
        for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % mod;
        NTT(a, -1);
        ll inv = ksm(limit, mod - 2);
        for(int i=0;i<=k1+k2;i++) a[i]=(a[i]*inv)%mod;
        for(int i = 0; i <= k1 + k2&&i<n; i++){
            ans=(ans+a[i] * w[i]%mod*2%mod)%mod;
        }
        for(int i=0;i<=limit;i++) a[i]=b[i]=r[i]=0;
    }
    void getrt(int u,int fa){
        sz[u]=1,mx[u]=0;
        for(int x:g[u]){
            if(x==fa||vis[x]) continue;
            getrt(x,u);
            sz[u]+=sz[x];
            mx[u]=max(mx[u],sz[x]);
        }
        mx[u]=max(mx[u],tot-sz[u]);
        if(mx[u]<mx[rt]) rt=u;
    }
    void dfs(int u,int fa,int d){
        B[d]=(B[d]+p[u])%mod;
        k2=max(k2,d);
        for(int x:g[u]){
            if(x==fa||vis[x]) continue;
            dfs(x,u,d+1);
        }
    }
    void solve(int u){
        vis[u]=1;k1=k2=0;
        A[0]=p[u];
        for(int x:g[u]){
            if(vis[x]) continue;
            k2=0;
            dfs(x,u,1);
            /*
            for(int i=0;i<=k1;i++){
                for(int j=0;j<=k2;j++) if(i+j<n){
                    ans+=w[i+j]*A[i]%mod*B[j]%mod*2%mod;
                    ans%=mod;
                }
            }
            */
            gao();
            k1=max(k1,k2);
            for(int i=0;i<=k2;i++) A[i]=(A[i]+B[i])%mod,B[i]=0;
        }
        for(int i=0;i<=k1;i++) A[i]=0;
        for(int x:g[u]){
            if(vis[x]) continue;
            tot=sz[x],mx[rt=0]=n;
            getrt(x,0);
            solve(rt);
        }
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++){
            scanf("%lld",&p[i]);
            p[0]+=p[i];
            if(p[0]>=mod) p[0]-=mod;
        }
        p[0]=ksm(p[0],mod-2);
        for(int i=1;i<=n;i++){
            scanf("%lld",&w[i-1]);
            p[i]=p[i]*p[0]%mod;
        }
        for(int i=2,x,y;i<=n;i++){
            scanf("%d%d",&x,&y);
            g[x].push_back(y);
            g[y].push_back(x);
        }
        tot=mx[rt]=n;
        getrt(1,0);
        solve(rt);
        for(int i=1;i<=n;i++) ans=(ans+w[0]*p[i]%mod*p[i]%mod)%mod;
        printf("%lld
    ",ans);
        return 0;
    }
    
  • 相关阅读:
    2018ddctf wp
    装饰器
    python作用域
    闭包
    迭代器
    ord() expected string of length 1, but int found
    pygm2安装问题
    elf逆向入门
    【POJ
    【POJ
  • 原文地址:https://www.cnblogs.com/xyq0220/p/13809545.html
Copyright © 2020-2023  润新知