题面
https://www.luogu.org/problem/P5024
题解
写这道题的时候写了$3$个暴力($n,m le 2000,A,B$),结果都挂了。。。。。
考场暴力:
#include<cstdio> #include<cstring> #include<iostream> #include<vector> #define ri register int #define N 100050 #define LL long long #define INF 1000000000000LL using namespace std; inline int read() { int ret=0,f=0; char ch=getchar(); while (ch<'0' || ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0' && ch<='9') ret*=10,ret+=ch-'0',ch=getchar(); return f?-ret:ret; } vector<int> to[N]; int n,m,fa[N],bo[N],dep[N],p[N]; char s[5]; LL f[N][2],g[N][2]; void maketree(int x,int ff) { fa[x]=ff; dep[x]=dep[ff]+1; for (ri i=0;i<to[x].size();i++) { int y=to[x][i]; if (y==ff) continue; maketree(y,x); } } LL min(LL a,LL b) { if (a<b) return a; else return b; } void redp(int x) { if (bo[x]==2) f[x][1]=p[x],f[x][0]=INF; else if (bo[x]==1) f[x][0]=0,f[x][1]=INF; else f[x][0]=0,f[x][1]=p[x]; for (ri i=0;i<to[x].size();i++) { int y=to[x][i]; if (y==fa[x]) continue; redp(y); f[x][0]+=f[y][1]; f[x][1]+=min(f[y][1],f[y][0]); } } void specialsolve1() { for (ri i=1;i<=m;i++) { int a=read(),x=read(),b=read(),y=read(); bo[a]=x+1; bo[b]=y+1; redp(1); if (min(f[1][0],f[1][1])>INF/2) puts("-1"); else printf("%lld ",min(f[1][0],f[1][1])); bo[a]=0; bo[b]=0; } } void specialsolve2() { for (ri i=1;i<=m;i++) { int a=read(),x=read(),b=read(),y=read(); for (ri j=a;j;j=fa[j]) g[j][0]=f[j][0],g[j][1]=f[j][1]; for (ri j=b;j;j=fa[j]) g[j][0]=f[j][0],g[j][1]=f[j][1]; if (dep[a]<dep[b]) swap(a,b),swap(x,y); g[a][x^1]=INF; g[b][y^1]=INF; while (dep[a]>dep[b]) { g[fa[a]][0]-=f[a][1]; g[fa[a]][1]-=min(f[a][0],f[a][1]); g[fa[a]][0]+=g[a][1]; g[fa[a]][1]+=min(g[a][0],g[a][1]); a=fa[a]; } if (a==b) { while (dep[a]>1) { g[fa[a]][0]-=f[a][1]; g[fa[a]][1]-=min(f[a][0],f[a][1]); g[fa[a]][0]+=g[a][1]; g[fa[a]][1]+=min(g[a][0],g[a][1]); a=fa[a]; } } else { while (a^b) { g[fa[b]][0]-=f[b][1]; g[fa[b]][1]-=min(f[b][0],f[b][1]); g[fa[b]][0]+=g[b][1]; g[fa[b]][1]+=min(g[b][0],g[b][1]); b=fa[b]; g[fa[a]][0]-=f[a][1]; g[fa[a]][1]-=min(f[a][0],f[a][1]); g[fa[a]][0]+=g[a][1]; g[fa[a]][1]+=min(g[a][0],g[a][1]); a=fa[a]; } while (dep[a]>1) { g[fa[a]][0]-=f[a][1]; g[fa[a]][1]-=min(f[a][0],f[a][1]); g[fa[a]][0]+=g[a][1]; g[fa[a]][1]+=min(g[a][0],g[a][1]); a=fa[a]; } } if (min(g[1][0],g[1][1])>INF/2) puts("-1"); else printf("%lld ",min(g[1][0],g[1][1])); } } struct matrix { LL v[2][2]; void init() { memset(v,0x3f,sizeof(v)); } matrix operator * (const matrix &rhs) const { matrix ret; ret.init(); for (ri k=0;k<2;k++) for (ri i=0;i<2;i++) for (ri j=0;j<2;j++) ret.v[i][j]=min(ret.v[i][j],v[i][k]+rhs.v[k][j]); return ret; } } t[N<<2]; #define ls (x<<1) #define rs (x<<1|1) void maketree(int x,int lb,int rb) { if (lb==rb) { t[x].v[0][0]=INF; t[x].v[0][1]=0; t[x].v[1][0]=p[lb]; t[x].v[1][1]=p[lb]; return; } int mid=(lb+rb)>>1; maketree(ls,lb,mid); maketree(rs,mid+1,rb); t[x]=t[rs]*t[ls]; } void modify(int x,int lb,int rb,int loc,LL v) { if (lb==rb) { t[x].v[0][0]=INF; t[x].v[0][1]=0; t[x].v[1][0]=t[x].v[1][1]=v; return; } int mid=(lb+rb)>>1; if (loc<=mid) modify(ls,lb,mid,loc,v); else modify(rs,mid+1,rb,loc,v); t[x]=t[rs]*t[ls]; } void specialsolve3() { maketree(1,1,n); for (ri i=1;i<=m;i++) { int a=read(),x=read(),b=read(),y=read(); int cnt=0; if (x==0) { modify(1,1,n,a,INF); } else { modify(1,1,n,a,p[a]-INF); cnt++; } if (y==0) { modify(1,1,n,b,INF); } else { modify(1,1,n,b,p[b]-INF); cnt++; } LL ans=min(min(t[1].v[0][1],t[1].v[1][1]),min(t[1].v[0][0],t[1].v[1][0])); if (ans+cnt*INF<INF/2) printf("%lld ",ans+cnt*INF); else puts("-1"); modify(1,1,n,a,p[a]); modify(1,1,n,b,p[b]); } } int main() { scanf("%d %d %s",&n,&m,s); for (ri i=1;i<=n;i++) scanf("%d",&p[i]); for (ri i=1;i<n;i++) { int u,v; scanf("%d %d",&u,&v); to[u].push_back(v); to[v].push_back(u); } maketree(1,0); redp(1); if (n<=7000 && m<=7000) { specialsolve1(); return 0; } else if (s[0]=='B') { specialsolve2(); return 0; } else if (s[0]=='A') { specialsolve3(); return 0; } }