题目链接
(Describe)
题目描述
为了提高智商,(ZJY)开始学习线性代数。她的小伙伴菠萝给她出了这样一个问题:给定一个(n×n)的矩阵(B)和一个(1×n)的矩阵(C)。求出一个(1×n)的(01)矩阵(A)。使得(D=(A*B-C)*A^T) 最大,其中(A^T)为(A)的转置。输出(D)。
输入格式:
第一行输入一个整数(n)。接下来(n)行输入(B)矩阵,第(i)行第(j)个数代表(B)接下来一行输入(n)个整数,代表矩阵(C)。矩阵(B)和矩阵(C)中每个数字都是不过(1000)的非负整数
输出格式:
输出一个整数,表示最大的(D)。
输入样例:
3
1 2 1
3 1 0
1 2 3
2 3 7
输出样例:
2
(Solution)
首先来化简一下式子
[D=(A*B-C)*A^T
]
[=sum_{i=1}^{n}(sum_{j=1}^{n}A_j*B_{j,i}-C_i)*A_i
]
[=sum_{i=1}^{n}sum_{j=1}^{n}A_i*A_j*B_{i,j}-sum_{i=1}^{n}C_i*A_i
]
因为题目已经说明了(A)是一个(01)串,所以我们可以发现当(A_i)为(0)的时候对答案并没有任何贡献,不用计算。当(A_i)为(1)时,会有(C_i)的花费。但如果同时选(j)会有(B_{i,j})的花费.所以这显然是一个最小割模型了。讲1看为选,0为不选
建图:
- 将每个(B_{ij})看做一个点,总共有(n*n)个点。将这(S)和这(n*n)个点相连,流量为(B_{i,j})
- 新建(n)个点。将这些点和(T)相连,流量为(C_i)
- 将(n*n)个点和新建节点中的(i,j)相连,流量为(inf)
答案就是(B)矩阵内的和-最小割
(Code)
#include<bits/stdc++.h>
#define inf 1e9
using namespace std;
typedef long long ll;
int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
return x*f;
}
struct node{
int to,next,v;
}a[2000001];
int head[1000001],cnt,n,m,s,t,x,y,z,dep[260000],sum,cur[260000];
void add(int x,int y,int c){
a[++cnt].to=y,a[cnt].next=head[x],a[cnt].v=c,head[x]=cnt;
a[++cnt].to=x,a[cnt].next=head[y],a[cnt].v=0,head[y]=cnt;
}
queue<int> q;
int bfs(){
memset(dep,0,sizeof(dep));
q.push(s);
dep[s]=1;
while(!q.empty()){
int now=q.front();
q.pop();
for(int i=head[now];i;i=a[i].next){
int v=a[i].to;
if(!dep[v]&&a[i].v>0)
dep[v]=dep[now]+1,q.push(v);
}
}
if(dep[t])
return 1;
return 0;
}
int dfs(int k,int list){
if(k==t||!list)
return list;
for(int &i=cur[k];i;i=a[i].next){
int v=a[i].to;
if(dep[v]==dep[k]+1&&a[i].v>0){
int p=dfs(v,min(list,a[i].v));
if(p){
a[i].v-=p;
i&1?a[i+1].v+=p:a[i-1].v+=p;
return p;
}
}
}
return 0;
}
int Dinic(){
int ans=0,k;
while(bfs()){
for(int i=s;i<=t;i++)
cur[i]=head[i];
while((k=dfs(s,inf)))
ans+=k;
}
return ans;
}
int main(){
n=read(),s=0,t=n*n+n+1;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
x=read(),sum+=x,add(s,(i-1)*n+j,x),add((i-1)*n+j,i+n*n,inf),add((i-1)*n+j,j+n*n,inf);
for(int i=1;i<=n;i++)
x=read(),add(i+n*n,t,x);
printf("%d
",sum-Dinic());
}