• 21牛客9G


    21牛客9G - Glass Balls (树上概率dp)

    题目

    source

    题解

    UPD:
    队友写得题解真好啊,简单清晰,推荐:2021牛客暑期多校训练营9 G (树上概率dp 对于存在不合法的情况dp启示)
    可以直接令分数为状态,从v转移一步到u,贡献是score[v]再加上v上面的那1个球,所以是score[u]=p*(score[v]+1)。关键在于这里的p是条件概率,即在所有合法情况中的占比。具体见上面的博客,这样只需一次dfs。


    对于从(u)点出发掉到(v)点的球来说,它的贡献是(dep[u]-dep[v])。设对于一个固定的局面,掉到(v)点的球的球的个数为(cnt[v]),那么所有球的贡献为(即该局面的分数)为:

    [sumlimits_{i=1}^{n}{dep[i]-sumlimits_{i=1}^{n}{cnt[i]cdot dep[i]}} ]

    因此,只要分别求出深度总和的期望每个结点掉下去球数的期望即可,可以用树上dp计算。这里有几点要注意的:

    • 局面有合法的情况和非法的情况,因此在转移状态时注意确保的是从合法的子状态以合法的过程转移过来。
    • 树上dp一般计算的是子树的结果,在合并统计答案时要考虑上子树外部分的影响,这也是为什么往往需要两个dfs计算down和up的原因。

    从题目中可以容易推得,每个结点的子节点中至多只有一个结点不是“储存点”,否则就是非法的。

    (dp[i])​为(i)​的子树中到(i)​的球数的期望;(down[i])​为点(i)​的子树为合法局面的概率;(up[i])​为整棵树在点(i)​​为“储存点”时且除去了(down[i])​​的合法概率。这里的(up[i])​是为了(i)​子树中到点(i)​的球数的期望转换为整棵树中从(i)​掉下去的球数的期望,即(cnt[i]=up[i] cdot dp[i])​。

    显然,深度总和的期望就是整棵树合法的概率乘上深度的总和,即(down[1] cdot sumlimits_{i=1}^n{dep[i]})​。

    (down)(up)的转移都比较简单,主要是(dp)的转移。设(P)为“储存点的概率”,(t)为点p子结点的个数。

    • 子结点都是“储存点”,且子节点都合法,此时(u)中只有1个球,这种情况的贡献为:

    [dp[u]=1 cdot P^t cdot prod_{v { m 是}u{ m的子节点}} {down[v]} ]

    • 子结点(v)​不是”储存点“,且子节点都合法,此时(u)中除了本身的1个球,还有来自(dp[v])那么多的球,这种情况的贡献为:

    [dp[u]=dp[v]cdot P^{t-1}cdot (1-P)cdot prod_{v' { m 是}u{ m的子节点且}v' eq v}{down[v']}+1cdot P^{t-1}cdot (1-P)cdotprod_{v' { m 是}u{ m的子节点}} {down[v']} ]

    最终答案为:(down[1] cdot sumlimits_{i=1}^n{dep[i]}-sumlimits_{i=1}^{n}{up[i] cdot dp[i]cdot dep[i]})

    #include <bits/stdc++.h>
    
    #define endl '
    '
    #define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
    #define mp make_pair
    #define seteps(N) fixed << setprecision(N) 
    typedef long long ll;
    
    using namespace std;
    /*-----------------------------------------------------------------*/
    
    ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
    #define INF 0x3f3f3f3f
    
    const int N = 5e5 + 10;
    const int M = 998244353;
    const double eps = 1e-5;
    
    ll down[N];
    ll up[N];
    ll dp[N];
    int dep[N];
    ll po;
    vector<int> np[N];
    
    inline ll qpow(ll a, ll b, ll m) {
        ll res = 1;
        while(b) {
            if(b & 1) res = (res * a) % m;
            a = (a * a) % m;
            b = b >> 1;
        }
        return res;
    }
    
    void dfs(int p, int fa, int d) {
        dep[p] = d;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            dfs(nt, p, d + 1);
        }
    }
    
    void caldown(int p, int fa) {
        ll lp = 1;
        int num = 0;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            num++;
            caldown(nt, p);
            lp = lp * down[nt] % M;
        }
        if(num)
            lp = lp * (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
        down[p] = lp;
    }
    
    void calup(int p, int fa) {
        int num = 0;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            num++;
            up[nt] = down[1] * qpow(down[nt], M - 2, M) % M;
        }
        if(num) {
            ll tp = (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
            for(int nt : np[p]) {
                if(nt == fa) continue;
                up[nt] = up[nt] * qpow(tp, M - 2, M) % M;
                up[nt] = up[nt] * (qpow(po, num - 1, M) * (1 - po + M) % M * (num - 1) % M + qpow(po, num, M)) % M;
                calup(nt, p);
            }
        }
    }
    
    void solve(int p, int fa) {
        int num = 0;
        ll lp = 1;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            num++;
            lp = lp * down[nt] % M;
            solve(nt, p);
        }
        dp[p] = qpow(po, num, M) * lp % M;
        if(num)
            for(int nt : np[p]) {
                if(nt == fa) continue;
                // 注意后面1的贡献
                // 不要写成(dp[nt] + 1) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M
                dp[p] += dp[nt] * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M + qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M;
                // 也可以写成
                // dp[p] += (dp[nt] + down[nt]) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M;
                
                dp[p] %= M;
            }
    }
    
    int main() {
        IOS;
        up[1] = 1;
        int n;
        cin >> n >> po;
        for(int i = 2; i <= n; i++) {
            int f;
            cin >> f;
            np[i].push_back(f);
            np[f].push_back(i);
        }
        dfs(1, 0, 1);
        caldown(1, 0);
        calup(1, 0);
        solve(1, 0);
        ll ans = 0;
        ll tp = down[1];
        for(int i = 1; i <= n; i++) {
            ans = (ans + (tp - up[i] * (dp[i]) % M + M) * dep[i] % M) % M;
        }
        cout << ans << endl;
    }
    
  • 相关阅读:
    WIN7远程桌面连接--“发生身份验证错误。要求的函数不受支持”
    django-xadmin使用之更改菜单url
    django-xadmin使用之配置页眉页脚
    django-xadmin定制之列表页searchbar placeholder
    django-xadmin定制之分页显示数量
    Chrome无界面浏览模式与自定义插件加载问题
    Chrome开启无界面浏览模式Python+Windows环境
    django-xadmin中APScheduler的启动初始化
    处理nginx访问日志,筛选时间大于1秒的请求
    将Excel文件转为csv文件的python脚本
  • 原文地址:https://www.cnblogs.com/limil/p/15143352.html
Copyright © 2020-2023  润新知