• 多项式取模优化线性递推总结


    多项式取模优化线性递推总结

    声明:博主已退役,这是以前的总结,如有错误望指正,如有问题不妨看看别人的博客

    线性递推

    即对于数列({a})

    已知前(k)

    且对于任意(nge k)

    [a_n=sum_{i=0}^{k-1}f_ia_{n-1-i} ]

    其中({f})是一个已知的数列

    现在要求({a})的第(n)

    暴力是(O(n*k))

    如果(n)太大就会超时

    常用的优化方法是矩阵快速幂

    复杂度(O(k^3log n))

    但如果(k)比较大也会超时

    甚至还不如暴力

    (log n)已经很优秀了

    但是(k^3)实在太慢

    注意到根据上面的式子

    ({a})所有数都可以被({a_0,a_1,...,a_{k-1}})线性表示

    考虑已知(a_n)的线性表示如何求出(a_{2n})的线性表示

    这里应用一个性质

    [a_{n}=sum_{i=0}^{k-1}b_ia_i\ ]

    [a_{n+x}=sum_{i=0}^{k-1}b_ia_{i+x} ]

    所以

    [a_{2n}=sum_{i=0}^{k-1}b_ia_{n+i}\ =sum_{i=0}^{k-1}b_isum_{j=0}^{k-1}b_ja_{i+j}\ =sum_{i=0}^{2k-2}a_isum_{j=0}^{i}b_jb_{i-j}\ (这里令b_x=0(xge k)) ]

    这样就用({a_0,a_1,...,a_{2k-2}})线性表示了(a_{2n})

    只要知道({a_k,a_{k+1},...,a_{2k-2}})的线性表示然后带入即可

    这一步倒着依次带入

    复杂度优化为(O(k^2log n))

    Shlw loves matrix I

    #include<bits/stdc++.h>
    
    using namespace std;
    
    #define gc c=getchar()
    #define r(x) read(x)
    #define ll long long
    
    template<typename T>
    inline void read(T&x){
        x=0;T k=1;char gc;
        while(!isdigit(c)){if(c=='-')k=-1;gc;}
        while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
    }
    
    const int p=1000000007;
    const int N=2000;
    
    inline int add(int a,int b){
        a+=b;
        if(a>=p)a-=p;
        return a;
    }
    
    int n,k;
    
    int Tmp[N<<1];
    
    inline void mul(int* a,int *b,int* f){
        memset(Tmp,0,k<<3);
        for(int i=0;i<k;++i){
            for(int j=0;j<k;++j){
                Tmp[i+j]=add(Tmp[i+j],(ll)a[i]*b[j]%p);
            }
        }
        for(int i=(k<<1)-2;i>=k;--i){
            for(int j=0;j<k;++j){
                Tmp[i-j-1]=add(Tmp[i-j-1],(ll)Tmp[i]*f[j]%p);
            }
        }
        memcpy(a,Tmp,k<<2);
    }
    
    int base[N],ans[N];
    
    inline int solve(int* a,int* f,int n){
        if(n<k)return a[n];
        base[1]=ans[0]=1;
        for(;n;n>>=1){
            if(n&1)mul(ans,base,f);
            mul(base,base,f);
        }
        int ret=0;
        for(int i=0;i<k;++i)ret=add(ret,(ll)a[i]*ans[i]%p);
        return ret;
    }
    
    int a[N],f[N];
    
    int main(){
        r(n);r(k);
        for(int i=0;i<k;++i)r(f[i]),f[i]=add(f[i],p);
        for(int i=0;i<k;++i)r(a[i]),a[i]=add(a[i],p);
        printf("%d
    ",solve(a,f,n));
    }
    

    多项式取模在哪里?

    考虑上面代码的这一部分

    for(int i=(k<<1)-2;i>=k;--i){
        for(int j=0;j<k;++j){
            Tmp[i-j-1]=add(Tmp[i-j-1],(ll)Tmp[i]*f[j]%p);
        }
    }
    

    考虑消去第(n)位的时候

    相当于把多项式({-f_{k-1},-f_{k-2},...,-f_0,1})平移了(n-k)

    并从原数列中减去它的(Tmp_n)

    所以这段代码实际上是在对多项式({-f_{k-1},-f_{k-2},...,-f_0,1}​)取模

    于是复杂度可以优化至(O(klog k log n))

    【模板】线性递推

    #include<bits/stdc++.h>
    
    using namespace std;
    
    #define gc c=getchar()
    #define r(x) read(x)
    #define ll long long
    
    template<typename T>
    inline void read(T&x){
        x=0;T k=1;char gc;
        while(!isdigit(c)){if(c=='-')k=-1;gc;}
        while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
    }
    
    const int N=500000;
    const int p=998244353;
    const int g=3;
    
    inline int qpow(int a,int b){
        int ans=1;
        for(;b;b>>=1){
            if(b&1)ans=1ll*ans*a%p;
            a=1ll*a*a%p;
        }
        return ans;
    }
    
    namespace polynomial{
        int r[N];
        int NOW_LEN;
        inline void ntt(int *A,int len,int opt=1){
            if(len!=NOW_LEN)for(int i=0;i<len;++i)r[i]=(r[i>>1]>>1)|((i&1)*(len>>1));
            NOW_LEN=len;
            for(int i=0;i<len;++i)if(i<r[i])swap(A[i],A[r[i]]);
            for(int i=2;i<=len;i<<=1){
                int wn=qpow(g,(p-1)/i),n=i>>1;
                if(!opt)wn=qpow(wn,p-2);
                for(int j=0;j<len;j+=i){
                    int w=1;
                    for(int k=0;k<n;++k,w=1ll*w*wn%p){
                        int u=A[j+k],v=1ll*A[j+k+n]*w%p;
                        A[j+k]=(u+v)%p;
                        A[j+k+n]=(u-v+p)%p;
                    }
                }
            }
            if(!opt){
                int inv=qpow(len,p-2);
                for(int i=0;i<len;++i)A[i]=1ll*A[i]*inv%p;
            }
        }
        
        int Tmp_mul1[N],Tmp_mul2[N];
        inline void mul(int *A,int *B,int *C,int lenA,int lenB){
            int len=1,lenC=lenA+lenB-1;
            while(len<lenC)len<<=1;
            memcpy(Tmp_mul1,A,lenA<<2);
            memcpy(Tmp_mul2,B,lenB<<2);
            memset(Tmp_mul1+lenA,0,(len-lenA)<<2);
            memset(Tmp_mul2+lenB,0,(len-lenB)<<2);
            ntt(Tmp_mul1,len);ntt(Tmp_mul2,len);
            for(int i=0;i<len;++i)C[i]=1ll*Tmp_mul1[i]*Tmp_mul2[i]%p;
            ntt(C,len,0);
            memset(C+lenC,0,(len-lenC)<<2);
        }
        
        int Tmp_inv[N];
        inline void inverse(int *A,int *Inv,int len){
            memset(Inv,0,len<<2);
            Inv[0]=qpow(A[0],p-2);
            for(int i=2;i<=len;i<<=1){
                memcpy(Tmp_inv,A,i<<2);
                memset(Tmp_inv+i,0,i<<2);
                ntt(Inv,i<<1);ntt(Tmp_inv,i<<1);
                for(int k=0;k<i<<1;++k)Inv[k]=Inv[k]*(2-1ll*Inv[k]*Tmp_inv[k]%p+p)%p;
                ntt(Inv,i<<1,0);
                memset(Inv+i,0,i<<2);
            }
        }
        
        int A0[N],B0[N];
        inline void mod(int A[],int B[],int R[],int lenA,int lenB){
            int len=1,t=lenA-lenB+1;
            while(len<=t)len<<=1;
            reverse_copy(B,B+lenB,A0);
            inverse(A0,B0,len);
            reverse_copy(A,A+lenA,A0);
            mul(A0,B0,A0,t,t);
            reverse(A0,A0+t);
            for(len=1;len<(lenA<<1);len<<=1);
            copy(B,B+lenB,B0);
            mul(A0,B0,R,t,lenB);
            for(int i=0;i<lenB-1;++i)R[i]=(A[i]-R[i]+p)%p;
        }
    }
    
    int n,k;
    
    int Tmp[N<<1];
    
    inline void mul(int a[],int b[],int f[]){
        polynomial::mul(a,b,Tmp,k,k);
        polynomial::mod(Tmp,f,a,2*k,k+1);
    }
    
    int base[N],ans[N];
    
    inline int solve(int a[],int f[],int n){
        if(n<k)return a[n];
        
        reverse(f,f+k);
        for(int i=0;i<k;++i)f[i]=p-f[i];
        f[k]=1;
        
        base[1]=ans[0]=1;
        for(;n;n>>=1){
            if(n&1)mul(ans,base,f);
            mul(base,base,f);
        }
        int ret=0;
        for(int i=0;i<k;++i)ret=(ret+(ll)a[i]*ans[i]%p)%p;
        return ret;
    }
    
    int a[N],f[N];
    
    int main(){
    //	freopen(".in","r",stdin);
    //	freopen(".out","w",stdout);
        r(n);r(k);
        for(int i=0;i<k;++i)r(f[i]),f[i]=(f[i]+p)%p;
        for(int i=0;i<k;++i)r(a[i]),a[i]=(a[i]+p)%p;
        printf("%d
    ",solve(a,f,n));
    }
    
  • 相关阅读:
    Openssl命令详解
    Openssl命令详解
    Mac根目录下无法创建文件或目录
    解决 mysql from_base64 函数返回乱码的问题
    elementUI日期选择器 el-date-picker根据所选日期选择禁用
    el-dialog设置为点击弹窗以外的区域不自动关闭弹窗
    在vue项目中MD5加密的使用方法
    bower install 报错fatal: unable to access 'https://github.com/angular/bower-angular-touch.git/'类错误解决方法
    angular项目grunt serve报错Cannot find where you keep your Bower packages
    移动端开发--》适配各种机型样式大小
  • 原文地址:https://www.cnblogs.com/yicongli/p/11143002.html
Copyright © 2020-2023  润新知