Description
给你(n+m)个询问,其中(n)个的答案是(Yes),(m)个的答案是(No),现在依次回答这些询问,每回答一个询问就告诉你听你回答对了还是没对,求最优策略下答对题目期望数量对(998244353)取模
Solution
个人感觉很棒的一题qwq
首先我们可以设出一个无脑dp:(f[n][m])表示有(n)个(Yes)和(m)个(No)情况下的答案,策略的话当然是哪个剩的多就猜哪个,一样多随便猜一个,那么我们可以得到转移:
然后我们可以将这个东西放到一个。。坐标系里面,横坐标对应(n),纵坐标对应(m),那么一种回答的方案就相当于从((n,m))出发到((0,0))的一条路径,考虑画一条(y=x)的直线,整个坐标系被这条线分成了两大部分,线上方的点都满足(y>x),下方则是(x>y),那么放回上面的式子里面,先不看概率只看贡献,会发现在下方横向的路径是有贡献的,上方纵向的路径是有贡献的
为了方便下面的描述,我们不妨令(n>=m),因为我们的策略中如果一样多就随便猜一个,所以从对角线上点转移出来的答案应该还要乘上(frac{1}{2})(随便猜有(frac{1}{2})的概率对,而其他的情况下猜什么是已经确定的了所以可以直接算贡献),这个比较不同所以我们考虑分开,先看那些确定的贡献
先考虑比较简单的(n=m)的情况:考虑一条从((n,n))到((0,0))的不碰到对角线的路径,这样的一种方案中所有的边的贡献都是确定的可以直接计算,会发现不管怎么走,每条路径一定会有(n)的贡献
那么再看(n>m)的情况:考虑一条从((n,m))到((0,0))的路径(可以经过对角线),我们按照触碰对角线的节点将这条路径划分成若干个部分,除了第一部分(也就是从((n,m))走到碰到的第一个对角线上的点的这段)以外,其他部分都可以看成是从对角线上某一个点出发,中途不经过对角线,在对角线上某个点结束的一段路程,其实也就是我们的(n=m)的那种情况,而在第一部分中,为了触碰到对角线,一定会横着走(n-m)段,也就是一定会有(n-m)的贡献,加上前面的那些部分,每条路径一定会有(n)的贡献(当(m>n)的情况下其实一样的,类似的这个时候就是一定会有(m)的贡献了)
所以,我们可以得到一个结论:确定的贡献为(max(n,m)),接下来真正受概率影响的就只有那些对角线上的点的贡献了
而这些点的贡献其实也很好计算,只要有一条路径经过对角线上的一个点,那么不管是横着走还是竖着走的,都有(frac{1}{2})的概率获得(1)的贡献,所以我们只要对于对角线上面的每一个点计算经过它的方案数,然后除以总的路径数量,再乘上(frac{1}{2})即可
mark:没事把这种-1转移的二维dp丢到坐标系里面转成路径什么的好像挺有用的
代码大概长这个样子
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=5*(1e5)+10,MOD=998244353,inv2=499122177;
int fac[N*2],invfac[N*2];
int n,m;
int mul(int x,int y){return 1LL*x*y%MOD;}
int plu(int x,int y){return (1LL*x+y)%MOD;}
int C(int n,int m){return n<m?0:mul(fac[n],mul(invfac[m],invfac[n-m]));}
int calc(int n,int m){return C(n+m,m);}
int ksm(int x,int y){
int ret=1,base=x;
for (;y;y>>=1,base=mul(base,base))
if (y&1) ret=mul(ret,base);
return ret;
}
void prework(int n){
fac[0]=1;
for (int i=1;i<=n;++i) fac[i]=mul(fac[i-1],i);
invfac[n]=ksm(fac[n],MOD-2);
for (int i=n-1;i>=0;--i) invfac[i]=mul(invfac[i+1],i+1);
}
void solve(){
int ans=0;
for (int i=1;i<=min(n,m);++i)
ans=plu(ans,mul(calc(i,i),calc(n-i,m-i)));
ans=mul(ans,ksm(calc(n,m),MOD-2));
ans=mul(ans,inv2);
printf("%d
",ans+max(n,m));
}
int main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
prework(n+m);
solve();
}