题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5956
题意:一颗树上每条边有个权值,每个节点都有新闻要送到根节点就是1节点,运送过程中如果不换青蛙就是走过的所有边权之和的平方,如果换就每次更换要加上P,也就是求“每个节点到根节点这段路径切分成几块之后 [每块的权值和的平方加上(块个数-1)*P] 的最小值”。然后找到所有节点中消耗最大的那个是多少。
题解:设 dist[ i ] 表示节点 i 到根节点的距离,有 dp[ i ] = min(dp[ j ] + ( dist[ i ] - dist[ j ] ) ^ 2 + p),显然是斜率dp,需要注意的是这是在树上做斜率dp,当遍历了一个节点的某一棵子树后,遍历该节点的下一棵子树要恢复到之前的状态,可以考虑 dfs 过程中多传两个参数代表当前队列的 head 和 tail,而遍历完一棵子树后可能改变的是 tail,所以在比遍历之前把 tail 的节点记录下来即可。
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define ll long long 4 #define ull unsigned long long 5 #define mst(a,b) memset((a),(b),sizeof(a)) 6 #define mp(a,b) make_pair(a,b) 7 #define pi acos(-1) 8 #define pii pair<int,int> 9 #define pb push_back 10 const int INF = 0x3f3f3f3f; 11 const double eps = 1e-6; 12 const int MAXN = 1e5 + 10; 13 const int MAXM = 1e3 + 10; 14 const ll mod = 100000073; 15 16 int n; 17 ll p,ans; 18 vector<pair<int,ll> >vec[MAXN]; 19 ll dist[MAXN],dp[MAXN]; 20 int q[MAXN]; 21 22 ll sqr(ll x) { 23 return x * x; 24 } 25 26 ll getup(int j,int k) { 27 return dp[j] + sqr(dist[j]) - (dp[k] + sqr(dist[k])); 28 } 29 30 ll getdown(int j,int k) { 31 return 2ll * (dist[j] - dist[k]); 32 } 33 34 void dfs(int u,int fa,int st,int en) { 35 dp[u] = dist[u] * dist[u]; 36 int head = st, tail = en; 37 while(head + 1 < tail && getup(q[head + 1],q[head]) <= dist[u] * getdown(q[head + 1],q[head])) head++; 38 dp[u] = min(dp[u], dp[q[head]] + sqr(dist[u] - dist[q[head]]) + p); 39 while(head + 1 < tail && getup(u,q[tail - 1]) * getdown(q[tail - 1],q[tail - 2]) <= getup(q[tail - 1],q[tail - 2]) * getdown(u,q[tail - 1])) 40 tail--; 41 q[tail++] = u; 42 int pre = u; 43 ans = max(ans, dp[u]); 44 for(int i = 0; i < vec[u].size(); i++) { 45 int v = vec[u][i].first; 46 ll w = vec[u][i].second; 47 if(v == fa) continue; 48 dist[v] = dist[u] + w; 49 dfs(v,u,head,tail); 50 } 51 q[tail - 1] = pre; 52 } 53 54 int main() { 55 #ifdef local 56 freopen("data.txt", "r", stdin); 57 #endif 58 int t; 59 scanf("%d",&t); 60 while(t--) { 61 scanf("%d%lld",&n,&p); 62 for(int i = 1; i <= n; i++) { 63 vec[i].clear(); 64 } 65 for(int i = 1; i < n; i++) { 66 int u,v; 67 ll w; 68 scanf("%d%lld%lld",&u,&v,&w); 69 vec[u].push_back(make_pair(v,w)); 70 vec[v].push_back(make_pair(u,w)); 71 } 72 dist[1] = 0; 73 int head = 0, tail = 0; 74 q[tail++] = 1; 75 ans = 0; 76 dfs(1,0,0,1); 77 printf("%lld ",ans); 78 } 79 return 0; 80 }