题解-代价
前置知识:前缀和,暴搜找规律。
参考资料
暂无
题目描述
有一个 \(n+2\) 项的序列 \(a_1,a_2,...,a_{n+2}\),给你 \(a_2\sim a_{n+1}\),\(a_1=a_{n+2}=1\),每删除一个数 \(a_i\) 的代价是 \(a_i\times a_{i-1}\times a_{i+1}\),求删除序列的最小代价。
数据范围:\(1\le n\le 10^6\),\(1\le a_i\le 10^4\)。
考场上我没做出来,自闭了。
这题的暴力非常好写,可以写个双向链表,这样删数时维护 \(a_{i+1}\) 和 \(a_{i-1}\) 就非常方便了。代码如下:
#include <bits/stdc++.h>
using namespace std;
//&Start
#define lng long long
#define lit long double
#define kk(i,n) "\n "[i<n]
const int inf=0x3f3f3f3f;
const lng Inf=1e17;
//&Data
const int N=1e6+10;
int n,l[N],r[N];
lng a[N],b[N],sm[N],tmp,ans=Inf;
//&Main
void dfs(int x,lng sum){//我写了一个ida*,但其实没必要
if(sum+sm[(n+1-x)]>=ans) return;
if(x==n+1){ans=sum;return;}
for(int i=r[0];i!=n+1;i=r[i]){
l[r[i]]=l[i],r[l[i]]=r[i];
dfs(x+1,sum+a[i]*a[l[i]]*a[r[i]]);
l[r[i]]=i,r[l[i]]=i;
}
}
int main(){
scanf("%d",&n);
a[0]=a[n+1]=1;
r[0]=1,l[n+1]=n;
for(int i=1;i<=n;i++){
scanf("%lld",a+i),b[i]=a[i];
l[i]=i-1,r[i]=i+1;
tmp=min(tmp,a[i]);
}
sort(b+1,b+n+1);
for(int i=1;i<=n;i++)
sm[i]=sm[i-1]+b[i];
dfs(1,0);
printf("%lld\n",ans);
return 0;
}
时间复杂度是必爆的,但是这种结论题,可以用暴搜找规律(我考场上没想到啊啊啊)。多造几个序列,最好小一点,要不然暴搜太慢。然后看暴搜最佳选择方案,代码如下:
#include <bits/stdc++.h>
using namespace std;
//&Start
#define lng long long
#define lit long double
#define kk(i,n) "\n "[i<n]
const int inf=0x3f3f3f3f;
const lng Inf=1e17;
//&Data
const int N=1e6+10;
int n,l[N],r[N],xuan[N],as[N];
lng a[N],tmp,ans=Inf;
//&Main
void dfs(int x,lng sum){
if(sum+(n+1-x)*tmp>=ans) return;
if(x==n+1){
ans=sum;
memcpy(as,xuan,sizeof xuan);
return;
}
for(int i=r[0];i!=n+1;i=r[i]){
l[r[i]]=l[i],r[l[i]]=r[i];
xuan[x]=i;
dfs(x+1,sum+a[i]*a[l[i]]*a[r[i]]);
l[r[i]]=i,r[l[i]]=i;
}
}
int main(){
scanf("%d",&n);
if(n>10) return printf("%d\n",n),0;
a[0]=a[n+1]=1;
r[0]=1,l[n+1]=n;
for(int i=1;i<=n;i++){
scanf("%lld",a+i);
l[i]=i-1,r[i]=i+1;
tmp=min(tmp,a[i]);
}
dfs(1,0);
for(int i=1;i<=n;i++) printf("%d%c",as[i],kk(i,n));
printf("%lld\n",ans);
return 0;
}
然后你会发现,最优删法大部分都是从两边开始删的,唯独类似于
1 1 5 1 1 //n=3
这样的序列,不得不从中间把大的数去掉。
然后你能会脑洞大开,但是最终的正确方法是——把每个两边是 \(1\) 并且不包含 \(1\) 的子序列从两边删,然后最后删 \(1\)。
大致证明:
\(\texttt{First}\) ... \(1\) 必须最后取。
如果不最后取 \(a_i=1\),那么取 \(1\) 还会有 \(a_{i-1}\times a_{i+1}\) 的代价,并且取走两边的数的代价肯定会有个因子 \(a_{i-1}\times a_{i+1}\)。
而最后取 \(a_i=1\),代价为 \(1\)(因为最后只剩 \(1\) 了),并且取 \(a_{i-1}\) 和 \(a_{i+1}\) 时也能把因子 \(1\) 用上。
\(\texttt{Second}\) ... 每个不包含 \(1\) 并且两边是 \(1\) 的序列从两边删最佳。
如果从两边删就相当于代价只有相邻两项的乘积(因为左边或右边一定是 \(1\)),然后如果不从两边删代价至少为三项相乘,最后总和总是要大一点的(真是很口胡的证明,得出结论还是主要靠找规律)。
然后要知道得到了一个不包含 \(1\) 并且两边是 \(1\) 的区间 \([L,R]\) 要怎么 \(\Theta(R-L+1)\) 求最优取法(数据范围已经说明了时间复杂度)。
设 \([L,R]\) 最后取的数为 \(a_t(L\le t\le R)\),那么从 \(a_L\) 一直取到 \(a_{t-1}\) 的代价是
因为每次取的数左边肯定是 \(1\)。
同理,从 \(a_R\) 向左取到 \(a_{t+1}\) 的代价是
然后取 \(a_t\) 的代价是 \(1\times a_t\times 1=a_t\)。
所以总代价是
而这个式子的前两项可以用前缀后缀和维护。
区间 \([L,R]\) 最优取法的代价就是
所以拿到区间 \([L,R]\) 求最优取法就是 \(\Theta(R-L+1)\)。
因为最后每取一个 \(1\) 的代价是 \(1\),所以最后答案是
所以总时间复杂度 \(\Theta(n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
//&Start
#define lng long long
#define lit long double
#define kk(i,n) "\n "[i<n]
const int inf=0x3f3f3f3f;
const lng Inf=1e17;
//&Data
const int N=1e6+10;
int n;
lng a[N],sm[N],ms[N],ans;
//&Main
void solve(int l,int r){
for(int i=l+1;i<=r;i++) sm[i]=sm[i-1]+a[i]*a[i-1]; //前缀和
for(int i=r-1;i>=l;i--) ms[i]=ms[i+1]+a[i]*a[i+1]; //后缀和
lng res=Inf;
for(int i=l;i<=r;i++) res=min(res,a[i]+sm[i]+ms[i]);
ans+=(res==Inf)?0:res; //如果序列长度为0,特判一下
}
int main(){
scanf("%d",&n),a[0]=a[n+1]=1;//用0~n+1代替1~n+2
for(int i=1;i<=n;i++) scanf("%lld",a+i);
for(int lst=0,i=1;i<=n+1;i++)if(a[i]==1) solve(lst+1,i-1),lst=i,ans++;
printf("%lld\n",ans-1);//为了方便统计区间,把n+2的1也算进去了,减掉
return 0;
}
祝大家学习愉快!