树上DP题。
其实有点类似于01的问题。方程很容易想到。首先,因为一条链的节点其实都是在树上的,所以很容易想到应该先求一个LCA。然后,当某节点不是链的LCA时,它的转移就是:
dp[i]=sum[i],其中,sum[i]是i的子节点的dp[i]的和。
如果它是某个点的LCA时,那么它的转移就是dp[i]=val[p]+(sum(sum[k])-sum(dp[k except i])),k是一条chain上的节点。
然而,这样做是不会过的,TLE。。。
树状数组维和chain上的节点的和。QAQ.....
怎么样维护呢?首先,使用DFS把树处理成DFS序,并维护时间戳,这是很显然的,记为l[u],r[u]。然后,假设对于某条链的LCA为u,计算出sum[u]和dp[u]后,对l[u]~r[u]进行区间更新。为什么要区间更新呢?因为,如果以后存在一条chain的起点是在u的子树内,那么,我们在计算链上节点和时,该链必定是经过u点的,这样,求前缀和就可以得到一个和是包含了u点的。区间更新,单点查询。单点查询就是完成了从起或终点到它们的LCA的链上节点求和,因为每次区间更新的时候,因为起点都在子树内,就可以对链上的节点的和不停地累加。
我弱菜没想到优化,TLE的代码。。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <vector> #pragma comment(linker, "/STACK:1024000000,1024000000") using namespace std; const int MAX=100010; struct Point{ int v,i; Point(int vv,int ii){ v=vv,i=ii; } }; vector<Point>chain[MAX]; vector<int>head_chain[MAX]; //vector<int>chain_point; int chain_from[MAX],chain_to[MAX],chain_val[MAX]; int pre[MAX]; struct Edge{ int u,v,next; }edge[MAX*2]; int head[MAX],tot,n,m; int sum[MAX],dp[MAX]; int parent[MAX],depth[MAX]; bool color[MAX]; void addedge(int u,int v){ edge[tot].u=u; edge[tot].v=v; edge[tot].next=head[u]; head[u]=tot++; } int get_point(int i){ // chain_point.clear(); int res=0; int u=chain_from[i],v=chain_to[i]; if(depth[u]<depth[v]) swap(u,v); while(depth[u]>depth[v]){ // chain_point.push_back(u); res+=(sum[u]-dp[u]); u=parent[u]; } while(u!=v){ // chain_point.push_back(u); // chain_point.push_back(v); res+=(sum[u]-dp[u]); res+=(sum[v]-dp[v]); u=parent[u]; v=parent[v]; } res+=sum[u]; // chain_point.push_back(u); return res; } void dfs(int u,int par){ sum[u]=0; for(int e=head[u];e!=-1;e=edge[e].next){ int v=edge[e].v; if(v!=par){ dfs(v,u); sum[u]+=dp[v]; } } dp[u]=sum[u]; if(head_chain[u].size()){ int sz=head_chain[u].size(); int tmp,index,tsz; for(int i=0;i<sz;i++){ index=head_chain[u][i]; tmp=get_point(index); // tmp=0; tsz=chain_point.size(); // for(int k=0;k<tsz;k++){ // tmp+=(sum[chain_point[k]]-dp[chain_point[k]]); // } dp[u]=max(dp[u],tmp+chain_val[index]); } } } int findx(int u){ int x=u; while(pre[u]!=-1){ u=pre[u]; } while(pre[x]!=-1){ int t=pre[x]; pre[x]=u; x=t; } return u; } void DFS(int u,int par,int dep){ parent[u]=par; depth[u]=dep; for(int e=head[u];e!=-1;e=edge[e].next){ int v=edge[e].v; if(v!=par){ DFS(v,u,dep+1); pre[findx(v)]=u; } } color[u]=true; int sz=chain[u].size(),v; for(int i=0;i<sz;i++){ v=chain[u][i].v; if(color[v]){ head_chain[findx(v)].push_back(chain[u][i].i); } } } /* void get_head(int i){ int u=chain_from[i],v=chain_to[i]; if(depth[u]<depth[v]) swap(u,v); while(depth[u]>depth[v]){ u=parent[u]; } while(u!=v){ u=parent[u],v=parent[v]; } head_chain[u].push_back(i); } */ int main(){ int T,u,v; scanf("%d",&T); while(T--){ scanf("%d%d",&n,&m); tot=0; memset(head,-1,sizeof(head)); memset(pre,-1,sizeof(pre)); memset(color,false,sizeof(color)); for(int i=1;i<n;i++){ scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); head_chain[i].clear(); chain[i].clear(); } chain[n].clear(); head_chain[n].clear(); for(int i=1;i<=m;i++){ scanf("%d%d%d",&chain_from[i],&chain_to[i],&chain_val[i]); chain[chain_from[i]].push_back(Point(chain_to[i],i)); chain[chain_to[i]].push_back(Point(chain_from[i],i)); // get_head(i); } DFS(1,1,1); dfs(1,0); printf("%d ",dp[1]); } return 0; }
http://blog.csdn.net/qq_24451605/article/details/47003497
简单的代码,其实写起来真的不难,主要是优化没想到。。。。
#pragma comment(linker, "/STACK:1024000000,1024000000") #include <iostream> #include <cstdio> #include <algorithm> #include <cstring> #include <vector> #define MAX 200007 using namespace std; int n,m,t; typedef long long LL; LL d[MAX]; LL sum[MAX]; LL c1[MAX<<1]; LL c2[MAX<<1]; int lowbit ( int x ) { return x&-x; } void add1 ( int x , LL v ) { while ( x <= n ) { c1[x] += v; x += lowbit ( x ); } } void add2 ( int x , LL v ) { while ( x <= n ) { c2[x] += v; x += lowbit ( x ); } } LL sum1 ( int x ) { LL res = 0; while ( x ) { res += c1[x]; x -= lowbit ( x ); } return res; } LL sum2 ( int x ) { LL res = 0; while ( x ) { res += c2[x]; x -= lowbit ( x ); } return res; } typedef pair<int,int> PII; vector<int> e[MAX]; vector<int> chain[MAX]; vector<PII> a[MAX]; vector<LL> w[MAX]; vector<LL> val[MAX]; int fa[MAX]; int times; bool used[MAX]; int l[MAX]; int r[MAX]; int _find ( int x ) { return fa[x] == x ? x: fa[x] = _find ( fa[x]); } void LCA ( int u ) { fa[u] = u; l[u] = ++times; used[u] = true; for ( int i = 0 ; i < e[u].size() ; i++ ) { int v = e[u][i]; if ( used[v] ) continue; LCA ( v ); fa[v] = u; } for ( int i = 0 ; i < chain[u].size() ; i++ ) { int v = chain[u][i]; if ( !used[v] ) continue; int x = _find ( v ); a[x].push_back ( make_pair ( u , v )); w[x].push_back ( val[u][i] ); } r[u] = ++times; } void dfs ( int u , int p ) { sum[u] = 0; d[u] = 0; for ( int i = 0 ; i < e[u].size() ; i++ ) { int v = e[u][i]; if ( v == p ) continue; dfs ( v , u ); sum[u] += d[v]; } for ( int i = 0 ; i < a[u].size() ; i++ ) { int x = a[u][i].first; int y = a[u][i].second; LL temp = sum1(l[x]) + sum1(l[y]) + sum[u] -sum2(l[x]) - sum2(l[y]); d[u] = max ( temp + w[u][i] , d[u] ); } d[u] = max ( d[u] , sum[u] ); add1 ( l[u] , sum[u] ); add1 ( r[u] , -sum[u] ); add2 ( l[u] , d[u] ); add2 ( r[u] , -d[u] ); } void init ( ) { times = 0; memset ( c1 , 0 , sizeof ( c1 ) ); memset ( c2 , 0 , sizeof ( c2 )); memset ( used , 0 , sizeof ( used )); for ( int i = 0 ; i < MAX ; i++ ) { e[i].clear(); val[i].clear(); a[i].clear(); w[i].clear(); chain[i].clear(); } } int main ( ) { int u,v,x; scanf ( "%d" , &t ); while ( t-- ) { init(); scanf ( "%d%d" , &n , &m ); int nn = n-1; while ( nn-- ) { scanf ( "%d%d" , &u , &v ); e[u].push_back ( v ); e[v].push_back ( u ); } n = n*2; while ( m-- ) { scanf ( "%d%d%d" , &u , &v , &x ); chain[u].push_back ( v ); chain[v].push_back ( u ); val[u].push_back ( x ); val[v].push_back ( x ); } //cout <<"YES" << endl; LCA ( 1 ); dfs ( 1, -1 ); printf ( "%I64d " , d[1] ); } }