• 玩游戏——生成函数


    题面

      洛谷P4705

    解析

      答案显然是$frac{sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k}{n*m}$

      因此只需要求出$sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k$即可

      暴力展开:$$egin{align*}sum_{i=1}^nsum_{j=1}^m (a_i+b_j)^k&=sum_{i=1}^nsum_{j=1}^msum_{p=0}^kinom{k}{p}a_i^p*b_j^{k-p}\ &=k!sum_{p=0}^ksum_{i=1}^nfrac{a_i^p}{p!}sum_{j=1}^mfrac{b_j^{k-p}}{(k-p)!}\&=k!sum_{p=0}^kfrac{sum_{i=1}^na_i^p}{p!}frac{sum_{j=1}^mb_j^{k-p}}{(k-p)!}end{align*}$$

      现在就是要求对于任一$1leqslant p leqslant k$,$sum_{i=1}^na_i^p$(求$sum_{j=1}^mb_j^{k-p}$是类似的)

      这个比较常见,我在生成函数小结里有写,这里直接给出结论:$$egin{align*}F(x)&=sum_{j=0}^{infty}sum_{i=1}^na_i^jx^j\&=n-xln'(prod_{i=1}^n(1-a_ix))end{align*}$$

      $prod_{i=1}^n(1-a_ix)$可以分治$NTT$

      对$a$、$b$分别求出它们的$F(x)$,第$i$项系数除以$i!$后卷积起来。卷积后的第$i$项系数乘以$i!$再除以$n*m$就是答案。

      $O(Nlog^2N)$

     代码:

    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    #include<cstring>
    #include<vector>
    #define ls (x << 1)
    #define rs ((x << 1) | 1)
    using namespace std;
    typedef long long ll;
    const int maxn = 200005, mod = 998244353, g = 3;
    
    inline int read()
    {
        int ret, f=1;
        char c;
        while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1;
        ret=c-'0';
        while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0';
        return ret*f;
    }
    
    int add(int x, int y)
    {
        return x + y < mod? x + y: x + y - mod;
    }
    
    int rdc(int x, int y)
    {
        return x - y < 0? x - y + mod: x - y;
    }
    
    ll qpow(ll x, int y)
    {
        ll ret = 1;
        while(y)
        {
            if(y&1)
                ret = ret * x % mod;
            x = x * x % mod;
            y >>= 1;
        }
        return ret;
    }
    
    int n, m, a[maxn], b[maxn], lim, bit, rev[maxn<<1];
    ll fac[maxn], fnv[maxn];
    ll ginv, c[maxn<<1], d[maxn<<1], A[maxn<<1], B[maxn<<1], t[maxn<<1], iv[maxn<<1];
    
    void init()
    {
        ginv = qpow(g, mod - 2);
        fac[0] = 1;
        for(int i = 1; i <= 100001; ++i)
            fac[i] = fac[i-1] * i % mod;
        fnv[100001] = qpow(fac[100001], mod - 2);
        for(int i = 100000; i >= 0; --i)
            fnv[i] = fnv[i+1] * (i + 1) % mod;
    }
    
    void NTT_init(int x)
    {
        lim = 1;
        bit = 0;
        while(lim <= x)
        {
            lim <<= 1;
            ++ bit;
        }
        for(int i = 1; i < lim; ++i)
            rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1));
    }
    
    void NTT(ll *x, int y)
    {
        for(int i = 1; i < lim; ++i)
            if(i < rev[i])
                swap(x[i], x[rev[i]]);
        ll wn, w, u, v;
        for(int i = 1; i < lim; i <<= 1)
        {
            wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1));
            for(int j = 0; j < lim; j += (i << 1))
            {
                w = 1;
                for(int k = 0; k < i; ++k)
                {
                    u = x[j+k];
                    v = x[j+k+i] * w % mod;
                    x[j+k] = add(u, v);
                    x[j+k+i] = rdc(u, v);
                    w = w * wn % mod;
                }
            }
        }
        if(y == -1)
        {
            ll linv = qpow(lim, mod - 2);
            for(int i = 0; i < lim; ++i)
                x[i] = x[i] * linv % mod;
        }
    }
    
    void get_inv(ll *x, ll *y, int len)
    {
        if(len == 1)
        {
            x[0] = qpow(y[0], mod - 2);
            return ;
        }
        get_inv(x, y, (len + 1) >> 1);
        for(int i = 0; i < len; ++i)
            c[i] = y[i];
        NTT_init(len << 1);
        NTT(x, 1);
        NTT(c, 1);
        for(int i = 0; i < lim; ++i)
        {
            x[i] = rdc(add(x[i], x[i]), (c[i] * x[i] % mod) * x[i] % mod);
            c[i] = 0;
        }
        NTT(x, -1);
        for(int i = len; i < lim; ++i)
            x[i] = 0;
    }
    
    void get_ln(ll *x, ll *y, int len)
    {
        for(int i = 0; i < len; ++i)
            x[i] = y[i+1] * (i + 1) % mod;
        get_inv(iv, y, len);
        NTT_init(len << 1);
        NTT(x, 1);
        NTT(iv, 1);
        for(int i = 0; i < lim; ++i)
        {
            x[i] = x[i] * iv[i] % mod;
            iv[i] = 0;
        }
        NTT(x, -1);
        for(int i = len - 1; i >= 1; --i)
            x[i] = x[i-1] * qpow(i, mod - 2) % mod;
        x[0] = 0;
        for(int i = len; i < lim; ++i)
            x[i] = 0;
    }
    
    vector<int> G[maxn<<1];
    
    void solve(int x, int l, int r, int *y)
    {
        G[x].clear();
        if(l == r)
        {
            G[x].push_back(1);
            G[x].push_back(rdc(0, y[l]));
            return ;
        }
        int mid = (l + r) >> 1;
        solve(ls, l, mid, y);
        solve(rs, mid + 1, r, y);
        for(int i = 0; i <= mid - l + 1; ++i)
            c[i] = G[ls][i];
        for(int i = 0; i <= r - mid; ++i)
            d[i] = G[rs][i];
        NTT_init(r - l + 1);
        NTT(c, 1);
        NTT(d, 1);
        for(int i = 0; i < lim; ++i)
        {
            c[i] = c[i] * d[i] % mod;
            d[i] = 0;
        }
        NTT(c, -1);
        for(int i = 0; i <= r - l + 1; ++i)
        {
            G[x].push_back(c[i]);
            c[i] = 0;
        }
        for(int i = r - l + 2; i < lim; ++i)
            c[i] = 0;
    }
    
    int main()
    {
        init();
        n = read(); m = read();
        for(int i = 1; i <= n; ++i)
            a[i] = read();
        for(int i = 1; i <= m; ++i)
            b[i] = read();
        int q = read();
    
        solve(1, 1, n, a);
        for(int i = 0; i <= n; ++i)
            t[i] = G[1][i];
        get_ln(A, t, max(q, n) + 1);
        for(int i = 0; i <= max(q, n); ++i)
            A[i] = A[i+1] * (i + 1) % mod;
        for(int i = max(q, n); i >= 1; --i)
            A[i] = rdc(0, A[i-1]);
        A[0] = n;
        for(int i = 0; i <= max(q, n); ++i)
            A[i] = A[i] * fnv[i] % mod;
    
        solve(1, 1, m, b);
        memset(t, 0, sizeof(t));
        for(int i = 0; i <= m; ++i)
            t[i] = G[1][i];
        get_ln(B, t, max(q, m) + 1);
        for(int i = 0; i <= max(q, m); ++i)
            B[i] = B[i+1] * (i + 1) % mod;
        for(int i = max(q, m); i >= 1; --i)
            B[i] = rdc(0, B[i-1]);
        B[0] = m;
        for(int i = 0; i <= max(q, m); ++i)
            B[i] = B[i] * fnv[i] % mod;
    
        NTT_init(max(q, n) + max(q, m));
        NTT(A, 1);
        NTT(B, 1);
        for(int i = 0; i < lim; ++i)
            A[i] = A[i] * B[i] % mod;
        NTT(A, -1);
    
        ll mul = qpow(1LL * n * m % mod, mod - 2);
        for(int i = 1; i <= q; ++i)
            printf("%lld
    ", (A[i] * fac[i] % mod) * mul % mod);
        return 0;
    }
    View Code
  • 相关阅读:
    MySQL 一般模糊查询的几种用法
    MySQL插入中文数据报错
    BeanUtils.populate 的作用
    分分钟搞定 JSP 技术
    margin-top相对谁的问题
    常用汉字的Unicode码表
    从InputStream到String_写成函数
    Http请求和响应应用
    发布mvc报错:403.14-Forbidden Web 服务器被配置为不列出此目录的内容
    导出到excel
  • 原文地址:https://www.cnblogs.com/Joker-Yza/p/12640512.html
Copyright © 2020-2023  润新知