今天午睡的时候忽然想到这样一类问题:如何求一个序列(可正可负可零)的前k大(小)个子区间的和?随即想到了一个$O((n+k)(logn+logk))$的算法,思路并不难
看到区间和,首先想到的就是构造出前缀和,一般这么搞准没错~
然后看到前k小问题,一般能想到的做法是构造出一个边权非负的有向图(显式或隐式均可),从源点出发的每一条路径(也可能是每一个结点)对应一个解,转换成k短路问题,用优先队列求解
于是就自然而然地想到以每个左端点j作为起点,对其所有的右端点按权值构造构造一个小根堆,以点权之差作为边权,然后建个超级源点作为起点,向每个小根堆连一条权值为对应左端点点权的边(边权可能为负,但是是从起点连出去的所以没关系),假设序列长度为n,那么问题的答案就是这n个小根堆构成的图的前k短路
关键就在于如何快速构造出这个n个小根堆。注意到每个左端点j和j+1之间只差了一个结点,因此可以用类似主席树的方法构造出可持久化的堆,可以用单次合并复杂度为$O(logn)$的左偏树实现
1 #include<bits/stdc++.h> 2 #define l(u) ch[u][0] 3 #define r(u) ch[u][1] 4 using namespace std; 5 const int N=1e5+10,M=N*40; 6 int val[M],ds[M],ch[M][2],rt[N],tot,a[N],s[N],n,k; 7 int newnode(int x) {int u=++tot; ds[u]=l(u)=r(u)=0,val[u]=x; return u;} 8 int cpy(int u) {int w=++tot; ds[w]=ds[u],l(w)=l(u),r(w)=r(u),val[w]=val[u]; return w;} 9 void mg(int& w,int u,int v) { 10 if(!u||!v) {w=u|v; return;} 11 if(val[v]<val[u])swap(u,v); 12 w=cpy(u); 13 mg(r(w),r(u),v); 14 if(ds[r(w)]>ds[l(w)])swap(l(w),r(w)); 15 ds[w]=ds[r(w)]+1; 16 } 17 struct D { 18 int u,g; 19 bool operator<(const D& b)const {return g>b.g;} 20 }; 21 priority_queue<D> q; 22 void solve(int k) { 23 int flag=0; 24 while(q.size())q.pop(); 25 for(int i=0; i<n; ++i)q.push({rt[i+1],val[rt[i+1]]-s[i]}); 26 int cnt=0; 27 while(q.size()) { 28 int u=q.top().u,g=q.top().g; 29 q.pop(); 30 if(flag++)printf(" "); 31 printf("%d",g); 32 if(++cnt==k)break; 33 if(l(u))q.push({l(u),g-val[u]+val[l(u)]}); 34 if(r(u))q.push({r(u),g-val[u]+val[r(u)]}); 35 } 36 puts(""); 37 } 38 int main() { 39 scanf("%d%d",&n,&k); 40 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 41 for(int i=1; i<=n; ++i)s[i]=s[i-1]+a[i]; 42 for(int i=n; i>=1; --i)mg(rt[i],rt[i+1],newnode(s[i])); 43 solve(k); 44 return 0; 45 }
sample input:
10 10 1 3 -2 4 7 -5 -4 3 -3 5
sample output:
-9 -9 -6 -5 -4 -4 -4 -3 -2 -2
如果是子序列(不要求连续,相当于子集)的话更简单,直接对所有元素从小到大排序,每个小的结点向比它大的结点所有连一条权值为大的结点的点权的边,然后建一个超级源点与每个结点连一条权值为该结点的点权的边,则从源点出发的每一条路径对应一个可行解,利用多叉树转二叉树的思路隐式遍历即可,如果有负权就先把所有的负权加起来记为sum,然后把负权变成正权,最后的结果加上sum即可
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int N=1e5+10; 4 int a[N],n,k; 5 struct D { 6 int u,g; 7 bool operator<(const D& b)const {return g>b.g;} 8 }; 9 priority_queue<D> q; 10 void solve(int k) { 11 int flag=0; 12 while(q.size())q.pop(); 13 int sum=0; 14 for(int i=1; i<=n; ++i)if(a[i]<0)sum+=a[i],a[i]=-a[i]; 15 sort(a+1,a+1+n); 16 q.push({0,sum}); 17 int cnt=0; 18 while(q.size()) { 19 int u=q.top().u,g=q.top().g; 20 q.pop(); 21 if(flag++)printf(" "); 22 printf("%d",g); 23 if(++cnt==k)break; 24 if(u!=0&&u<n)q.push({u+1,g-a[u]+a[u+1]}); 25 if(u<n)q.push({u+1,g+a[u+1]}); 26 } 27 puts(""); 28 } 29 int main() { 30 scanf("%d%d",&n,&k); 31 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 32 solve(k); 33 return 0; 34 }
sample input:
10 10 1 3 -2 4 7 -5 -4 3 -3 5
sample output:
-14 -13 -12 -11 -11 -11 -11 -10 -10 -10
如果是树上路径,设结点u的点权为a[u],到根结点的路径上的点权和为s[u],那么要求的实际就是s[u]+s[v]-2*s[LCA(u,v)]+a[LCA(u,v)]的前k小。因此可以枚举LCA,用类似树形dp的方法构造可持久化左偏树,一边合并一边向全局优先队列里加入结点,与序列不同的是全局优先队列里的一个结点对应左偏树上的一对结点,即(u,v),每次扩张要扩4个结点,即(l(u),v),(r(u),v),(u,l(v)),(u,r(v)),还要注意用hash来防止重复扩张
1 #include<bits/stdc++.h> 2 #define l(u) ch[u][0] 3 #define r(u) ch[u][1] 4 using namespace std; 5 typedef long long ll; 6 const int N=1e5+10,M=N*40; 7 int hd[N],ne,a[N],s[N],n,k; 8 struct E {int v,nxt;} e[N<<1]; 9 void link(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 10 int val[M],ds[M],ch[M][2],rt[N],tot; 11 int newnode(int x) {int u=++tot; ds[u]=l(u)=r(u)=0,val[u]=x; return u;} 12 int cpy(int u) {int w=++tot; ds[w]=ds[u],l(w)=l(u),r(w)=r(u),val[w]=val[u]; return w;} 13 void mg(int& w,int u,int v) { 14 if(!u||!v) {w=u|v; return;} 15 if(val[v]<val[u])swap(u,v); 16 w=cpy(u); 17 mg(r(w),r(u),v); 18 if(ds[r(w)]>ds[l(w)])swap(l(w),r(w)); 19 ds[w]=ds[r(w)]+1; 20 } 21 struct D { 22 int u,v,g; 23 bool operator<(const D& b)const {return g>b.g;} 24 }; 25 priority_queue<D> q; 26 void dfs(int u,int fa,int sum) { 27 s[u]=sum,rt[u]=newnode(s[u]); 28 q.push({rt[u],0,a[u]}); 29 for(int i=hd[u]; ~i; i=e[i].nxt) { 30 int v=e[i].v; 31 if(v==fa)continue; 32 dfs(v,u,sum+a[v]); 33 q.push({rt[u],rt[v],val[rt[u]]+val[rt[v]]-2*s[u]+a[u]}); 34 mg(rt[u],rt[u],rt[v]); 35 } 36 } 37 unordered_set<ll> vis; 38 void solve(int k) { 39 int flag=0; 40 while(q.size())q.pop(); 41 vis.clear(); 42 dfs(1,0,a[1]); 43 int cnt=0; 44 while(q.size()) { 45 int u=q.top().u,v=q.top().v,g=q.top().g; 46 q.pop(); 47 if(flag++)printf(" "); 48 printf("%d",g); 49 if(++cnt==k)break; 50 if(l(u)&&!vis.count((ll)l(u)*M+v))q.push({l(u),v,g-val[u]+val[l(u)]}),vis.insert((ll)l(u)*M+v); 51 if(r(u)&&!vis.count((ll)r(u)*M+v))q.push({r(u),v,g-val[u]+val[r(u)]}),vis.insert((ll)r(u)*M+v); 52 if(l(v)&&!vis.count((ll)u*M+l(v)))q.push({u,l(v),g-val[v]+val[l(v)]}),vis.insert((ll)u*M+l(v)); 53 if(r(v)&&!vis.count((ll)u*M+r(v)))q.push({u,r(v),g-val[v]+val[r(v)]}),vis.insert((ll)u*M+r(v)); 54 } 55 puts(""); 56 } 57 int main() { 58 memset(hd,-1,sizeof hd),ne=0; 59 scanf("%d%d",&n,&k); 60 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 61 for(int i=1; i<n; ++i) { 62 int u,v; 63 scanf("%d%d",&u,&v); 64 link(u,v); 65 link(v,u); 66 } 67 solve(k); 68 return 0; 69 }
sample input:
9 40 2 3 -4 5 1 2 -3 3 -2 1 4 1 3 1 2 3 7 3 6 3 5 4 9 4 8
sample output:
-7 -6 -5 -5 -4 -3 -3 -2 -2 -2 -2 -2 -1 -1 0 0 1 1 1 2 2 2 2 3 3 3 3 3 3 3 4 5 5 5 5 6 6 7 7 8