链接
http://acm.hdu.edu.cn/showproblem.php?pid=6832
题意
给定一张图,每个点不是0就是1,第i条边权值是(2^i),求以下式子的值
[sum_{i=1}^nsum_{j=1}^nd(i,j) imes[a_i=1wedge a_j=0]
]
([hspace{3mm}])是Iverson bracket
(wedge)是逻辑和符号
(d(i,j))表示i和j的最短路长度
思路
根据百度我们可以得到式子的意思是求所有1点和所有0点的最短路的和。
又因为边权是(2^i)这一特殊的形式,我们可以得到如下一个推论:
如果两点i和j间在(2^i)边前已经连通了,那么由(2^0+2^1+2^2+...+2^{i-1}<2^i)可以得到,此时的最短路不是两点直接连边,也就相当于所有的最短路都在MST上。
然后问题就来到了在MST上求答案,我们很容易联想到某一道题(我没找到那题)
枚举每条边,计算这条边对答案的贡献
假设这条边两边共连接了n个0点和m个1点,那么这条边对答案的贡献就是nm边权
具体如何维护这两个值有很多方法,比如并查集,比如bfs
代码
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define ms(a) memset(a, 0, sizeof(a))
#define repu(i, a, b) for (int i = a; i < b; i++)
#define repd(i, a, b) for (int i = a; i > b; i--)
using namespace std;
typedef long long ll;
typedef long double ld;
const int M = int(1e5)+ 10;
const int mod = int(1e9) + 7;
ll one;
ll zero;
ll ans;
int color[M];
ll dp[M][2];
int fa[M];
struct edge {
int to;
ll w;
int next;
};
edge edges[M * 2];
int head[M];
int cnt;
void add(int u, int v, ll w) {
edges[cnt].to = v;
edges[cnt].w = w;
edges[cnt].next = head[u];
head[u] = cnt++;
edges[cnt].to = u;
edges[cnt].w = w;
edges[cnt].next = head[v];
head[v] = cnt++;
}
void init() {
cnt = 0;
memset(head, -1, sizeof(head));
}
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
void dfs(int u, int r) {
dp[u][0] = dp[u][1] = 0;
dp[u][color[u]]++;
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to;
if (r == v) continue;
dfs(v, u);
dp[u][0] += dp[v][0];
dp[u][1] += dp[v][1];
}
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to;
if (r == v) continue;
ans +=
(((dp[v][0] * (one - dp[v][1])) % mod) * edges[i].w);
ans %= mod;
ans +=
(((dp[v][1] * (zero - dp[v][0])) % mod) * edges[i].w);
ans %= mod;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t;
// cin >> t;
scanf("%d",&t);
while (t--) {
init();
ans = 0;
zero = one = 0;
int n, m;
// cin >> n >> m;
scanf("%d%d",&n,&m);
repu(i, 1, n + 1) {
dp[i][0] = dp[i][1] = 0;
fa[i] = i;
// cin >> color[i];
scanf("%d",color+i);
if (color[i] == 0)
zero++;
else
one++;
}
ll val = 1;
repu(i, 0, m) {
int u, v;
// cin >> u >> v;
scanf("%d%d",&u,&v);
val = val * 2 % mod;
int x = find(u);
int y = find(v);
if (x == y) continue;
fa[x] = y;
add(u, v, val);
}
dfs(1, -1);
// cout << ans << endl;
printf("%lld
",ans);
}
return 0;
}