单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。
具体实现过程比较复杂,很神奇的一个树形dp,下面给出一个含较详细注释的代码及对应的一组自造的数据以及图片来进行解释
欢迎交流,给出意见~~~
数据
/*
第二行的1 2 3在图中分别用红黄蓝来表示
15 1 2 3 1 2 3 1 2 3 1 2 3 1 2 3 1 2 1 3 2 4 2 5 3 6 3 7 4 8 4 9 5 10 5 11 6 12 6 13 7 14 7 15 */
含注释的代码
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int N=2e5+5; int n; int c[N]; //各结点的颜色 int size[N]; //size[i]记录结点i的子树的大小,叶子为1 int vis[N]; //vis[i]记录颜色i是否出现过 int sum[N]; //见函数内注释及图片 vector<int> adj[N]; LL ans,de; LL C(LL a,LL b) //计算组合数,这样写是为了增加后面的代码的可读性 { LL ret=1; for(int i=1;i<=b;i++) ret=ret*(a+1-i)/i; return ret; } void dfs(int u,int pre) { // 定义 u 为“当前结点” // sum[c[u]]表示, 整棵树中,已经被dfs过的, // 与当前结点具有相同颜色的结点的所有子树的大小之和 // 不重复计数 // 图一展示了的是刚刚进入dfs(7,3)时sum[c[u]]包含的结点 printf(" ================================== Enter->%d ",u); int all=0; size[u]=1; for(int to:adj[u]) { if(to==pre) continue; int sumu_bd=sum[c[u]]; //bd : before dfs //记录此次dfs前的sum[c[u]] dfs(to,u); size[u]+=size[to]; int part=sum[c[u]]-sumu_bd; //dfs过后,sum[c[u]]会产生一个增量,用part来记录这个增量 all+=part; //all用来记录,对u dfs的过程中,sum[c[u]]产生的总增量 //这是在为 u返回pre时更新sum[c[u]]做准备 printf("from %d return to %d ",to,u); printf("all=%d,sumu_bd=%d,part=%d,sum[c[u]]=%d ",all,sumu_bd,part,sum[c[u]]); printf("to=%d,size[to]=%d ",to,size[to]); de+=C(size[to]-part,2); //size[to]-part的含义: 在to为根节点的子树中,有一部分结点与to位于同一“块”, // 这个块以 与u颜色相同的结点(不含),或是叶子(含,若叶子与u颜色相同则不含) 为边界 // size[to]-part表示的是这个块的大小 // 图二展示了 dfs(1,0)内,刚刚执行完dfs(2,1)后,更新de时的size[2]-part printf(" now,de=%d ",de); } printf(" === leaving... === "); printf("pre_sum[c[u]]=%d ",sum[c[u]]); sum[c[u]]+=size[u]-all; printf("after_sum[c[u]]=%d ",sum[c[u]]); printf(" leave from %d ================================== ",u); } int main() { int kase=0; while(~scanf("%d",&n)) { memset(vis,0,sizeof(vis)); memset(size,0,sizeof(size)); memset(sum,0,sizeof(sum)); for(int i=0;i<=n;i++) adj[i].clear(); int c_num=0; de=0; for(int i=1;i<=n;i++) { scanf("%d",&c[i]); if(!vis[c[i]]) c_num++; vis[c[i]]=1; } for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); adj[u].push_back(v); adj[v].push_back(u); } dfs(1,0); for(int i=1;i<=n;i++) if(i!=c[1]&&vis[i]) de+=C(n-sum[i],2); ans=C(n,2)*c_num-de; printf("Case #%d: %lld ",++kase,ans); } }
图一:
图二:
AC代码
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int N=2e5+5; int n; int c[N]; int size[N]; int vis[N]; int sum[N]; vector<int> adj[N]; LL ans,de; LL C(LL a,LL b) { LL ret=1; for(int i=1;i<=b;i++) ret=ret*(a+1-i)/i; return ret; } void dfs(int u,int pre) { int all=0; size[u]=1; for(int to:adj[u]) { if(to==pre) continue; int sumu_bd=sum[c[u]]; dfs(to,u); size[u]+=size[to]; int part=sum[c[u]]-sumu_bd; all+=part; de+=C(size[to]-part,2); } sum[c[u]]+=size[u]-all; } int main() { int kase=0; while(~scanf("%d",&n)) { memset(vis,0,sizeof(vis)); memset(size,0,sizeof(size)); memset(sum,0,sizeof(sum)); for(int i=0;i<=n;i++) adj[i].clear(); int c_num=0; de=0; for(int i=1;i<=n;i++) { scanf("%d",&c[i]); if(!vis[c[i]]) c_num++; vis[c[i]]=1; } for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); adj[u].push_back(v); adj[v].push_back(u); } dfs(1,0); for(int i=1;i<=n;i++) if(i!=c[1]&&vis[i]) de+=C(n-sum[i],2); ans=C(n,2)*c_num-de; printf("Case #%d: %lld ",++kase,ans); } }