有一棵点数为 N 的树,树边有边权。给你一个在 0~ N 之内的正整数 K ,你要在这棵树中选择 K个点,将其染成黑色,并将其他 的N-K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。
Solution
比较经典的树形背包问题。
如果只对点进行分析,情况会变得十分麻烦,不放考虑每条变的贡献,每条边会产生两边黑点数的乘积加上两边白点数的乘积。
这样的话我们直接跑背包就可以了,标准的树形背包是n^3的,但是这道题每颗字数背包体积有上限,总复杂度可以做到n^2.
Code
#include<iostream> #include<cstdio> #include<cstring> #define N 2009 using namespace std; long long dp[N][N]; int size[N],m,n,a,b,c,tot,head[N]; struct dsd { int n,to,l; }an[N<<1]; inline void add(int u,int v,int l) { an[++tot].n=head[u]; an[tot].to=v; head[u]=tot; an[tot].l=l; } void dfs(int u,int fa) { size[u]=1; for(int i=head[u];i;i=an[i].n) if(an[i].to!=fa) { int v=an[i].to; dfs(v,u); size[u]+=size[v]; for(int j=min(m,size[u]);j>=0;--j)// for(int k=0;k<=min(j,size[v]);++k) if(dp[v][k]!=-0x3f3f3f3f) { long long num=(long long)(k*(m-k)+(n-size[v]-(m-k))*(size[v]-k))*an[i].l; dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[v][k]+num); } } } int main() { scanf("%d%d",&n,&m); for(int i=1;i<n;++i) { scanf("%d%d%d",&a,&b,&c); add(a,b,c);add(b,a,c); } memset(dp,-0x3f,sizeof(dp)); for(int i=1;i<=n;++i) dp[i][0]=dp[i][1]=0;// dfs(1,0); cout<<dp[1][m]; return 0; }
这种写法太慢了,并没有做到严格n^2,bzoj会TLE,下面这种写法是稳过的。
Code
#include<iostream> #include<cstdio> #include<cstring> #define N 2009 using namespace std; typedef long long ll; ll dp[N][N],size[N],m,n,a,b,c,tot,head[N],g[N]; struct dsd { ll n,to,l; }an[N<<1]; inline void add(ll u,ll v,ll l) { an[++tot].n=head[u]; an[tot].to=v; head[u]=tot; an[tot].l=l; } ll mi(ll x,ll y){return x<y?x:y;} ll ma(ll x,ll y){return x<y?y:x;} void dfs(ll u,ll fa){ size[u]=1; for(ll i=head[u];i;i=an[i].n) if(an[i].to!=fa){ ll v=an[i].to; dfs(v,u); ll x=mi(m,size[u]),y=mi(m,size[v]); for(int j=0;j<=m;++j)g[j]=0; for(ll j=x;j>=0;--j) for(int k=0;k<=y;++k)if(j+k<=m){ ll gyx=((ll)k*(m-k)+(n-size[v]-(m-k))*(size[v]-k))*an[i].l; g[j+k]=ma(g[j+k],dp[u][j]+dp[v][k]+gyx); } for(int j=0;j<=m;++j)dp[u][j]=g[j]; size[u]+=size[v]; } } inline int rd(){ int x=0;char c=getchar(); while(!isdigit(c))c=getchar(); while(isdigit(c)){ x=(x<<1)+(x<<3)+(c^48); c=getchar(); } return x; } int main() { n=rd();m=rd(); for(int i=1;i<n;++i){ a=rd();b=rd();c=rd(); add(a,b,c);add(b,a,c); } memset(dp,-0x3f,sizeof(dp)); for(int i=1;i<=n;++i) dp[i][0]=dp[i][1]=0; dfs(1,0); printf("%lld",dp[1][m]); return 0; }