今天校内考了次试,居然全是洛谷上的原题...
是的,就是C艹考试
T1还行,T2打了个超大模拟,手动模拟二叉树,搞了一堆数组用来存关系,又模拟了最近公共祖先(明明知道可以用倍增,但想不起来怎么写了),最后发现直接跑一个Floyd就可以了
T3就是这道题。由于T2模拟花了太多时间,T3基本上没怎么看,到了讲评的时候还不太明白题意。听了讲解瞬间感觉应该先做T1T3,因为T3我昨天刚刚做过一个相似的题(也是区间DP,能量项链了解一下)QwQ
题目传送门:戳我进入
首先想到的就是枚举每一种情况,以谁为根节点,然后将问题又分为两个相同的问题,之后在所有的答案里面选择一个最大的就可以了
假设所有的点为1,2,3,4....,n,如果我们已经确定了点i为根节点,由于输入是中序遍历,i左边的点都在他的左子树上,而i右边的点都在i的右子树上
我们再来看这棵树的加分计算方法:左子树加分*右子树加分+当前节点的分数
而当前节点的分数已经知道,我们只需要求两个子树的加分情况就好了
至于子树的形态...这个重要吗?
考虑区间DP
如何设置状态?
我们可以假设f[i][j]为区间[i,j]的最大加分,这样看来,我们最终的答案就是f[1][n]了
边界条件是什么?
我们看输入的数据,输入的是每一个节点的分数,实际上在区间DP中可以看成是区间[i,i]的最大分数。而且叶子的加分就是叶节点本身的分数,不考虑它的空子树,所以我们这样设置边界是正确的
状态转移方程?
当我们确定了区间长度后,开始枚举区间位置,然后枚举树的根节点。由于是求最大的分数,所以状态转移方程就出来了:确定区间位置后枚举所有的根节点,然后取最大值
f[i][j]=max(f[i][k-1]*f[k+1][j]+f[k][k])
由于是前序遍历输出,我们只需要再设置一个root数组表示根节点,在输出的时候递归输出就可以了
代码及注释:
#include<cstdio> #include<iostream> #include<cstdlib> #include<iomanip> #include<cmath> #include<cstring> #include<string> #include<algorithm> #include<time.h> #include<queue> using namespace std; typedef long long ll; typedef long double ld; typedef pair<int,int> pr; const double pi=acos(-1); #define rep(i,a,n) for(int i=a;i<=n;i++) #define per(i,n,a) for(int i=n;i>=a;i--) #define Rep(i,u) for(int i=head[u];i;i=Next[i]) #define clr(a) memset(a,0,sizeof a) #define pb push_back #define mp make_pair #define fi first #define sc second ld eps=1e-9; ll pp=1000000007; ll mo(ll a,ll pp){if(a>=0 && a<pp)return a;a%=pp;if(a<0)a+=pp;return a;} ll powmod(ll a,ll b,ll pp){ll ans=1;for(;b;b>>=1,a=mo(a*a,pp))if(b&1)ans=mo(ans*a,pp);return ans;} ll read(){ ll ans=0; char last=' ',ch=getchar(); while(ch<'0' || ch>'9')last=ch,ch=getchar(); while(ch>='0' && ch<='9')ans=ans*10+ch-'0',ch=getchar(); if(last=='-')ans=-ans; return ans; } //head int n; int a[500],f[500][500],root[500][500]; inline void print(int x,int y) { if(x>y) return; printf("%d ",root[x][y]); if(x==y) return; print(x,root[x][y]-1); print(root[x][y]+1,y); //由于是前序遍历,所以分成两个区间继续 } int main() { n=read(); rep(i,1,n) a[i]=read(); rep(i,1,n) f[i][i]=a[i],root[i][i]=i; //设置边界条件 for(int l=1;l<n;l++)//枚举区间长度 { for(int i=1;i+l<=n;i++)//枚举区间位置 { int t=i+l;//t表示区间的右端点 f[i][t]=f[i+1][t]+f[i][i];//先初始化为以i为根节点 root[i][t]=i;//同上 for(int j=i+1;j<=t;j++)//开始枚举根节点的位置 { if(f[i][t]<f[i][j-1]*f[j+1][t]+f[j][j]) //如果有更优的解,就更新f和root { f[i][t]=f[i][j-1]*f[j+1][t]+f[j][j]; root[i][t]=j; } } } } cout<<f[1][n]<<endl;//最后的答案就是f[1][n] print(1,n);//按照前序遍历输出 }