• ZZNU 2182 矩阵dp (矩阵快速幂+递推式 || 杜教BM)


    题目链接:http://47.93.249.116/problem.php?id=2182

    题目描述

    河神喜欢吃零食,有三种最喜欢的零食,鱼干,猪肉脯,巧克力。他每小时会选择一种吃一包。

    不幸的是,医生告诉他,他吃这些零食的时候,如果在连续的三小时内他三种都吃了,并且在中间一小时

    吃的是巧克力,他就会食物中毒。并且,如果河神在连续三小时内吃到相同种类的食物,他就会不开心。

    假设每种类零食的数量都是无限的,那么如果经过n小时,让河神满意的零食吃法有多少种呢?(开心又不

    会食物中毒的吃法)答案可能过大,结果对1000000007取模。

    输入

    第一行一个T,代表测试实例数(T<=1000)

    第二行一个n,代表经过n个小时。(n<=1e10)

    输出

    每行一个结果,代表经过n小时河神满意的零食吃法的数量。

    样例输入

    3
    3
    4
    15

    样例输出

    22
    54
    1034422

    一道比较水的矩阵题,但因出题人语文素养过低,签到阅读题表意不清导致全场没时间开这题,emmm。。。。

    官方题解+AC代码:
    /*
    
    从数据量上可以很明显的看出,1e10必定不是暴力直接递推。联想到矩阵快速幂
    
     
    
    可以很明显的看出,前两项是不受规则影响的,所以前两项直接输出。
    
     
    
    从第三项开始是满足规则的状态转移(dp),因为从第三项开始才满足规则,
    
     
    
    因此第三项相当于dp序列的第一项,用快速幂计算时需要减去2.
    
     
    
    解决dp问题的关键在于能否把大问题分解为较小数量级的子问题,即可以子问题
    
     
    
    作为已知条件推出更多答案。
    
     
    
    对于初始规则矩阵的构造,假设我们用0,1,2分别表示三种零食,其中1代表
    
     
    
    巧克力,那么对于规则来说,第n小时的可选择零食只需看之前两个小时的
    
     
    
    选择(因为只需要判断连续三小时内是不是合法),为了方便书写,在代码
    
     
    
    中采用状态压缩的方法表示,那么代表每个小时选择的有三种,即压缩为三
    
     
    
    进制数即可。比如,21=(2*3^1+1*3^0=5)代表前两小时分别吃了鱼干和
    
     
    
    巧克力。那么对于当前状态5可达成的状态,只有11,12,因为规则中210是会
    
     
    
    食物中毒的(我们保留了后两位但在保留过程中以三位判断是否是合法的状态
    
     
    
    转移)。dp[n][k] = Σdp[n-1][u](所有合法的可以转移到k的u),矩阵幂完美
    
     
    
    实现任意状态的转移结果,最后把所有转移状态累加即得到答案。
    
    */
    
    #include<stdio.h>
    
    #include<string.h>
    
    #include<algorithm>
    
    #include<queue>
    
    using namespace std;
    const int MAXN = 9;
    const long long mod=1e9+7;
    struct Matrix
    {
        long long edge[MAXN][MAXN];
    };
    int N;
    void Mul(Matrix a, Matrix b, Matrix &ans)
    {
        memset(ans.edge,0,sizeof(ans.edge));
        for(int i=0; i<N; i++)
            for(int j=0; j<N; j++)
                for(int p=0; p<N; p++)
                {
                    ans.edge[i][j] += (a.edge[i][p] * b.edge[p][j])%mod;
                    ans.edge[i][j]%=mod;
                }
    
    }
    
    void QuickPow(Matrix Map, long long k, Matrix &ans)
    
    {
        memset(ans.edge,0,sizeof(ans.edge));
        for(int i=0; i<N; i++)ans.edge[i][i] = true;
        while(k)
        {
            if(k & 1)
                Mul(ans, Map, ans);
            Mul(Map, Map, Map);
            k /= 2;
        }
    }
    bool check(int i, int j)
    {
        int a = i%3, b = i/3%3;
        int c = j%3, d = j/3%3;
        if(a == b && b == c)
            return false;
        if(a!=d)
            return false;
        if(a==1 && b!=1 && c!=1 && b!=c)
            return false;
        return true;
    }
    int main()
    {
        //freopen("qq.txt","r",stdin);
       // freopen("pp.txt","w",stdout);
        int T; long long n;
        scanf("%d", &T);
        N = 9;
        while(T--)
        {
    
            scanf("%lld", &n);
            Matrix ans;
            memset(ans.edge, 0, sizeof(ans.edge));
            for(int i=0; i<9; i++)
            {
                for(int j=0; j<9; j++)
                {
                    if(check(i, j))
                        ans.edge[i][j] = 1;
                }
            }
            if(n<=2)
            {
                if(n == 1)
                    printf("3
    ");
                else    printf("9
    ");
            }
            else
            {
                QuickPow(ans, n-2, ans);
                long long p = 0;
                for(int i=0; i<9; i++)
                    for(int j=0; j<9; j++)
                    {
                        p = p+ans.edge[i][j];
                        p %= mod;
                    }
                printf("%lld
    ", p);
            }
        }
        return 0;
    }
    

    其实直接手推前几项,然后杜教BM板子一套,加个杜教读入挂,恐怖如斯

    #include<bits/stdc++.h>
    using namespace std;
    #define rep(i,a,n) for (int i=a;i<n;i++)
    #define per(i,a,n) for (int i=n-1;i>=a;i--)
    #define pb push_back
    #define mp make_pair
    #define all(x) (x).begin(),(x).end()
    #define fi first
    #define se second
    #define SZ(x) ((int)(x).size())
    typedef vector<int> VI;
    typedef long long ll;
    typedef pair<int,int> PII;
    const ll mod=1000000007;
    ll powmod(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
     //head
    namespace IO{
        #define BUF_SIZE 100000
        #define OUT_SIZE 100000
        #define ll long long
       // fread->read
    
        bool IOerror=0;
        inline char nc(){
            static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
            if (p1==pend){
                p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
                if (pend==p1){IOerror=1;return -1;}
                {printf("IO error!
    ");system("pause");for (;;);exit(0);}
            }
            return *p1++;
        }
        inline bool blank(char ch){return ch==' '||ch=='
    '||ch=='
    '||ch=='	';}
        inline void read(int &x){
            bool sign=0; char ch=nc(); x=0;
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            if (ch=='-')sign=1,ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
            if (sign)x=-x;
        }
        inline void read(ll &x){
            bool sign=0; char ch=nc(); x=0;
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            if (ch=='-')sign=1,ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
            if (sign)x=-x;
        }
        inline void read(double &x){
            bool sign=0; char ch=nc(); x=0;
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            if (ch=='-')sign=1,ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
            if (ch=='.'){
                double tmp=1; ch=nc();
                for (;ch>='0'&&ch<='9';ch=nc())tmp/=10.0,x+=tmp*(ch-'0');
            }
            if (sign)x=-x;
        }
        inline void read(char *s){
            char ch=nc();
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
            *s=0;
        }
        inline void read(char &c){
            for (c=nc();blank(c);c=nc());
            if (IOerror){c=-1;return;}
        }
    //    fwrite->write
        struct Ostream_fwrite{
            char *buf,*p1,*pend;
            Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
            void out(char ch){
                if (p1==pend){
                    fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
                }
                *p1++=ch;
            }
            void print(int x){
                static char s[15],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1);
            }
            void println(int x){
                static char s[15],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1); out('
    ');
            }
            void print(ll x){
                static char s[25],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1);
            }
            void println(ll x){
                static char s[25],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1); out('
    ');
            }
            void print(double x,int y){
                static ll mul[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,
                    1000000000,10000000000LL,100000000000LL,1000000000000LL,10000000000000LL,
                    100000000000000LL,1000000000000000LL,10000000000000000LL,100000000000000000LL};
                if (x<-1e-12)out('-'),x=-x;x*=mul[y];
                ll x1=(ll)floor(x); if (x-floor(x)>=0.5)++x1;
                ll x2=x1/mul[y],x3=x1-x2*mul[y]; print(x2);
                if (y>0){out('.'); for (size_t i=1;i<y&&x3*mul[i]<mul[y];out('0'),++i); print(x3);}
            }
            void println(double x,int y){print(x,y);out('
    ');}
            void print(char *s){while (*s)out(*s++);}
            void println(char *s){while (*s)out(*s++);out('
    ');}
            void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
            ~Ostream_fwrite(){flush();}
        }Ostream;
        inline void print(int x){Ostream.print(x);}
        inline void println(int x){Ostream.println(x);}
        inline void print(char x){Ostream.out(x);}
        inline void println(char x){Ostream.out(x);Ostream.out('
    ');}
        inline void print(ll x){Ostream.print(x);}
        inline void println(ll x){Ostream.println(x);}
        inline void print(double x,int y){Ostream.print(x,y);}
        inline void println(double x,int y){Ostream.println(x,y);}
        inline void print(char *s){Ostream.print(s);}
        inline void println(char *s){Ostream.println(s);}
        inline void println(){Ostream.out('
    ');}
        inline void flush(){Ostream.flush();}
        #undef ll
        #undef OUT_SIZE
        #undef BUF_SIZE
    };
    
    int _;
    long long n;
    namespace linear_seq {
        const int N=10010;
        ll res[N],base[N],_c[N],_md[N];
    
        vector<int> Md;
        void mul(ll *a,ll *b,int k) {
            rep(i,0,k+k) _c[i]=0;
            rep(i,0,k) if (a[i]) rep(j,0,k) _c[i+j]=(_c[i+j]+a[i]*b[j])%mod;
            for (int i=k+k-1;i>=k;i--) if (_c[i])
                rep(j,0,SZ(Md)) _c[i-k+Md[j]]=(_c[i-k+Md[j]]-_c[i]*_md[Md[j]])%mod;
            rep(i,0,k) a[i]=_c[i];
        }
        int solve(ll n,VI a,VI b) { // a 系数 b 初值 b[n+1]=a[0]*b[n]+...
            printf("%d
    ",SZ(b));
            ll ans=0,pnt=0;
            int k=SZ(a);
            assert(SZ(a)==SZ(b));
            rep(i,0,k) _md[k-1-i]=-a[i];_md[k]=1;
            Md.clear();
            rep(i,0,k) if (_md[i]!=0) Md.push_back(i);
            rep(i,0,k) res[i]=base[i]=0;
            res[0]=1;
            while ((1ll<<pnt)<=n) pnt++;
            for (int p=pnt;p>=0;p--) {
                mul(res,res,k);
                if ((n>>p)&1) {
                    for (int i=k-1;i>=0;i--) res[i+1]=res[i];res[0]=0;
                    rep(j,0,SZ(Md)) res[Md[j]]=(res[Md[j]]-res[k]*_md[Md[j]])%mod;
                }
            }
            rep(i,0,k) ans=(ans+res[i]*b[i])%mod;
            if (ans<0) ans+=mod;
            return ans;
        }
        VI BM(VI s) {
            VI C(1,1),B(1,1);
            int L=0,m=1,b=1;
            rep(n,0,SZ(s)) {
                ll d=0;
                rep(i,0,L+1) d=(d+(ll)C[i]*s[n-i])%mod;
                if (d==0) ++m;
                else if (2*L<=n) {
                    VI T=C;
                    ll c=mod-d*powmod(b,mod-2)%mod;
                    while (SZ(C)<SZ(B)+m) C.pb(0);
                    rep(i,0,SZ(B)) C[i+m]=(C[i+m]+c*B[i])%mod;
                    L=n+1-L; B=T; b=d; m=1;
                } else {
                    ll c=mod-d*powmod(b,mod-2)%mod;
                    while (SZ(C)<SZ(B)+m) C.pb(0);
                    rep(i,0,SZ(B)) C[i+m]=(C[i+m]+c*B[i])%mod;
                    ++m;
                }
            }
            return C;
        }
        int gao(VI a,ll n) {
            VI c=BM(a);
            c.erase(c.begin());
            rep(i,0,SZ(c)) c[i]=(mod-c[i])%mod;
            return solve(n,c,VI(a.begin(),a.begin()+SZ(c)));
        }
    };
    
    int main() {
        //freopen("int.txt","r",stdin);
       // freopen("out.txt","w",stdout);
        int t;
        IO::read(t);
        while(t--){
            IO::read(n);
            if(n==1) {IO::println("3");continue;}
            else if(n==2) {IO::println("9");continue;}
    
            vector<int>v;
            v.push_back(3);//前几项
            v.push_back(9);
            v.push_back(22);
            v.push_back(54);
            v.push_back(132);
            v.push_back(324);
            v.push_back(794);
            v.push_back(1946);
            v.push_back(4770);
            v.push_back(11692);
            v.push_back(28658);
            v.push_back(70244);
            v.push_back(172176);
            v.push_back(422022);
    
            //输入n ,输出第n项的值
            IO::println(linear_seq::gao(v,n-1));
            printf("%d
    ",linear_seq::gao(v,n-1));
        }
    }
    // 3 9 22 54 132 324 794 1946 4770 11692 28658 70244 172176 422022 1034422
    


  • 相关阅读:
    Discrete Logging
    P1378 油滴扩展
    P3390 【模板】矩阵快速幂
    RMQ算法
    P1372 又是毕业季I
    P1440 求m区间内的最小值
    高效判断素数方法
    阿尔贝喝我
    浙江大学PAT上机题解析之2-11. 两个有序链表序列的合并
    顺序队列之C++实现
  • 原文地址:https://www.cnblogs.com/weimeiyuer/p/9693241.html
Copyright © 2020-2023  润新知