线段树。
这道题乍一看是道dp,但是发现1e6的范围。。。~(N^2过百万)~并且有后效性qaqqqqqq......
我的思路是这样的:考虑暴力,我们每次枚举左端点,O(n)求和,复杂度N^2。如何优化呢?考虑前缀和,当我们扫过某个点后会对答案产生什么影响?设NXT[ i ]表示i点之后第一个与 i 相同的数字的编号,那么我们只需要在[1,nxt[i]-1]中减去i 的贡献,在[ nxt[i],nxt[nxt[i]] ]范围内加上贡献就好了。注意特判边界(一开始60分qwq)
code:
#include<iostream>
#include<cstdio>
#include<queue>
#include<algorithm>
#include<cstring>
#define int long long
#define half (l+r)>>1
const int maxn=5000006;
using namespace std;
int n,m;
struct hzw
{
int tag,sum,lc,rc,mx;
}t[maxn];
int id[maxn],cost[maxn],nxt[maxn];
int pan[maxn];
bool flag[maxn];
int val[maxn];
inline void pushup(int s)
{
t[s].sum=t[t[s].lc].sum+t[t[s].rc].sum;
t[s].mx=max(t[t[s].lc].mx,t[t[s].rc].mx);
}
int tot;
inline void build(int s,int l,int r)
{
if (l==r)
{
t[s].sum+=val[l];
t[s].mx=val[l];
return;
}
int mid=half;
t[s].lc=++tot;
build(tot,l,mid);
t[s].rc=++tot;
build(tot,mid+1,r);
pushup(s);
}
inline void pushdown(int s,int l,int r)
{
int mid=half,le=t[s].lc,ri=t[s].rc;
t[le].tag+=t[s].tag,t[ri].tag+=t[s].tag;
t[le].sum+=(mid-l+1)*t[s].tag,t[ri].sum+=(r-mid)*t[s].tag;
t[le].mx+=t[s].tag,t[ri].mx+=t[s].tag;
t[s].tag=0;
}
inline void update(int s,int k,int cl,int cr,int l,int r)
{
if (l==cl&&r==cr)
{
t[s].sum+=(r-l+1)*k;
t[s].tag+=k;
t[s].mx+=k;
return;
}
int mid=half;
if (t[s].tag) pushdown(s,l,r);
if (cr<=mid) update(t[s].lc,k,cl,cr,l,mid);
else if (cl>mid) update(t[s].rc,k,cl,cr,mid+1,r);
else
{
update(t[s].lc,k,cl,mid,l,mid);
update(t[s].rc,k,mid+1,cr,mid+1,r);
}
pushup(s);
}
inline int query(int s,int cl,int cr,int l,int r)
{
if (l==cl&&r==cr)
{
return t[s].mx;
}
int mid=half;
if (t[s].tag) pushdown(s,l,r);
if (cr<=mid) return query(t[s].lc,cl,cr,l,mid);
else if (cl>mid) return query(t[s].rc,cl,cr,mid+1,r);
else
{
return max(query(t[s].lc,cl,mid,l,mid),query(t[s].rc,mid+1,cr,mid+1,r));
}
}
signed main()
{
cin>>n>>m;
tot=1;
for (int i=1;i<=n;++i) scanf("%lld",&id[i]);
for (int i=1;i<=m;++i) scanf("%lld",&cost[i]);
for (int i=1;i<=n;++i)
{
if (pan[id[i]])
{
nxt[pan[id[i]]]=i;
pan[id[i]]=i;
if (!flag[id[i]])
{
flag[id[i]]=1;
val[i]=val[i-1]-cost[id[i]];
}
else val[i]=val[i-1];
continue;
}
int tmp=val[i-1]+cost[id[i]];
val[i]=tmp;
pan[id[i]]=i;
}
int ans=-23333;
build(1,1,n);
for (int i=1;i<=n;++i)
{
ans=max(ans,query(1,i,n,1,n));
if (nxt[i])
{
int tmp;
if (!nxt[nxt[i]]) tmp=n+1;//特判
else tmp=nxt[nxt[i]];
tmp--;
update(1,-cost[id[i]],1,nxt[i]-1,1,n);
update(1,cost[id[i]],nxt[i],tmp,1,n);
}
else update(1,-cost[id[i]],1,n,1,n);//特判
}
cout<<ans;
return 0;
}