传送门:http://codeforces.com/problemset/problem/573/d
思路:首先如果没有限制,那么根据排序不等式,肯定按顺序匹配战士和马最好。
但是现在有了战士不能和自己的马匹配的限制。
于是就有了一个重要的性质:
最优匹配的前提下,排序后第i号战士只会与[i-2,i+2]号马匹配
至于证明,可以自己YY,也可以分情况讨论(好像很复杂...)
于是就可以DP了,设f[i]表示1-i号战士正好和1-i号战马匹配
记ban[i]表示第i号战士不能匹配的战马
那么转移方程就是:
f[i]=max{
f[i-1]+a[i]*b[i];(ban[i]!=i)
f[i-2]+a[i]*b[i-1]+a[i-1]*b[i];(ban[i]!=i-1)
f[i-3]+a[i]*b[i-2]+a[i-1]*b[i]+a[i-2]*b[i-1](ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1)
f[i-3]+a[i]*b[i-1]+a[i-1]*b[i-2]+a[i-2]*b[i](ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i)
}
然后每次DP一遍,复杂度O(qn),加一些常数优化就可以过了。
#include<cstdio> #include<cstring> #include<algorithm> typedef long long ll; const ll inf=1ll<<60; const int maxn=30010; using namespace std; struct node{ ll v;int id;//价值,不能匹配的编号 }a[maxn],b[maxn]; int n,q,posa[maxn],posb[maxn],ban[maxn];ll f[maxn],w1[maxn],w2[maxn],w3[maxn]; bool cmp(node a,node b){return a.v<b.v;} void calc(int i){ w1[i]=w2[i]=w3[i]=-inf; if (i>=1&&ban[i]!=i) w1[i]=a[i].v*b[i].v; if (i>=2&&ban[i]!=i-1&&ban[i-1]!=i) w2[i]=a[i].v*b[i-1].v+a[i-1].v*b[i].v; if (i>=3){ if (ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1) w3[i]=a[i].v*b[i-2].v+a[i-1].v*b[i].v+a[i-2].v*b[i-1].v; if (ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i) w3[i]=max(w3[i],a[i].v*b[i-1].v+a[i-1].v*b[i-2].v+a[i-2].v*b[i].v); } } int main(){ scanf("%d%d",&n,&q); for (int i=1;i<=n;i++) scanf("%I64d",&a[i].v),a[i].id=i; for (int i=1;i<=n;i++) scanf("%I64d",&b[i].v),b[i].id=i; sort(a+1,a+1+n,cmp),sort(b+1,b+1+n,cmp); for (int i=1;i<=n;i++) posa[a[i].id]=i,posb[b[i].id]=i; for (int i=1;i<=n;i++) ban[i]=posb[a[i].id]; for (int i=1;i<=n;i++) calc(i); // for (int i=1;i<=n;i++) printf("%d ",ban[i]); for (int j=1,x,y;j<=q;j++){ scanf("%d%d",&x,&y); x=posa[x],y=posa[y],swap(ban[x],ban[y]); for (int i=max(1,x-5);i<=min(n,x+5);i++) calc(i); for (int i=max(1,y-5);i<=min(n,y+5);i++) calc(i); f[0]=0; for (int i=1;i<=n;i++){ if (i>=1) f[i]=f[i-1]+w1[i]; if (i>=2) f[i]=max(f[i],f[i-2]+w2[i]); if (i>=3) f[i]=max(f[i],f[i-3]+w3[i]); } printf("%I64d ",f[n]); } return 0; } /* 4 15 70 46 78 69 90 93 83 11 2 3 3 4 4 1 3 1 4 3 3 1 2 4 3 1 3 2 3 4 2 3 1 2 1 4 4 1 1 2 */
然而还有更优的写法,用矩阵乘法+线段树优化
观察转移方程
f[i]=max{
f[i-1]+a[i]*b[i]; (ban[i]!=i)
f[i-2]+a[i]*b[i-1]+a[i-1]*b[i]; (ban[i]!=i-1)
f[i-3]+a[i]*b[i-2]+a[i-1]*b[i]+a[i-2]*b[i-1]; (ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1)
f[i-3]+a[i]*b[i-1]+a[i-1]*b[i-2]+a[i-2]*b[i]; (ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i)
}
可以发现f[i]只与f[i-1],f[i-2],f[i-3]有关我们于是联想到可以用一个3*3的矩阵来转移
而每次修改又只会影响几个转移矩阵
于是可以用线段树维护区间矩阵乘法,修改时就暴力改动那几个矩阵
但是这里的矩阵乘法和平时的乘法有些不同
现在乘法的定义是:c=a*b
c[i][j]=max(a[i][k]+b[k][j])
而正常的定义是
c[i][j]=Σ(a[i][k]*b[k][j])
我们现在只要证它有结合律就可以用线段树来维护区间乘法
令F=(a*b)*c,G=a*(b*c)
那么就有
F[i][j]=max(max(a[i][u]+b[u][v])+c[v][j])
G[i][j]=max(a[i][u]+max(b[u][v]+c[v][j]))
因为max和+有结合律,+对max有分配律
max((max(a,b)),c)=max(a,max(b,c))
(a+b)+c=a+(b+c)
max(a,b)+c=max(a+c,b+c)
于是G[i][j]=max(max(a[i][u]+b[u][v])+c[v][j])=F[i][j]
那么我们就证明了这种矩阵乘法有结合律
于是就可以上线段树解决了
具体细节:
如果不能转移就填-1
转移矩阵就是
-1 -1 case(i-2)
0 -1 case(i-1)
-1 0 case(i)
-1 -1 -1 *-1 -1 case(i-2)=-1 -1 -1
-1 -1 -1 0 -1 case(i-1) -1 -1 -1
f[i-2] f[i-1] f[i] -1 0 case(i) f[i-1] f[i] f[i-2]
#include<cstdio> #include<cstring> #include<algorithm> #define ls (p<<1) #define rs ((p<<1)|1) #define mid ((l+r)>>1) const int maxn=30010,maxt=maxn<<2; typedef long long ll; using namespace std; int n,q,posa[maxn],posb[maxn],ban[maxn]; struct node{ll v;int id;}a[maxn],b[maxn]; bool cmp(node a,node b){return a.v<b.v;} struct matrix{ ll mat[3][3]; void clear(){memset(mat,-1,sizeof(mat));} }; matrix operator *(matrix a,matrix b){ matrix res;res.clear(); for (int i=0;i<3;i++) for (int k=0;k<3;k++) if (a.mat[i][k]!=-1) for (int j=0;j<3;j++) if (b.mat[k][j]!=-1) res.mat[i][j]=max(res.mat[i][j],a.mat[i][k]+b.mat[k][j]); return res; } matrix get_matrix(int i){ matrix c;c.clear(); c.mat[1][0]=c.mat[2][1]=0; if (ban[i]!=i) c.mat[2][2]=a[i].v*b[i].v; if (i<=1) return c; if (ban[i]!=i-1) c.mat[1][2]=a[i].v*b[i-1].v+a[i-1].v*b[i].v; if (i<=2) return c; ll v1=-1,v2=-1; if (ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i) v1=a[i].v*b[i-1].v+a[i-1].v*b[i-2].v+a[i-2].v*b[i].v; if (ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1) v2=a[i].v*b[i-2].v+a[i-1].v*b[i].v+a[i-2].v*b[i-1].v; c.mat[0][2]=max(v1,v2); return c; } struct Segment_Tree{ matrix t[maxt]; void build(int p,int l,int r){ //printf("%d %d %d ",p,l,r); if (l==r){t[p]=get_matrix(l);return;} build(ls,l,mid),build(rs,mid+1,r); t[p]=t[ls]*t[rs]; } void modify(int p,int l,int r,int a){ if (l==r){t[p]=get_matrix(l);return;} if (a<=mid) modify(ls,l,mid,a); else modify(rs,mid+1,r,a); t[p]=t[ls]*t[rs]; } void modify(int a){modify(1,1,n,a);} ll query(){return t[1].mat[2][2];} }T; int main(){ scanf("%d%d",&n,&q); for (int i=1;i<=n;i++) scanf("%I64d",&a[i].v),a[i].id=i; for (int i=1;i<=n;i++) scanf("%I64d",&b[i].v),b[i].id=i; sort(a+1,a+1+n,cmp),sort(b+1,b+1+n,cmp); for (int i=1;i<=n;i++) posa[a[i].id]=i,posb[b[i].id]=i; for (int i=1;i<=n;i++) ban[i]=posb[a[i].id]; T.build(1,1,n); /*for (int i=1;i<=n;i++,puts("")){ printf("%d ",i);matrix c=get_matrix(i); for (int j=0;j<3;j++,puts("")) for (int k=0;k<3;k++) printf("%lld ",c.mat[j][k]); }*/ for (int x,y;q;q--){ scanf("%d%d",&x,&y),x=posa[x],y=posa[y],swap(ban[x],ban[y]); for (int i=max(1,x-2);i<=min(x+2,n);i++) T.modify(i); for (int i=max(1,y-2);i<=min(y+2,n);i++) T.modify(i); printf("%I64d ",T.query()); } return 0; }