• Fish eating fruit 沈阳网络赛(树形dp)


    Fish eating fruit

    [Time Limit: 1000 ms quad Memory Limit: 262144 kB ]

    题意

    大体的题意就是给出一棵树,求每一对点之间的距离,然后把该距离存在距离 (mod 3) 的位置,输出总和。

    思路

    令两个 (dp) 数组和两个辅助 (dp) 的数组。
    (dp1[i][j]) 表示从 (i) 为起点往下到各个点距离 (mod 3) 后为 (j) 的距离总和。
    (cnt1[i][j]) 表示以 (i) 为起点往下到各个点距离 (mod 3) 后为 (j) 的节点个数。
    (dp2[i][j]) 表示从 (i) 起点往上一步后到各个点距离 (mod 3) 后为 (j) 的距离总和。
    (cnt2[i][j]) 表示以 (i) 为起点往上一步后到各个点距离 (mod 3) 后为 (j) 的节点个数。


    对于两个 (dp) 分别跑一遍 (dfs)


    对于 (dp1) 比较好处理,直接往下 (dfs)
    (u) 开始的答案等于从 (v) 开始的答案加上这一条边 (w) 的贡献,可以得到

    [ dp1[u][(j+w)\%3] = sum (dp1[v][j] + cnt1[v][j]*w)\ cnt1[u][(j+w)\%3] = sum cnt1[v][j] ]


    对于 (dp2) 会比较麻烦,需要用 (fa) 节点向上的贡献在加上 (fa) 节点往下的贡献在减去 (fa) 节点往 (u) 走的贡献。这些节点就是 (u) 往上走一步后可以走到的所有节点。这样算出真实的节点数和距离总和,然后 (u) 才能开始转移。
    (faw) 为从 (u)(fa) 的路径长度
    计算真实的节点数:

    [ c[j] = cnt2[fa][j]+cnt1[fa][j]\ c[(j+faw)\%3] -= cnt1[u][j] ]

    计算真实的距离总和:

    [ d[j] = dp2[fa][j]+dp1[fa][j] \ d[(j+faw)\%3] -= dp1[u][0]+cnt1[u][j]*faw ]

    则最后的 (dp2) 就可以利用 (d)(c) 得到了

    [ dp2[u][(j+faw)\%3] = c[j]*faw+d[j] \ cnt2[u][(j+faw)\%3] = c[j] ]

    /*************************************************************** 
        > File Name    : a.cpp
        > Author       : Jiaaaaaaaqi
        > Created Time : Mon 16 Sep 2019 08:55:33 PM CST
     ***************************************************************/
    
    #include <map>
    #include <set>
    #include <list>
    #include <ctime>
    #include <cmath>
    #include <stack>
    #include <queue>
    #include <cfloat>
    #include <string>
    #include <vector>
    #include <cstdio>
    #include <bitset>
    #include <cstdlib>
    #include <cstring>
    #include <iostream>
    #include <algorithm>
    #include <unordered_map>
    #define  lowbit(x)  x & (-x)
    #define  mes(a, b)  memset(a, b, sizeof a)
    #define  fi         first
    #define  se         second
    #define  pb         push_back
    #define  pii        pair<int, int>
    
    typedef unsigned long long int ull;
    typedef long long int ll;
    const int    maxn = 1e5 + 10;
    const int    maxm = 1e5 + 10;
    const ll     mod  = 1e9 + 7;
    const ll     INF  = 1e18 + 100;
    const int    inf  = 0x3f3f3f3f;
    const double pi   = acos(-1.0);
    const double eps  = 1e-8;
    using namespace std;
    
    int n, m;
    int cas, tol, T;
    
    vector< pii > vv[maxn];
    ll cnt1[maxn][3], cnt2[maxn][3];
    ll dp1[maxn][3], dp2[maxn][3];
    
    void dfs1(int u, int fa) {
    	cnt1[u][0] = 1;
    	for(auto i : vv[u]) {
    		int v = i.fi, w = i.se;
    		if(v == fa)	continue;
    		dfs1(v, u);
    		dp1[u][(0+w)%3] += (cnt1[v][0]*w%mod+dp1[v][0])%mod;
    		dp1[u][(1+w)%3] += (cnt1[v][1]*w%mod+dp1[v][1])%mod;
    		dp1[u][(2+w)%3] += (cnt1[v][2]*w%mod+dp1[v][2])%mod;
    		for(int j=0; j<3; j++)	dp1[u][j] %= mod;
    		cnt1[u][(0+w)%3] += cnt1[v][0];
    		cnt1[u][(1+w)%3] += cnt1[v][1];
    		cnt1[u][(2+w)%3] += cnt1[v][2];
    	}
    }
    
    void dfs2(int u, int fa) {
    	if(u!=1) {
    		int faw;
    		for(auto i : vv[u]) {
    			if(i.fi == fa) {
    				faw = i.se;
    				break;
    			}
    		}
    		int c[3] = { 0 };
    		c[0] = cnt2[fa][0]+cnt1[fa][0];
    		c[1] = cnt2[fa][1]+cnt1[fa][1];
    		c[2] = cnt2[fa][2]+cnt1[fa][2];
    		c[(0+faw)%3] -= cnt1[u][0];
    		c[(1+faw)%3] -= cnt1[u][1];
    		c[(2+faw)%3] -= cnt1[u][2];
    		ll d[3] = { 0 };
    		d[0] = (dp2[fa][0]+dp1[fa][0])%mod;
    		d[1] = (dp2[fa][1]+dp1[fa][1])%mod;
    		d[2] = (dp2[fa][2]+dp1[fa][2])%mod;
    		d[(0+faw)%3] = ((d[(0+faw)%3] - (cnt1[u][0]*faw%mod+dp1[u][0])%mod+mod)%mod+mod)%mod;
    		d[(1+faw)%3] = ((d[(1+faw)%3] - (cnt1[u][1]*faw%mod+dp1[u][1])%mod+mod)%mod+mod)%mod;
    		d[(2+faw)%3] = ((d[(2+faw)%3] - (cnt1[u][2]*faw%mod+dp1[u][2])%mod+mod)%mod+mod)%mod;
    		
    		dp2[u][(0+faw)%3] = (c[0]*faw%mod+d[0])%mod;
    		dp2[u][(1+faw)%3] = (c[1]*faw%mod+d[1])%mod;
    		dp2[u][(2+faw)%3] = (c[2]*faw%mod+d[2])%mod;
    		cnt2[u][(0+faw)%3] += c[0];
    		cnt2[u][(1+faw)%3] += c[1];
    		cnt2[u][(2+faw)%3] += c[2];
    	}
    	for(auto i : vv[u]) {
    		int v = i.fi, w = i.se;
    		if(v == fa)	continue;
    		dfs2(v, u);
    	}
    }
    
    int main() {
    	// freopen("in", "r", stdin);
    	while(~scanf("%d", &n)) {
    		for(int i=1; i<=n; i++) {
    			vv[i].clear();
    		}
    		mes(dp1, 0), mes(dp2, 0);
    		mes(cnt1, 0), mes(cnt2, 0);
    		for(int i=1, u, v, w; i<n; i++) {
    			scanf("%d%d%d", &u, &v, &w);
    			u++, v++;
    			vv[u].pb(make_pair(v, w));
    			vv[v].pb(make_pair(u, w));
    		}
    		dfs1(1, 0);
    		dfs2(1, 0);
    		// for(int i=1; i<=n; i++) {
    		//     for(int j=0; j<3; j++) {
    		//         printf("dp1[%d][%d] = %lld, cnt1[%d][%d] = %lld
    ", i, j, dp1[i][j], i, j, cnt1[i][j]);
    		//     }
    		// }
    		// cout << "-----------------" << endl;
    		// for(int i=1; i<=n; i++) {
    		//     for(int j=0; j<3; j++) {
    		//         printf("dp2[%d][%d] = %lld, cnt2[%d][%d] = %lld
    ", i, j, dp2[i][j], i, j, cnt2[i][j]);
    		//     }
    		// }
    		ll ans0, ans1, ans2;
    		ans0 = ans1 = ans2 = 0;
    		for(int i=1; i<=n; i++) {
    			ans0 = (ans0+dp1[i][0]+dp2[i][0])%mod;
    			ans1 = (ans1+dp1[i][1]+dp2[i][1])%mod;
    			ans2 = (ans2+dp1[i][2]+dp2[i][2])%mod;
    		}
    		printf("%lld %lld %lld
    ", ans0, ans1, ans2);
    	}
    	return 0;
    }
    
  • 相关阅读:
    C++——overloading
    C++——调用优化
    C++——Big Three(copy ctor、copy op=、dtor)
    C++——引用 reference
    C++——构造函数 constructor
    003——矩阵的掩膜操作
    002——加载、修改、保存图像
    001——搭建OpenCV实验环境
    宏——基础
    剖析可执行文件ELF组成
  • 原文地址:https://www.cnblogs.com/Jiaaaaaaaqi/p/11530717.html
Copyright © 2020-2023  润新知