洛谷 P5326 [ZJOI2019]开关
https://www.luogu.com.cn/problem/P5326
Tutorial
https://www.luogu.com.cn/blog/xht37/solution-p5326
https://www.cnblogs.com/PinkRabbit/p/ZJOI2019D2T1.html
令(p_i=dfrac {p_i}{sum p_i})
设(f(x))表示在第(k)步到达合法状态的概率的生成函数,因为只关心第一次到达合法状态的情况,所以设(g(x))表示走(k)步后回到原来的状态的概率,(h(x))表示第(k)步第一次走到合法状态的概率,则有(f(x)=g(x)h(x) o h(x)=dfrac{f(x)}{g(x)}) .设(h(x)=sum a_k x^k),则我们要求就是
[sum ka_k=h'(1)=dfrac{f'(1)g(1)-f(1)g'(1)}{g^2(1)}
]
考虑如何求(f(x)).到达合法状态的条件为选择开关(i)的次数与(s_i)相等.则有
[F_i(x)=dfrac{e^{p_ix}+(-1)^{s_i}e^{-p_ix}}2
]
发现(f(x))是OGF,(F_i(x))为EGF,为了相互转化,将(prod F_i(x))表示为(sum c_k(e^x)^k)的形式,其中(c_k)可以用背包在(O(nsum p))的时间求得,最后得到
[egin{align}
f(x)&=sum_k ([x^k]k!sum_i c_i(e^x)^i)x^k \
&=sum_k(k!sum_i c_i [x^k](e^x)^i)x^k \
&=sum_k(k!sum_ic_idfrac{i^k}{k!})x^k \
&=sum_k (sum_i c_ii^k)x^k \
&=sum_ic_isum_{k}i^kx^k \
&=sum_idfrac{c_i}{1-ix}
end{align}
]
(g(x))的处理类似,最后得到
[g(x)=sum_idfrac{d_i}{1-ix}
]
但是发现当(i=1)时会有(1-x)这一项,所以不能直接将(x=1)带入,考虑分子分母同乘((1-x)),得到新的(f(x),g(x))
[f(x)=sum_idfrac{c_i(1-x)}{1-ix}=c_1+sum_{i
ot=1}dfrac{c_i(1-x)}{1-ix}
]
所以此时(f(1)=c_1)
[f'(x)=sum_{i
ot=1}dfrac{c_i(ix-1)+ic_i(1-x)}{(1-ix)^2} \
f'(1)=sum_{i
ot=1}dfrac{c_i(i-1)}{(1-i)^2}=sum_{i
ot=1}dfrac{c_i}{i-1}
]
(g(1),g'(1))也类似计算,即可得到答案.
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define inver(a) power(a,mod-2)
using namespace std;
inline char gc() {
// return getchar();
static char buf[100000],*l=buf,*r=buf;
return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void rd(T &x) {
x=0; int f=1,ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=gc();}
x*=f;
}
typedef long long ll;
const int mod=998244353,r2=(mod+1)>>1;
const int maxn=100+5,maxP=1e5+50;
int n,P,s[maxn],p[maxn];
int c[2][maxP],d[2][maxP];
inline int sub(int x) {return x<0?x+mod:x;}
ll power(ll x,ll y) {
ll re=1;
while(y) {
if(y&1) re=re*x%mod;
x=x*x%mod;
y>>=1;
}
return re;
}
inline int sqr(int x) {return (ll)x*x%mod;}
inline void upd(int *a,int *b,int v,int w) {
for(int i=0;i<=(P<<1);++i) if(b[i]) {
a[i+w]=(a[i+w]+(ll)v*b[i])%mod;
}
}
int main() {
rd(n);
for(int i=1;i<=n;++i) rd(s[i]);
for(int i=1;i<=n;++i) rd(p[i]),P+=p[i];
int cur=0;
c[cur][P]=d[cur][P]=1;
for(int i=1;i<=n;++i) {
cur^=1;
memset(c[cur],0,sizeof(c[cur])),memset(d[cur],0,sizeof(d[cur]));
upd(c[cur],c[cur^1],r2,p[i]),upd(c[cur],c[cur^1],(ll)r2*(s[i]==1?mod-1:1)%mod,-p[i]);
upd(d[cur],d[cur^1],r2,p[i]),upd(d[cur],d[cur^1],r2,-p[i]);
}
int an=0,c1=c[cur][P<<1],d1=d[cur][P<<1],t=inver(P);
for(int i=-P;i<P;++i) {
an=(an+inver(sub((ll)i*t%mod-1))*sub((ll)c[cur][i+P]*d1%mod-(ll)c1*d[cur][i+P]%mod))%mod;
}
an=(ll)an*sqr(inver(d1))%mod;
printf("%d
",an);
return 0;
}