题目链接
题解
dp一下
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
inline int read() {
int x = 0,f = 1;
char c = getchar();
while(c < '0' || c > '9'){if(c == '-')f = -1; c = getchar(); }
while(c <= '9' && c >= '0') x = x * 10 + c - '0',c = getchar();
return x *f ;
}
#define int long long
const int maxn = 300007;
int n;
struct node {
int v,next,w;
} edge[maxn << 1];
int head[maxn],num = 0;
inline void add_edge(int u,int v,int w ) {
edge[++ num].v = v;edge[num].next = head[u];head[u] = num; edge[num].w = w;
}
int a[maxn];
#define INF 100000000000000000ll
//没有黑点or只有一个白点
int dp[maxn][2][3];//x的子数中颜色k的个数
int tmp[2][3];//notice 如果要断开的话需要辅助数组
void dfs(int x,int fa = 0) {
dp[x][a[x] == 0][a[x] == 1] = 0;
for(int e = head[x];e;e = edge[e].next) {
int v = edge[e].v;
if(v == fa) continue;
dfs(v,x);
for(int i = 0;i <= 1;++i) for(int j = 0;j <= 2;++ j) tmp[i][j] = INF;
for(int i = 0;i <= 1;++ i) for(int j = 0;j <= 2;++ j) {
if(dp[x][i][j] == INF) continue;
for(int k = 0;k <= 1;++ k) for(int l = 0;l <= 2;++ l) {
if(dp[v][k][l] == INF) continue;
int t1 = k + i >= 1 ? 1 : k + i;
int t2 = j + l >= 2 ? 2 : j + l;
tmp[t1][t2] = std::min(tmp[t1][t2],dp[x][i][j] + dp[v][k][l]);
if(!k || l <= 1) tmp[i][j] = std::min(tmp[i][j],dp[x][i][j] + dp[v][k][l] + edge[e].w);
}
}
std::memcpy(dp[x],tmp,sizeof tmp);
}
}
main() {
int T = read();
while(T --) {
memset(dp,0x3f,sizeof dp);
memset(head,0,sizeof head);
num = 0;
n = read();
for(int i = 1;i <= n;++ i) a[i] = read();
for(int u,v,w,i = 1;i < n;++ i) {
u = read(),v = read();w = read();
add_edge(u,v,w);
add_edge(v,u,w);
}
dfs(1);
int ans = INF;
for(int i = 0;i < 2;++ i)
for(int j = 0;j < 3;++ j)
if(!i || j < 2) ans = std::min(ans,dp[1][i][j]);
printf("%lld
",ans);
}
return 0;
}