题目:https://www.luogu.org/problemnew/show/P5024
考场上只会写n,m<=2000的暴力,还想了想A1和A2的情况,不过好像只得了A1的分。然后仔细一看,原来是把dp2[ ][ ]写成dp[ ][ ]了。改一下,就能得到A1和A2的分。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=1e5+5; const ll INF=1e10+5; int n,m,p[N],hd[N],xnt,to[N<<1],nxt[N<<1]; int q0,q1,f0,f1; ll dp[N][3],dp2[N][3]; char ch[5]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9') ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return fx?ret:-ret; } void add(int x,int y) { to[++xnt]=y; nxt[xnt]=hd[x]; hd[x]=xnt; } ll Mn(ll a,ll b){return a<b?a:b;} ll Mx(ll a,ll b){return a>b?a:b;} void dfs(int cr,int fa) { dp[cr][0]=0; dp[cr][1]=p[cr]; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs(v,cr); dp[cr][0]+=dp[v][1]; dp[cr][1]+=Mn(dp[v][0],dp[v][1]); } if(cr==q0)dp[cr][!f0]=INF; if(cr==q1)dp[cr][!f1]=INF; } bool chk() { bool flag=0; for(int i=hd[q0];i;i=nxt[i]) if(to[i]==q1){flag=1;break;} if(flag&&!f0&&!f1) { puts("-1");return true; } return false; } void solve1() { for(int i=1;i<=m;i++) { q0=rdn();f0=rdn();q1=rdn();f1=rdn(); if(chk())continue; dfs(1,0); printf("%lld ",Mn(dp[1][0],dp[1][1])); } } void solve2() { dp[1][1]=p[1]; dp[1][0]=INF; for(int i=2;i<=n;i++) { dp[i][1]=Mn(dp[i-1][0],dp[i-1][1])+p[i]; dp[i][0]=dp[i-1][1]; } dp2[n][1]=p[n]; dp2[n][0]=0; for(int i=n-1;i;i--) { dp2[i][1]=Mn(dp2[i+1][0],dp2[i+1][1])+p[i]; dp2[i][0]=dp2[i+1][1]; } for(int i=1;i<=m;i++) { q0=rdn();f0=rdn();q1=rdn();f1=rdn(); if(chk())continue; printf("%lld ",dp[q1][f1]+dp2[q1][f1]-(f1?p[q1]:0)); } } void solve3() { dp[1][1]=p[1]; dp[1][0]=0; for(int i=2;i<=n;i++) { dp[i][1]=Mn(dp[i-1][0],dp[i-1][1])+p[i]; dp[i][0]=dp[i-1][1]; } dp2[n][1]=p[n]; dp2[n][0]=0; for(int i=n-1;i;i--) { dp[i][1]=Mn(dp[i+1][0],dp[i+1][1])+p[i]; dp[i][0]=dp[i+1][1]; } for(int i=1;i<=m;i++) { q0=rdn();f0=rdn();q1=rdn();f1=rdn(); if(chk())continue; if(q0>q1)swap(q0,q1),swap(f0,f1); printf("%lld ",dp[q0][f0]+dp2[q1][f1]); } } int main() { freopen("defense.in","r",stdin); freopen("defense.out","w",stdout); n=rdn();m=rdn();scanf("%s",ch+1); for(int i=1;i<=n;i++)p[i]=rdn(); for(int i=1,u,v;i<n;i++) { u=rdn(); v=rdn(); add(u,v); add(v,u); } if(n<=2000)solve1(); else if(ch[1]=='A'&&ch[2]=='1')solve2(); else if(ch[1]=='A'&&ch[2]=='2')solve3(); else solve1(); return 0; }
然后得知A的分好像就是一个线段树。想一想,就记录一下该区间两端的是0还是1就行了。
正解的一种是倍增。
先做出正常的dp[ ][ 0/1 ]数组。考虑倍增,f[ cr ][ i ][0/1][0/1]表示自己到第 i 个祖先的路上的贡献(不含自己及自己子树,含祖先,含路上的点以及它们的分叉,不含祖先上面的部分);只要把dp数组累加起来就行了;累加的时候注意把自己这一条减去,就是如果用父亲的dp[ ][1]的话,就减去自己的min(dp[ ][0],dp[ ][1]),不然就减去自己的dp[ ][1],然后把父亲的这个减去之后的东西放到f[ ][0][ ][ ]里;i 的其他值正常倍增合并就行了。
考虑统计,两个端点就用它们的dp值就行;路上的就倍增地走,用 f 值就行;lca处需要一个以该点为根的值,减去走来的那两条,然后累加到答案里。因为在lca处的那两条一定是原树的两个子树,所以可以换一遍根记录每个点作为根的值,用的时候减去那两条的dp值即可。
思路似乎也很简单?倍增真是好物。考虑两个点的问题时应该想一想倍增、lca之类的。然后考虑倍增数组维护什么,想一想统计答案的时候分为几种不同的部分,应该就差不多了。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=1e5+5,M=20; const ll INF=1e10+5; int n,m,p[N],pr[N][M],hd[N],xnt,to[N<<1],nxt[N<<1],lm,dep[N]; ll dp[N][2],info[N][2],f[N][M][2][2]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9') ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return fx?ret:-ret; } ll Mn(ll a,ll b){return a<b?a:b;} void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void dfs1(int cr,int fa) { dp[cr][1]=p[cr]; dep[cr]=dep[fa]+1; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs1(v,cr); dp[cr][1]+=Mn(dp[v][0],dp[v][1]); dp[cr][0]+=dp[v][1]; } } void dfs2(int cr,int fa,ll w0,ll w1) { info[cr][0]=dp[cr][0]+w1; info[cr][1]=dp[cr][1]+Mn(w0,w1); for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs2(v,cr,info[cr][0]-dp[v][1],info[cr][1]-Mn(dp[v][0],dp[v][1])); } } void dfsx(int cr,int fa) { for(int i=1,d;i<=lm&&pr[pr[cr][i-1]][i-1];i++) { d=pr[cr][i-1]; pr[cr][i]=pr[d][i-1]; for(int j=0;j<=1;j++) for(int k=0;k<=1;k++) { f[cr][i][j][k]=Mn(f[cr][i-1][j][0]+f[d][i-1][0][k],f[cr][i-1][j][1]+f[d][i-1][1][k]);//not dec p[d] } } for(int i=hd[cr],v;i;i=nxt[i])//f:without son if((v=to[i])!=fa) { pr[v][0]=cr; for(int j=0;j<=1;j++) { f[v][0][j][0]=dp[cr][0]-dp[v][1];//not pls p[v] f[v][0][j][1]=dp[cr][1]-Mn(dp[v][0],dp[v][1]);//not pls p[v] } f[v][0][0][0]=INF; dfsx(v,cr); } } ll cz(int x,int f0,int y,int f1) { ll d00=0,d01=0,d10=0,d11=0; if(dep[x]<dep[y])swap(x,y),swap(f0,f1); int x0=x,y0=y; ll y00=0,y01=0; if(f0)y00=INF; else y01=INF; for(int i=lm;i>=0;i--) if(dep[pr[x][i]]>dep[y])//> not >= { d00=Mn(y00+f[x][i][0][0],y01+f[x][i][1][0]); d01=Mn(y00+f[x][i][0][1],y01+f[x][i][1][1]); y00=d00; y01=d01; x=pr[x][i]; } ll ret=0; if(pr[x][0]==y) { ret=info[y][f1]; if(f1) ret-=Mn(dp[x][0],dp[x][1]),ret+=Mn(y00,y01);//y** not d** for init else ret-=dp[x][1],ret+=y01; ret+=dp[x0][f0]; return ret; } if(dep[x]!=dep[y]) { d00=Mn(y00+f[x][0][0][0],y01+f[x][0][1][0]); d01=Mn(y00+f[x][0][0][1],y01+f[x][0][1][1]); y00=d00; y01=d01; x=pr[x][0]; } ll y10=0,y11=0; if(f1)y10=INF; else y11=INF; for(int i=lm;i>=0;i--) if(pr[x][i]!=pr[y][i]) { d00=Mn(y00+f[x][i][0][0],y01+f[x][i][1][0]); d01=Mn(y00+f[x][i][0][1],y01+f[x][i][1][1]); y00=d00; y01=d01; x=pr[x][i]; d10=Mn(y10+f[y][i][0][0],y11+f[y][i][1][0]); d11=Mn(y10+f[y][i][0][1],y11+f[y][i][1][1]); y10=d10; y11=d11; y=pr[y][i]; } int d=pr[x][0]; ret=info[d][1]-Mn(dp[x][0],dp[x][1])-Mn(dp[y][0],dp[y][1])+Mn(y00,y01)+Mn(y10,y11);//y** not d** for init ll rt2=info[d][0]-dp[x][1]-dp[y][1]+y01+y11; ret=Mn(ret,rt2)+dp[x0][f0]+dp[y0][f1]; return ret; } int main() { freopen("defense.in","r",stdin); freopen("defense.out","w",stdout); char ch[5]; n=rdn(); m=rdn(); scanf("%s",ch); for(;(1<<lm)<n;lm++); for(int i=1;i<=n;i++)p[i]=rdn(); for(int i=1,u,v;i<n;i++) { u=rdn(); v=rdn(); add(u,v); add(v,u); } dfs1(1,0);dfs2(1,0,0,0);dfsx(1,0); for(int i=1,a,b,x,y;i<=m;i++) { a=rdn();x=rdn();b=rdn();y=rdn(); ll ans=cz(a,x,b,y); printf("%lld ",ans>=INF?-1:ans); } return 0; }