• [清华集训2017]生成树计数


    代码:

    #include <bits/stdc++.h>
    using namespace std;
    #define rep(i,h,t) for (int i=h;i<=t;i++)
    #define dep(i,t,h) for (int i=t;i>=h;i--)
    #define ll long long
    #define me(x) memset(x,0,sizeof(x))
    #define IL inline
    #define rint register int
    inline ll rd(){
        ll x=0;char c=getchar();bool f=0;
        while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
        while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
        return f?-x:x;
    }
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T>void maxa(T &x,T y)
    {
        if (y>x) x=y;
    }
    template<class T>void mina(T &x,T y)
    {
        if (y<x) x=y;
    }
    template<class T>void read(T &x)
    {
        int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
    }
    const int mo=998244353;
    ll fsp(int x,int y)
    {
        if (y==1) return x;
        ll ans=fsp(x,y/2);
        ans=ans*ans%mo;
        if (y%2==1) ans=ans*x%mo;
        return ans;
    }
    struct cp {
        ll x,y;
        cp operator +(cp B)
        {
            return (cp){x+B.x,y+B.y};
        }
        cp operator -(cp B)
        {
            return (cp){x-B.x,y-B.y};
        }
        ll operator *(cp B)
        {
            return x*B.y-y*B.x;
        }
        int half() { return y < 0 || (y == 0 && x < 0); }
    };
    struct re{
        int a,b,c;
    };
    const int N=6e5;
    const int G=3;
    int f[N],g[N],n;
    struct fft{
      int l,n,m;
      int r[N],a[N],b[N],w[N],inv[N];
      int C[N],D[N];
      fft()
      {
        inv[0]=inv[1]=1;
        rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; 
      }
      IL void ntt_init()
      {
        l=0; for (n=1;n<=m;n<<=1) l++;
        for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1)); 
      }
      IL void clear()
      {
          rep(i,0,n) a[i]=b[i]=0;
      }
      void ntt(int *a,int o)
      {
          for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
          for (int i=1;i<n;i<<=1)
          {
              int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
              rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
              for (int j=0;j<n;j+=(i*2))
                for (int k=0;k<i;k++)
                {
                    int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
                //    if (x<0||y<0) cerr<<x<<" "<<y<<endl; 
                    a[j+k]=x+y>mo?x+y-mo:x+y; 
                a[i+j+k]=x-y>=0?x-y:x-y+mo;
            //     a[j+k]=(x+y)%mo;
            //     a[i+j+k]=(x-y)%mo;
                }
        }
        if (o==-1)
        {
            reverse(&a[1],&a[n]);
            for (int i=0,inv=fsp(n,mo-2);i<n;i++)
               a[i]=1ll*a[i]*inv%mo;
        }
      }
      IL void getcj(int *C,int len)
      {
      //    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;      
          m=len*2; ntt_init();
          rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo;
          ntt(a,1); ntt(b,1);
          for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
          ntt(a,-1);
          for (int i=0;i<n;i++) C[i]=a[i];
          clear();
      }
      IL void getcj(int *A,int *B,int len)
      {
        m=len*2; ntt_init();
        for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;
        ntt(a,1); ntt(b,1);
        for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
        ntt(a,-1);
        for (int i=0;i<len;i++) B[i]=a[i];
        clear();
      }
      IL void getinv(int *A,int *B,int len)
      {
        if (len==1) { B[0]=fsp(A[0],mo-2); return; }
        getinv(A,B,(len+1)>>1);
        m=len*2; ntt_init();
        for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
        ntt(a,1); ntt(b,1);
        for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo;
        ntt(a,-1);
        for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; 
        clear();
      }
      IL void getsqrt(int *A,int *B,int len)
      {
        int inv2=fsp(2,mo-2);
        if (len==1) {B[0]=sqrt(A[0]); return;}
        getsqrt(A,B,(len+1)>>1);
        int C[N]={};
        getinv(B,C,len);
        getcj(A,C,len);
        for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo;
      }
      IL void getDao(int *a,int *b,int len)
      {
        for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo;
        b[len-1]=0;
      }
      IL void getjf(int *a,int *b,int len)
      {
        for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo;
        b[0]=0;
      }
      IL void getln(int *A,int *B,int len)
      {
      //  me(C); me(D);
        getDao(A,C,len);
        getinv(A,D,len);
        getcj(C,D,len);
        getjf(D,B,len);
        rep(i,0,len) C[i]=0,D[i]=0;
      }
      IL void getexp(int *A,int *B,int len)
      {
        if (len==1) {B[0]=1; return;}
        getexp(A,B,(len+1)>>1);
        int C[N];
        getln(B,C,len);
        for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo;
        C[0]=(C[0]+1)%mo;
        getcj(C,B,len);
      }
    }F;
    /*
    
    f[i]=sum f[j]*g[i-j]; 
    
    */
    /*
    int now[N];
    void solve(int h,int t)
    {
      if (h>=t) return; 
      if (t-h<=32)
      {
          rep(i,h,t)
            rep(j,h,i)
              f[i]=(f[i]+1ll*f[j]*g[i-j])%mo;
        return;
      }
      int mid=(h+t)/2;
      solve(h,mid);
      rep(i,h,mid) F.a[i-h]=f[i];
      rep(i,1,t-h) F.b[i]=g[i];
      F.getcj(now,(t-h+1)+(mid-h+1));
      rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo;
      solve(mid+1,t);
    }
    */
    int sum[N],now[N],a[N],b[N],c[N],d[N],e[N];
    ll jc[N],jc2[N];
    /*
    prod (1+a[i]x) 
    */ 
    void solve(int h,int t,int *a)
    {
        if (h==t) return;
        int mid=(h+t)/2;
        solve(h,mid,a); solve(mid+1,t,a);
        rep(i,h,mid) F.a[i-h+1]=a[i];
        rep(i,mid+1,t) F.b[i-mid]=a[i];
        F.a[0]=F.b[0]=1;
        F.getcj(now,(mid-h+2));
        rep(i,h,t) a[i]=now[i-h+1];
    }
    int sum3[N],sum4[N];
    int main()
    {
       ios::sync_with_stdio(false);
       int n,m;
       cin>>n>>m;
       ll ans=1;
       rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo;
       rep(i,1,n) sum[i]=(-a[i]+mo)%mo;
       solve(1,n,sum);
       sum[0]=1;
       F.getln(sum,sum,n+2);
       F.getDao(sum,sum,n+2);
       dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo;
       sum[0]=n;
       jc[0]=jc2[0]=1;
       rep(i,1,n) jc[i]=jc[i-1]*i%mo;
       jc2[n]=fsp(jc[n],mo-2);
       dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo;
       rep(i,0,n) a[i]=b[i]=c[i]=0;
       rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo;
       rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo;
       F.getln(c,e,n+1);
       rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo;
       F.getexp(e,c,n+1);
       F.getinv(b,d,n+1);
       F.getcj(a,d,n+1);
       rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo;
       F.getcj(c,d,n+1);
       ans=ans*d[n-2]%mo*jc[n-2]%mo;
       cout<<ans<<endl; 
       return 0;
    }
    View Code
  • 相关阅读:
    Spring----MyBatis整合
    VueRouter案列
    Vue-Router
    axios用法
    Fetch的使用
    Promise用法
    组件之间传值
    局部组件注册方式
    学习组件与模板
    如何实现new,call,apply,bind的底层原理。
  • 原文地址:https://www.cnblogs.com/yinwuxiao/p/15143666.html
Copyright © 2020-2023  润新知