• [Luogu4705] 玩游戏


    Description

    给定两个长度分别为 (n)(m) 的序列 (a)(b)。要从这两个序列中分别随机一个数,设为 (a_x,b_y),定义该次游戏的 (k) 次收益为 ((a_x+b_y)^k) 。对于 (i=1,2,dots,t),求一次游戏 (i) 次收益的期望。(n,m,tleq 10^5)

    Sol

    根据期望的线性性,显然可以求每个点对的 (i) 次收益,最后再除以 (nm) 就好了。

    所以问题转化为,对于每个 (k),求:

    [sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k ]

    接下来直接推导:

    [egin{aligned} ans_k&=sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k\ &=sum_{i=1}^nsum_{j=1}^msum_{p=0}^k inom kpa_i^pb_j^{k-p}\ &=sum_{p=0}^kinom kp left(sum_{i=1}^na_i^p ight) left(sum_{j=1}^mb_j^{k-p} ight)\ &=k!cdotsum_{p=0}^k left(sum_{i=1} ^n frac{a_i^p}{p!} ight) left(sum_{j=1}^mfrac{b_j^{k-p}}{(k-p)!} ight) end{aligned} ]

    发现这是个卷积式子,现在问题变成了如何求:

    [sum_{i=1}^n a_i^p ]

    (F(x)=prodlimits_{i=1}^n(1+a_ix),G(x)=ln(F(x)))

    那么:

    [egin{aligned} G(x)&=ln(prod_{i=1}^n 1+a_ix)\ &=sum_{i=1}^n ln(1+a_ix) end{aligned} ]

    (ln(1+a_ix)) 泰勒展开:

    [egin{aligned} G(x)&=sum_{i=1}^n ln(1+a_ix)\ &= sum_{i=1}^n sum_{k=1}^infty frac{(-1)^{k+1}}{k}cdot a_i^kcdot x^k\ &= sum_{k=1}^infty frac{(-1)^{k+1}}kcdot x^kcdot left( sum_{i=1}^n a_i^k ight) end{aligned} ]

    后边那项就是我们要求的。

    总结一下,先分治( ext{NTT})求出(F(x)),再取对数求出(G(x)),然后第 (k) 项乘上一个系数就是 (sumlimits_{i=1}^n a_i^k) 了。

    Code

    #pragma GCC optimize(2)
    #include<bits/stdc++.h>
    using namespace std;
    typedef double db;
    typedef long long ll;
    typedef vector<int> vec;
    const int N=262144+5;
    const int mod=998244353;
    #define pb push_back
    
    int w[2][N],in[N];
    int fac[N],ifac[N],A[N],B[N];
    int n,m,t,a[N],b[N],c[N],d[N];
    int lim,maxn,rev[N],tmpa[N],tmpb[N];
    
    int ksm(int a,int b=mod-2,int ans=1){
        while(b){
            if(b&1) ans=1ll*ans*a%mod;
            a=1ll*a*a%mod;b>>=1;
        } return ans;
    }
    
    void ntt(int *f,int g){
        for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            for(int R=mid<<1,j=0;j<lim;j+=R){
                for(int k=0;k<mid;k++){
                    int x=f[j+k],y=1ll*w[g][maxn/R*k]*f[j+k+mid]%mod;
                    f[j+k]=x+y>=mod?x+y-mod:x+y,f[j+k+mid]=x-y<0?x-y+mod:x-y;
                }
            }
        } if(g)
            for(int i=0;i<lim;i++) f[i]=1ll*f[i]*in[lim]%mod;
    }
    
    vec calc(int *a,int l,int r){
        if(l==r){vec now;now.pb(1);now.pb(a[l]);return now;}
        int mid=l+r>>1;
        vec L=calc(a,l,mid),R=calc(a,mid+1,r);
        lim=1;while(lim<=r-l+1) lim<<=1;
        for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
        for(int i=0;i<(int)L.size();i++) A[i]=L[i];
        for(int i=0;i<(int)R.size();i++) B[i]=R[i];
        ntt(A,0),ntt(B,0);
        for(int i=0;i<lim;i++) A[i]=1ll*A[i]*B[i]%mod;
        ntt(A,1); vec now;
        for(int i=0;i<=r-l+1;i++) now.pb(A[i]),A[i]=B[i]=0;
        for(int i=r-l+2;i<lim;i++) A[i]=B[i]=0;
        return now;
    }
    
    void solveinv(int *a,int *b,int len){
        if(len==1) return b[0]=ksm(a[0]),void();
        solveinv(a,b,len>>1); lim=len<<1;
        for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
        for(int i=len;i<lim;i++) A[i]=0;
        for(int i=0;i<len;i++) A[i]=a[i];
        ntt(A,0),ntt(b,0);
        for(int i=0;i<lim;i++)
            b[i]=1ll*b[i]*(2ll-1ll*A[i]*b[i]%mod+mod)%mod;
        ntt(b,1); for(int i=len;i<lim;i++) b[i]=0;
    }
    
    void ds(int *a,int *b,int n){
        for(int i=0;i<n;i++)
            b[i]=1ll*a[i+1]*(i+1)%mod;
        b[n]=0;
    }
    
    void jf(int *a,int n){
        for(int i=n;i;i--)
            a[i]=1ll*a[i-1]*in[i]%mod;
        a[0]=0;
    }
    
    void solveln(int *a,int *b,int n){
        memset(tmpa,0,sizeof tmpa);
        memset(tmpb,0,sizeof tmpb);
        lim=1;while(lim<n) lim<<=1;
        solveinv(a,tmpa,lim);
        lim=1;while(lim<n<<1) lim<<=1;
        for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
        ds(a,tmpb,n);
        ntt(tmpa,0),ntt(tmpb,0);
        for(int i=0;i<lim;i++) b[i]=1ll*tmpa[i]*tmpb[i]%mod;
        ntt(b,1); jf(b,n);
    }
    
    void init(int n){
        fac[0]=ifac[0]=1;
        for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
        ifac[n]=ksm(fac[n]);
        for(int i=n-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
    }
    
    signed main(){
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++) scanf("%d",&a[i]);
        for(int i=1;i<=m;i++) scanf("%d",&b[i]);
        scanf("%d",&t);
        init(t);
        maxn=1;while(maxn<=max(t<<1,n+m-2)) maxn<<=1;
        w[0][0]=w[1][0]=1; in[1]=1;
        w[0][1]=ksm(3,(mod-1)/maxn),w[1][1]=ksm((mod+1)/3,(mod-1)/maxn);
        for(int i=2;i<=maxn;i++)
        	in[i]=ksm(i),
            w[0][i]=1ll*w[0][i-1]*w[0][1]%mod,
            w[1][i]=1ll*w[1][i-1]*w[1][1]%mod;
        vec aa=calc(a,1,n),bb=calc(b,1,m);
        for(int i=0;i<=n;i++) c[i]=aa[i];
        for(int i=0;i<=m;i++) d[i]=bb[i];
        memset(a,0,sizeof a),memset(b,0,sizeof b);
        solveln(c,a,t); a[0]=n; // 注意这里的0次项 积分给消掉了 所以要特殊赋值
        solveln(d,b,t); b[0]=m; 
        for(int i=1;i<=t;i++){
            a[i]=1ll*a[i]*i%mod;
            b[i]=1ll*b[i]*i%mod;
            if(!(i&1)) a[i]=mod-a[i],b[i]=mod-b[i];
            a[i]=1ll*a[i]*ifac[i]%mod;
            b[i]=1ll*b[i]*ifac[i]%mod;
        }
        for(int i=t+1;i<lim;i++) a[i]=b[i]=0;
        lim=maxn;
        for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
        ntt(a,0),ntt(b,0);
        for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
        ntt(a,1);
        for(int inn=ksm(1ll*n*m%mod),i=1;i<=t;i++) 
            printf("%lld
    ",1ll*a[i]*fac[i]%mod*inn%mod);
        return 0;
    }
    
  • 相关阅读:
    cookie的设置、获取和删除封装
    原生javascript封装ajax和jsonp
    javascript模块化应用
    图解javascript this指向什么?
    学习bootstrap心得
    javascript使用两个逻辑非运算符(!!)的原因
    dubbo小教程
    JSTL与EL表达式(为空判断)
    自己整理的常用SQL Server 2005 语句、
    python基础:迭代器、生成器(yield)详细解读
  • 原文地址:https://www.cnblogs.com/YoungNeal/p/10394096.html
Copyright © 2020-2023  润新知