题面
https://www.luogu.org/problem/P3600
题解
洛谷试炼场里“概率期望”的题,我写了$4$个,这应该是第$5$个。
我查了一下,这是有出处的。见【洛谷2017春节联欢赛】Hello Dingyou 是这场比赛的$F$题,说明在这场比赛中还是一道很难的题。
但从$CF$赛制和题数来看,这道题如果选对了方法也应该是很好做的。
我们先观察一下答案的形式,一般的概率期望处理这些最大/最小的都没有很好的办法,但是因为是求最小的最大,这种形式在二分问题中出现的很多($mbox{NOIP2018day1T3}$),在这个概率期望/计数问题中,我们显然可以枚举答案。
发现如果是枚举答案恰好是$x_0$,众所周知并不是很好做。可以进一步的转换,枚举答案$le x_0$,算这样的方案数,这样的话,我们会有一步更优美的转化。
“所有答案的最大值$le x_0$,等价于所有答案$le x_0$”,神奇地把最大值去掉了。
假设我们已经求出了$P_i$(所有最终答案$le x_0$的序列的方案数),我们就有$$ans=frac{1}{x^n}sum{i(P_i-P_{i-1})}$$
另一个很显然的结论,若存在区间$A,B$,有$A subseteq B$,则$B$是没用的,因为$B$中最小值一定不比$A$中最小值更大,肯定不会是最终的最大值,所以就可以把所有这样的区间全去掉。(这是题解上说的,事实上,我的方法中,这个显然的结论不是必须的,可见性质的价值在一定程度上取决于发掘它的难度)
去掉了这样的区间之后,左端点有序则右端点有序,每个区间最小值$le x_0$,则对序列的要求是每一个区间至少有一个数$le x_0$。
直接令$f[i]$为$[1..i]$的序列,最后一个$le x_0$的地方在$i$处,然后就是一个简单的序列$dp$(前几天看的$mbox{ShichengXiao}$大佬告诉我的)
然后每次对$i$有影响的$j$是一段区间,直接算前缀和然后用类似哈希的方法更新就好了。
#include<cstdio> #include<cstring> #include<iostream> #include<vector> #define ri register int #define N 2050 #define mod 666623333 #define LL long long using namespace std; inline int read() { int ret=0,f=0; char ch=getchar(); while (ch<'0' || ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0' && ch<='9') ret*=10,ret+=ch-'0',ch=getchar(); return f?-ret:ret; } int n,x,q; int f[N],p[N]; int pre[N]; int sum[N]; vector<int> back[N]; inline int mul(int a,int b) { LL ret=a; ret*=b; return (int)(ret%mod); } inline int pow(int a,int b) { int ret=1; for (;b;b>>=1,a=mul(a,a)) if (b&1) ret=mul(ret,a); return ret; } int getval(int l,int r,int x0) { if (!l) return sum[r]; int ret=sum[r]-mul(sum[l-1],pow(x-x0,r-l+1)); if (ret<0) return ret+mod; else return ret; } int count(int x0) { memset(f,0,sizeof(f)); memset(sum,0,sizeof(sum)); f[0]=1; sum[0]=1; for (ri i=1;i<=n;i++) { f[i]=mul(mul(getval(pre[i-1],i-1,x0),1),x0); sum[i]=mul(sum[i-1],x-x0)+f[i]; if (sum[i]>=mod) sum[i]-=mod; } int ret=0; for (ri i=pre[n];i<=n;i++) { ret+=mul(f[i],pow(x-x0,n-i)); if (ret>=mod) ret-=mod; } return ret; } int main() { n=read(); x=read(); q=read(); for (ri i=1,l,r;i<=q;i++) { l=read(); r=read(); back[r].push_back(l); } pre[0]=0; for (ri i=1;i<=n;i++) { int ret=0; for (ri j=0;j<back[i].size();j++) if (back[i][j]>ret) ret=back[i][j]; pre[i]=max(pre[i-1],ret); } int ans=0; p[0]=0; for (ri i=1;i<=x;i++) p[i]=count(i); for (ri i=1;i<=x;i++) { int a=p[i]-p[i-1]; if (a<0) a+=mod; ans+=mul(a,i); if (ans>=mod) ans-=mod; } ans=mul(ans,pow(pow(x,n),mod-2)); cout<<ans<<endl; return 0; }