• 题解 P5591 【小猪佩奇学数学】


    题意

    [sum_{i=0}^n binom{n}{i}p^ilfloorfrac ik floor pmod {998244353} ]

    (1 leq n,p <998244353,k in {2^{w}|0 leq w leq 20})

    题解

    首先知道一个结论([n|k]=sum_{i=0}^{n-1}omega_n^{ik}),下面将用这个柿子乱搞。

    [sum_{i=0}^n binom{n}{i}p^ilfloorfrac ik floor ]

    [=sum_{i=0}^n binom{n}{i}p^ifrac{i-imod k}{k} ]

    [=frac{1}{k}sum_{i=0}^n binom{n}{i}p^i(i-imod k) ]

    [=frac{1}{k}(sum_{i=0}^n binom{n}{i}p^ii-sum_{i=0}^n binom{n}{i}p^i(imod k)) ]

    目标求出里面那坨。把式子拆成两部分。

    (sum_{i=0}^n binom{n}{i}p^ii)

    首先不难发现

    [ binom{n}{i}i=frac{n!}{i!(n-i)!}i=frac{n(n-1)!}{(i-1)!(n-i)!}= binom{n-1}{i-1}n ]

    于是我们将其带进去。不过需要注意(i=0)可能会出现负数,拎出来特判发现是(0)

    [sum_{i=1}^n binom{n-1}{i-1}np^i ]

    (i+1)替换(i)

    [npsum_{i=0}^{n-1} binom{n-1}{i}p^{i} ]

    然后二项式定理就十分显然了。

    [np(p+1)^{n-1} ]

    (sum_{i=0}^n binom{n}{i}p^i(imod k))

    这部分就是复习白兔之舞了。

    [sum_{i=0}^nsum_{t=0}^{k-1}[imod k=t] binom{n}{i}p^it ]

    [sum_{i=0}^nsum_{t=0}^{k-1}[k|(i-t)] binom{n}{i}p^it ]

    把一开始的公式套进去

    [sum_{i=0}^nsum_{t=0}^{k-1}frac{1}{k}sum_{j=0}^{k-1}omega_k^{j(i-t)} binom{n}{i}p^it ]

    [frac{1}{k}sum_{i=0}^nsum_{t=0}^{k-1}sum_{j=0}^{k-1}omega_k^{ij}omega_k^{-tj} binom{n}{i}p^it ]

    [frac{1}{k}sum_{j=0}^{k-1}sum_{t=0}^{k-1}tomega_k^{-tj}sum_{i=0}^n binom{n}{i}omega_k^{ij}p^i ]

    后面那一串有点意思。

    [frac{1}{k}sum_{j=0}^{k-1}sum_{t=0}^{k-1}tomega_k^{-tj}sum_{i=0}^n binom{n}{i}(omega_k^jp)^i ]

    然后就把讨厌的循环(n)次弄没了

    [frac{1}{k}sum_{j=0}^{k-1}sum_{t=0}^{k-1}tomega_k^{-tj}(omega_k^jp+1)^n ]

    所以只需要对于(jin[0,k]))求出后面一串的值就行了。这里用Bluestein's Algorithm,(ij= binom{i+j}{2}- binom{i}{2}- binom{j}{2})

    [frac{1}{k}sum_{j=0}^{k-1}sum_{t=0}^{k-1}tomega_k^{- binom{t+j}{2}+ binom{t}{2}+ binom{j}{2}}(omega_k^jp+1)^n ]

    整理一下系数。

    [frac{1}{k}sum_{j=0}^{k-1}omega_k^ binom{j}{2}(omega_k^jp+1)^nsum_{t=0}^{k-1}tomega_k^{ binom{t}{2}} imes omega_k^{- binom{t+j}{2}} ]

    里面随便卷卷就好了。记(c_{k+j}= ext{后面一串}),(a_{k-i}=iomega_k^{ binom{i}{2}},b_{i}=omega_k^{- binom{i}{2}}),有:

    [c_{k+j}=sum_{t=0}^{k-1}a_{k-t}b_{t+j} ]

    卷积显而易见。

    再带回去。

    [frac{1}{k}sum_{j=0}^{k-1}omega_k^ binom{j}{2}(omega_k^jp+1)^nc_{k+j} ]

    代码

    #include<bits/stdc++.h>
    namespace in{
    	char buf[1<<21],*p1=buf,*p2=buf;
    	inline int getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
    	template <typename T>inline void read(T& t){
    		t=0;int f=0;char ch=getc();while (!isdigit(ch)){if(ch=='-')f = 1;ch=getc();}
    	    while(isdigit(ch)){t=t*10+ch-48;ch = getc();}if(f)t=-t;
    	}
    	template <typename T,typename... Args> inline void read(T& t, Args&... args){read(t);read(args...);}
    }
    namespace out{
    	char buffer[1<<21];int p1=-1;const int p2 = (1<<21)-1;
    	inline void flush(){fwrite(buffer,1,p1+1,stdout),p1=-1;}
    	inline void putc(const char &x) {if(p1==p2)flush();buffer[++p1]=x;}
    	template <typename T>void write(T x) {
    		static char buf[15];static int len=-1;if(x>=0){do{buf[++len]=x%10+48,x/=10;}while (x);}else{putc('-');do {buf[++len]=-(x%10)+48,x/=10;}while(x);}
    		while (len>=0)putc(buf[len]),--len;
    	}
    }
    using namespace std;
    template<const int mod>
    struct modint{
        int x;
        modint<mod>(int o=0){x=o;}
        modint<mod> &operator = (int o){return x=o,*this;}
        modint<mod> &operator +=(modint<mod> o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
        modint<mod> &operator -=(modint<mod> o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
        modint<mod> &operator *=(modint<mod> o){return x=1ll*x*o.x%mod,*this;}
        modint<mod> &operator ^=(int b){
            modint<mod> a=*this,c=1;
            for(;b;b>>=1,a*=a)if(b&1)c*=a;
            return x=c.x,*this;
        }
        modint<mod> &operator /=(modint<mod> o){return *this *=o^=mod-2;}
        modint<mod> &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
        modint<mod> &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
        modint<mod> &operator *=(int o){return x=1ll*x*o%mod,*this;}
        modint<mod> &operator /=(int o){return *this *= ((modint<mod>(o))^=mod-2);}
    	template<class I>friend modint<mod> operator +(modint<mod> a,I b){return a+=b;}
        template<class I>friend modint<mod> operator -(modint<mod> a,I b){return a-=b;}
        template<class I>friend modint<mod> operator *(modint<mod> a,I b){return a*=b;}
        template<class I>friend modint<mod> operator /(modint<mod> a,I b){return a/=b;}
        friend modint<mod> operator ^(modint<mod> a,int b){return a^=b;}
        friend bool operator ==(modint<mod> a,int b){return a.x==b;}
        friend bool operator !=(modint<mod> a,int b){return a.x!=b;}
        bool operator ! () {return !x;}
        modint<mod> operator - () {return x?mod-x:0;}
    	modint<mod> &operator++(int){return *this+=1;}
    };
    const int N=4e6+5;
    
    const int mod=998244353;
    const modint<mod> GG=3,Ginv=modint<mod>(1)/3,I=86583718;
    struct poly{
    	vector<modint<mod>>a;
    	modint<mod>&operator[](int i){return a[i];}
    	int size(){return a.size();}
    	void resize(int n){a.resize(n);}
    	void reverse(){std::reverse(a.begin(),a.end());}
    };
    int rev[N];
    inline poly one(){poly a;a.a.push_back(1);return a;}
    inline int ext(int n){int k=0;while((1<<k)<n)k++;return k;}
    inline void init(int k){int n=1<<k;for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));}
    inline void ntt(poly&a,int k,int typ){
    	int n=1<<k;
    	for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    	for(int mid=1;mid<n;mid<<=1){
    		modint<mod> wn=(typ>0?GG:Ginv)^((mod-1)/(mid<<1));
    		for(int r=mid<<1,j=0;j<n;j+=r){
    			modint<mod> w=1;
    			for(int k=0;k<mid;k++,w=w*wn){
    				modint<mod> x=a[j+k],y=w*a[j+k+mid];
    				a[j+k]=x+y,a[j+k+mid]=x-y;
    			}
    		}
    	}
    	if(typ<0){
    		modint<mod> inv=modint<mod>(1)/n;
    		for(int i=0;i<n;i++)a[i]*=inv;
    	}
    }
    inline poly operator*(poly a,poly b){
    	int n=a.size()+b.size()-1,k=ext(n);
    	a.resize(1<<k),b.resize(1<<k),init(k);
    	ntt(a,k,1);ntt(b,k,1);for(int i=0;i<(1<<k);i++)a[i]*=b[i];
    	ntt(a,k,-1),a.resize(n);return a;
    }
    typedef modint<mod>mint;
    int n=3,p=3,k=2;
    static mint fac[20];
    namespace solve{
    	mint w[N];
    	#define C2(i) (1ll*(i)*((i)-1)/2)
    	void run(){
    		w[0]=1;w[1]=mint(3)^((mod-1)/k);for(int i=2;i<k;i++)w[i]=w[i-1]*w[1];
    		mint ans=0;poly a,b,c;
    		a.resize(k);b.resize(2*k);
    		for(int i=0;i<k;i++)a[k-i]=w[C2(i)%k]*i;
    		for(int i=0;i<2*k;i++)b[i]=w[((-C2(i))%k+k)%k];
    		c=a*b;
    		for(int j=0;j<k;j++)
    			ans+=w[C2(j)%k]*((w[j%k]*p+1)^n)*c[k+j];
    		ans=ans/k;
    		out::write((((mint(p+1)^(n-1))*n*p-ans)/k).x);out::putc('
    ');
    	}
    }
    signed main(){
    	in::read(n,p,k);
    	solve::run();
    	out::flush();
    	return 0;
    }
    

    vector比较慢要O2才能过

  • 相关阅读:
    [IDEA]高级用法
    [TongEsb]Windows入门教程
    [MAC] JDK版本切换
    [下划线]换行自适应稿纸完美方案
    [IDEA] Jrebel热部署环境搭建Mac
    Python规范:用用assert
    Python规范:提高可读性
    Python规范:代码规范要注意
    计算机网络自顶向下方法--计算机网络中的安全.2
    Python进阶:程序界的垃圾分类回收
  • 原文地址:https://www.cnblogs.com/juruo-cjl/p/14253012.html
Copyright © 2020-2023  润新知