题目:https://loj.ac/problem/2980
线段树维护矩阵。
然后是 30 分。似乎是被卡常了?……
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long #define ls Ls[cr] #define rs Rs[cr] using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=2e5+5e4+5,M=5e5+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int n,m,a[N],b[N],c[N],tot,Ls[M],Rs[M]; struct Mtr{ int a[4][4]; Mtr(){memset(a,0,sizeof a);} void init(){for(int i=0;i<4;i++)a[i][i]=1;} Mtr operator* (const Mtr &b)const { Mtr c; for(int i=0;i<4;i++) for(int k=0;k<4;k++) for(int j=0;j<4;j++) c.a[i][j]=(c.a[i][j]+(ll)a[i][k]*b.a[k][j])%mod; return c; } Mtr operator+ (const Mtr &b)const { Mtr c=*this; for(int i=0;i<3;i++)c.a[0][i]=upt(c.a[0][i]+b.a[0][i]); return c; } bool operator== (const Mtr &b)const { for(int i=0;i<4;i++) for(int j=0;j<4;j++)if(a[i][j]!=b.a[i][j])return false; return true; } }Ml[5],vl[M],tg[M],I,tp; void build(int l,int r,int cr) { if(l==r) { vl[cr].a[0][0]=a[l]; vl[cr].a[0][1]=b[l]; vl[cr].a[0][2]=c[l]; vl[cr].a[0][3]=1; return; } int mid=l+r>>1; tg[cr].init(); ls=++tot; build(l,mid,ls); rs=++tot; build(mid+1,r,rs); vl[cr]=vl[ls]+vl[rs]; vl[cr].a[0][3]=r-l+1; } void init() { tot=1; build(1,n,1); for(int i=1;i<=3;i++)Ml[i].init(); Ml[1].a[1][0]=1; Ml[2].a[2][1]=1; Ml[3].a[0][2]=1; I.init(); } void pshd(int cr) { if(tg[cr]==I)return; tg[ls]=tg[ls]*tg[cr]; tg[rs]=tg[rs]*tg[cr]; vl[ls]=vl[ls]*tg[cr]; vl[rs]=vl[rs]*tg[cr]; tg[cr]=I; } void mdfy(int l,int r,int cr,int L,int R,int op,int v) { if(l>=L&&r<=R) { if(op<=3)tp=Ml[op]; if(op==4){ tp=I; tp.a[3][0]=v;} if(op==5){ tp=I; tp.a[1][1]=v;} if(op==6){ tp=I; tp.a[2][2]=0; tp.a[3][2]=v;} tg[cr]=tg[cr]*tp; vl[cr]=vl[cr]*tp; return; } int mid=l+r>>1; pshd(cr); if(L<=mid)mdfy(l,mid,ls,L,R,op,v); if(mid<R)mdfy(mid+1,r,rs,L,R,op,v); vl[cr]=vl[ls]+vl[rs]; vl[cr].a[0][3]=r-l+1; } Mtr qry(int l,int r,int cr,int L,int R) { if(l>=L&&r<=R)return vl[cr]; int mid=l+r>>1; pshd(cr); if(L>mid)return qry(mid+1,r,rs,L,R); if(R<=mid)return qry(l,mid,ls,L,R); return qry(l,mid,ls,L,R)+qry(mid+1,r,rs,L,R); } int main() { n=rdn(); for(int i=1;i<=n;i++)a[i]=rdn(),b[i]=rdn(),c[i]=rdn(); init(); m=rdn(); int op,l,r,v=0; while(m--) { op=rdn();l=rdn();r=rdn(); if(op==7) { Mtr ans=qry(1,n,1,l,r); printf("%d %d %d ",ans.a[0][0],ans.a[0][1],ans.a[0][2]); } if(op>=4&&op<=6)v=rdn(); if(op<7)mdfy(1,n,1,l,r,op,v); } return 0; }
1.行向量乘4*4矩阵,写成 42 而不是 43 的。
2.矩阵乘法的时候,一定要多写很多 if(a[i][k]) 之类的判断。会快一大截!
3.查询的时候,无需返回一个矩阵,可以用全局变量去累计答案。
4.快速输出。
然后还是 95 分……但是把其他 AC 代码交上去,TLE得更严重。这到底是……
UPD(2019.5.18):再交了一次,卡过了。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long #define ls Ls[cr] #define rs Rs[cr] using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int g[20]; void wrt(int x) { int t=0; while(x)g[++t]=x%10,x/=10; for(int i=t;i;i--)putchar(g[i]+'0'); } const int N=2e5+5e4+5,M=5e5+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int n,m,tot,Ls[M],Rs[M],vl[M][5],s0,s1,s2,tp[5]; bool lz[M]; struct Mtr{ int a[4][4]; Mtr(){memset(a,0,sizeof a);} void init() { for(int i=0;i<4;i++)a[i][i]=1;} Mtr operator* (const Mtr &b)const { Mtr c; for(int i=0;i<4;i++) for(int k=0;k<4;k++) if(a[i][k])/// for(int j=0;j<4;j++) if(b.a[k][j]) c.a[i][j]=(c.a[i][j]+(ll)a[i][k]*b.a[k][j])%mod; return c; } Mtr operator+ (const Mtr &b)const { Mtr c=*this; for(int i=0;i<3;i++)c.a[0][i]=upt(c.a[0][i]+b.a[0][i]); return c; } bool operator== (const Mtr &b)const { for(int i=0;i<4;i++) for(int j=0;j<4;j++)if(a[i][j]!=b.a[i][j])return false; return true; } }Ml[5],tg[M],I,trans; void Mul(int cr,Mtr b) { for(int j=0;j<3;j++)tp[j]=vl[cr][j],vl[cr][j]=0; tp[3]=vl[cr][3]; for(int k=0;k<4;k++) if(tp[k]) for(int j=0;j<3;j++)//3 if(b.a[k][j]) vl[cr][j]=(vl[cr][j]+(ll)tp[k]*b.a[k][j])%mod; } void Plu(int cr,int b) { for(int j=0;j<3;j++)tp[j]=upt(vl[cr][j]+vl[b][j]); } void build(int l,int r,int cr) { if(l==r) { vl[cr][0]=rdn(); vl[cr][1]=rdn(); vl[cr][2]=rdn(); vl[cr][3]=1; return; } int mid=l+r>>1; tg[cr].init(); ls=++tot; build(l,mid,ls); rs=++tot; build(mid+1,r,rs); Plu(ls,rs); for(int j=0;j<3;j++)vl[cr][j]=tp[j]; vl[cr][3]=r-l+1; } void init() { tot=1; build(1,n,1); for(int i=1;i<=3;i++)Ml[i].init(); Ml[1].a[1][0]=1; Ml[2].a[2][1]=1; Ml[3].a[0][2]=1; I.init(); } void pshd(int cr) { if(!lz[cr])return; lz[cr]=0; tg[ls]=tg[ls]*tg[cr]; tg[rs]=tg[rs]*tg[cr]; Mul(ls,tg[cr]); Mul(rs,tg[cr]);// tg[cr]=I; lz[ls]=lz[rs]=1; } void mdfy(int l,int r,int cr,int L,int R) { if(l>=L&&r<=R) { tg[cr]=tg[cr]*trans; Mul(cr,trans);// lz[cr]=1; return; } int mid=l+r>>1; pshd(cr); if(L<=mid)mdfy(l,mid,ls,L,R); if(mid<R)mdfy(mid+1,r,rs,L,R); Plu(ls,rs); for(int j=0;j<3;j++)vl[cr][j]=tp[j]; } void qry(int l,int r,int cr,int L,int R) { if(l>=L&&r<=R) { s0=upt(s0+vl[cr][0]); s1=upt(s1+vl[cr][1]); s2=upt(s2+vl[cr][2]); return; } int mid=l+r>>1; pshd(cr); if(L<=mid)qry(l,mid,ls,L,R); if(mid<R)qry(mid+1,r,rs,L,R); } int main() { n=rdn(); init(); m=rdn(); int op,l,r,v=0; while(m--) { op=rdn();l=rdn();r=rdn(); if(op==7) { s0=s1=s2=0; qry(1,n,1,l,r); wrt(s0);putchar(' '); wrt(s1);putchar(' '); wrt(s2);puts(""); continue; } if(op>=4&&op<=6) { v=rdn(); trans=I; if(op==4) trans.a[3][0]=v; if(op==5) trans.a[1][1]=v; if(op==6){ trans.a[2][2]=0; trans.a[3][2]=v;} } else trans=Ml[op]; if(op<7)mdfy(1,n,1,l,r); } return 0; }