题目链接
题目思路
看到单点修改,区间查询其实很容易想到线段树
但是这个线段树有点特殊
线段树的每个节点代表这个区间的最小生成树的所有边,用vector存
那么每个节点最多n-1条边
你每次合并使用类似归并排序就是\(O(n)\)
如果暴力合并就是\(O(nlogn)\)
那么使用归并排序总的时间复杂度就是\(O(qnlog(m))\)
代码
#include<bits/stdc++.h>
#define fi first
#define se second
#define debug cout<<"I AM HERE"<<endl;
using namespace std;
typedef long long ll;
const int maxn=3e4+5,inf=0x3f3f3f3f,mod=1e9+7;
const int eps=1e-6;
int n,m,q;
int fa[maxn];
struct edge{
int u,v,w;
}e[maxn];
vector<edge> tree[maxn<<2];
bool cmp(edge a,edge b){
return a.w<b.w;
}
int findd(int x){
return x==fa[x]?x:fa[x]=findd(fa[x]);
}
vector<edge> mer(vector<edge> a,vector<edge> b){
vector<edge> ans;
for(int i=1;i<=n;i++){
fa[i]=i;
}
a.push_back({-1,-1,inf});
b.push_back({-1,-1,inf});
int id1=0,id2=0;
while(id1<((int)a.size()-1)||id2<((int)b.size()-1)){
int x,y;
if(a[id1].w<=b[id2].w){
x=findd(a[id1].u);
y=findd(a[id1].v);
id1++;
if(x==y) continue;
fa[x]=y;
ans.push_back(a[id1-1]);
}else{
x=findd(b[id2].u);
y=findd(b[id2].v);
id2++;
if(x==y) continue;
fa[x]=y;
ans.push_back(b[id2-1]);
}
}
return ans;
}
void build(int node,int l,int r){
if(l==r){
tree[node].push_back(e[l]);
return ;
}
int mid=(l+r)/2;
build(node<<1,l,mid);
build(node<<1|1,mid+1,r);
tree[node]=mer(tree[node<<1],tree[node<<1|1]);
}
void update(int node,int l,int r,int pos){
if(l==r){
tree[node][0]=e[pos];
return ;
}
int mid=(l+r)>>1;
if(mid>=pos) update(node<<1,l,mid,pos);
else update(node<<1|1,mid+1,r,pos);
tree[node]=mer(tree[node<<1],tree[node<<1|1]);
}
vector<edge> query(int node,int l,int r,int L,int R){
if(L<=l&&r<=R){
return tree[node];
}
int mid=(l+r)/2;
vector<edge> ans;
if(mid>=L) ans=mer(ans,query(node<<1,l,mid,L,R));
if(mid<R) ans=mer(ans,query(node<<1|1,mid+1,r,L,R));
return ans;
}
signed main(){
scanf("%d%d%d",&n,&m,&q);
for(int i=1;i<=m;i++){
scanf("%d%d%d",&e[i].u,&e[i].v,&e[i].w);
}
build(1,1,m);
for(int i=1,x,y,z,t,l,r,opt;i<=q;i++){
scanf("%d",&opt);
if(opt==1){
scanf("%d%d%d%d",&x,&y,&z,&t);
e[x]={y,z,t};
update(1,1,m,x);
}else{
scanf("%d%d",&l,&r);
vector<edge> ans=query(1,1,m,l,r);
int sum=0;
for(int i=0;i<ans.size();i++){
sum+=ans[i].w;
}
if(ans.size()!=n-1){
printf("Impossible\n");
}else{
printf("%d\n",sum);
}
}
}
return 0;
}