题意:给定两个 (n) 元环,环上每个点有权值,分别为 (x_i, y_i)。定义两个环的差值为
[sum_{i=0}^{n-1}{(x_i-y_i)^2}
]
可以旋转其中的一个环,或者将其中一个环的每种权值加上一个数。求最小化的差值。
Solution: 加数只需要加在一个上面即可(假设可以为负),那么差值可以写成
[sum_{i=0}^{n-1}{(x_i-y_{i+k}+c)^2}
]
我们可以将差值定义为旋转位数(k)与加数(c)的函数,即 (f(k,c)) 。我们现在要找的就是二元函数的峰值。展开得
[f(k,c)=sum{x_i^2} + sum{y_i^2} + nc^2 + 2c sum{x_i} - 2csum{y_i} - 2 sum{x_i y_{(i+k)\%n}}
]
我们惊喜地发现没有交叉项,即
[f(k,c) = g(k) + h(c)
]
那么我们只需要分别最小化两部分即可。
考虑(g(k)),翻转序列(x),这是一个循环卷积的形式。我们可以将序列(y)扩增一倍来转化为线性卷积。那么此时
[g(k) = sum_{i=0}^{n-1}{x_{n-1-i} y_{i+k}}
]
对应到多项式乘法上,找(g(k))的最小值,即在幂次为 ([n-1,2n-1)) 的项中找最大系数即可。
考虑(h(c)),由于(m leq 100),暴力枚举即可。
#include<bits/stdc++.h>
#define pi acos(-1)
using namespace std;
struct poly {
typedef complex<double> E;
int n,m;
vector <double> c;
void read(int deg) {
c.resize(deg+1);
for(int i=0;i<=deg;i++) scanf("%lf",&c[i]);
}
void write() {
for(int i=0;i<c.size();i++) printf("%lf ",c[i]);
printf("
");
}
void fft(E *a,int f,int *R){
for(int i=0;i<n;i++)if(i<R[i])swap(a[i],a[R[i]]);
for(int i=1;i<n;i<<=1){
E wn(cos(pi/i),f*sin(pi/i));
for(int p=i<<1,j=0;j<n;j+=p){
E w(1,0);
for(int k=0;k<i;k++,w*=wn){
E x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
}
void mul(vector<double> A, vector<double> B){
int *R; E *a,*b; int L=0;
int _s=(A.size()+B.size()+2)<<5;
a=(E*)malloc(_s); b=(E*)malloc(_s); R=(int*)malloc(_s);
memset(a,0,_s);memset(b,0,_s);memset(R,0,_s);
n=A.size()-1; for(int i=0;i<=n;i++) a[i]=A[i];
m=B.size()-1; for(int i=0;i<=m;i++) b[i]=B[i];
m=n+m;for(n=1;n<=m;n<<=1)L++;
for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
fft(a,1,R);fft(b,1,R);
for(int i=0;i<=n;i++)a[i]=a[i]*b[i];
fft(a,-1,R);
c.resize(m+1);
for(int i=0;i<=m;i++)c[i]=a[i].real()/n;
free(a); free(b); free(R);
}
poly mul(poly pa,poly pb) {
poly pc; pc.mul(pa.c,pb.c); return pc;
}
poly operator * (const poly &pb) {
poly ret = mul(*this,pb); return ret;
}
poly operator *= (const poly &pb) {
mul(this->c,pb.c); return *this;
}
};
int n,m,a[100005],b[100005];
int main() {
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<n;i++) cin>>a[i];
for(int i=0;i<n;i++) cin>>b[i];
long long sx=0,sy=0;
for(int i=0;i<n;i++) sx+=a[i],sy+=b[i];
long long px=0,py=0;
for(int i=0;i<n;i++) px+=a[i]*a[i],py+=b[i]*b[i];
long long ans = 1e+9;
for(int i=-100;i<=100;i++) {
ans = min(ans, n*i*i+2*i*sx-2*i*sy);
}
ans += px + py;
poly x,y;
x.c.resize(n);
for(int i=0;i<n;i++) x.c[i]=a[n-i-1];
y.c.resize(2*n);
for(int i=0;i<n;i++) y.c[i]=y.c[i+n]=b[i];
poly z=x*y;
long long mx=0;
for(int i=0;i<n;i++) mx=max(mx,(long long)(z.c[n+i-1]+0.5));
cout<<ans-2*mx<<endl;
}