http://codeforces.com/gym/101341/problem/K
题意:给出n个区间,每个区间有一个l, r, w,代表区间左端点右端点和区间的权值,现在可以选取一些区间,要求选择的区间不相交,问最大的权和可以是多少,如果权和相同,则选区间长度最短的。要要求输出区间个数和选了哪些区间。
思路:把区间按照右端点排序后,就可以维护从左往右,到p[i].r这个点的时候,已经选择的最大权和是多少,最小长度是多少,区间个数是多少。
因为可以二分找右端点小于等于当前区间的左端点的某个区间(index),然后就有
dp[i] = max(dp[index] + w[index], dp[i]).
这题的我觉得困难的是打印使用了哪些区间,一开始想的方法超时了。
因为用pre数组记录路径,但是路径只是代表当前的解从哪里更新过来的,而不能记录是由哪一个点推出新的最优解。
一开始我用了一个bool型的vis数组判断是否在第i个点更新的答案,然后每次往前找,这样最坏会达到O(n^2)。
但是这样肯定扫了很多重复的情况,因此还可以用递推来构造vis数组,代表当前这个解最后使用了哪一个区间推出来。pre数组也使用递推的形式。
每次如果更新的了,就vis[i] = i,pre[i] = index,否则就vis[i] = vis[i-1], pre[i] = pre[i-1]。
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long LL; 4 #define N 200010 5 #define INF 0x3f3f3f3f 6 struct node { 7 int l, r, w, id; 8 bool operator < (const node &rhs) const { 9 if(r != rhs.r) return r < rhs.r; 10 return l < rhs.l; 11 } 12 } p[N]; 13 LL dp[N], len[N]; 14 int tol[N], pre[N], vis[N]; 15 vector<int> ans; 16 // vis表示新的DP答案由哪个推出来的 17 // pre表示当前这个点可以由哪一个点跳过来 18 // dp[i]表示到第i个右端点的时候最大权和 19 // len[i]表示到第i个右端点的时候最小长度 20 // tol[i]表示区间个数 21 22 int main() { 23 int n; scanf("%d", &n); 24 for(int i = 1; i <= n; i++) scanf("%d%d%d", &p[i].l, &p[i].r, &p[i].w), p[i].id = i; 25 sort(p + 1, p + 1 + n); 26 for(int i = 1; i <= n; i++) { 27 dp[i] = dp[i-1], len[i] = len[i-1], tol[i] = tol[i-1], pre[i] = pre[i-1], vis[i] = vis[i-1]; 28 node now = (node) { INF, p[i].l, 0, 0 }; 29 int index = upper_bound(p + 1, p + n + 1, now) - p - 1; 30 // while(!vis[index] && index) index--; 31 if(dp[index] + p[i].w > dp[i]) { 32 dp[i] = dp[index] + p[i].w; 33 len[i] = len[index] + p[i].r - p[i].l; 34 tol[i] = tol[index] + 1; 35 pre[i] = index; vis[i] = i; 36 } else if(dp[index] + p[i].w == dp[i] && len[index] + p[i].r - p[i].l < len[i]) { 37 len[i] = len[index] + p[i].r - p[i].l; 38 tol[i] = tol[index] + 1; 39 pre[i] = index; vis[i] = i; 40 } 41 // printf("index : %d %d %d %d %lld ", p[i].id, p[vis[i]].id, p[index].id, p[pre[i]].id, dp[i]); 42 } 43 int ed = n; 44 while(ed) { ans.push_back(p[vis[ed]].id); ed = pre[ed]; } 45 sort(ans.begin(), ans.end()); 46 printf("%d %lld %lld ", tol[n], dp[n], len[n]); 47 for(int i = 0; i < ans.size(); i++) printf("%d%c", ans[i], i + 1 == ans.size() ? ' ' : ' '); 48 return 0; 49 }