• 把strassen乘法调出来了...


    完美...

    指针搞死我了

    ///
    /// Author: zball
    /// No rights reserved
    /// (Creative Commons CC0)
    ///
    #include <cstdio>
    #include <malloc.h>
    #include <cstring>
    #define maxn 10000100
    #define qm 1000000007
    namespace matrix_mem{
    	int *mem_cur,mem_start[maxn];
    	inline void init(){
    		memset(mem_start,0,sizeof(mem_start));
    		mem_cur=mem_start;
    	}
    	inline int* get(int size){
    		int* p=mem_cur;
    		mem_cur=mem_cur+size;
    		return p;
    	}
    //	inline void freea();
    }
    /*inline void matrix_mem::freea(){
    	free(mem_cur);
    }*/
    inline int fitQm(int n){
    	if(n>=qm) return n-qm; return n;
    }
    inline int fitQm1(int n){
    	if(n<0) return n+qm; return n;
    }
    namespace matrices{
    	int *mats[10][9];
    	#define compute_offset(n) (((n+sx)<<bxl)+sy)
    	struct submatrix{
    		//submatrix is a simple matrix package
    		int *t,sx,sy,xl,bxl;
    		inline submatrix(){}
    		inline submatrix(int* t,int sx,int sy,int xl,int bxl):t(t),sx(sx),sy(sy),xl(xl),bxl(bxl){}//for xl==yl case
    		inline void set(int* _t,int _sx,int _sy,int _xl,int _bxl){
    			t=_t,sx=_sx,sy=_sy,xl=_xl,bxl=_bxl;
    		}
    		inline submatrix(const submatrix& p,int _sx,int _sy){
    			t=p.t;
    			sx=_sx+p.sx;
    			sy=_sy+p.sy;
    			xl=p.xl-1;
    			bxl=p.bxl;
    		}
    		inline int* operator[](int n){
    			return t+compute_offset(n);//this operator returns a pointer points to the correct position
    		}
    		inline void makeIdentity(){
    			int l=1<<xl;
    			for(int i=0;i<l;++i) for(int j=0;j<l;++j) (*this)[i][j]=0;
    			for(int i=0;i<l;++i) (*this)[i][i]=1;
    		}
    	};
    	typedef submatrix sm;
    	sm curs[10][9];
    	inline void init(){
    		for(int i=1;i<10;++i){
    			for(int j=0;j<9;++j){
    				mats[i][j]=matrix_mem::get(1<<(i<<1));
    //				if(mats[i][j]==NULL) printf("ERROR %d %d
    ",i,j);
    				curs[i][j].set(mats[i][j],0,0,i,i);
    			}
    		}
    	}
    	inline int fitQm(int n){
    		if(n>=qm) return n-qm; return n;
    	}
    	inline int fitQm1(int n){
    		if(n<0) return n+qm; return n;
    	}
    	inline void add(sm& c,sm& a,sm& b){//b and c should be at the same size!!!IMPORTANT.
    		int l=1<<b.xl;
    		for(int i=0;i<l;++i) for(int j=0;j<l;++j) c[i][j]=fitQm(a[i][j]+b[i][j]);
    	}
    	inline void sub(sm& c,sm& a,sm& b){//b and c should be at the same size!!!IMPORTANT.
    		int l=1<<b.xl;
    		for(int i=0;i<l;++i) for(int j=0;j<l;++j) c[i][j]=fitQm1(a[i][j]-b[i][j]);
    	}
    	inline void transfer(sm& a,sm& b){
    		int l=1<<b.xl;
    		for(int i=0;i<l;++i) for(int j=0;j<l;++j) a[i][j]=b[i][j];
    	}
    	inline void transferPartial(sm& a,sm& b,int x,int y){
    		//transfer b to a[x][y] -> ...
    		submatrix reunify(a.t,a.sx+x,a.sy+y,b.xl,a.bxl);
    		//just to a coordinate convertion
    		transfer(reunify,b);
    	}
    	int q;
    	#define naive_threshold 4
    	inline void multiply_limb(sm& out,sm& a,sm& b){//2x2 , 4x4 or 8x8 limbs
    		int l=1<<a.xl;
    		for(int i=0;i<l;++i) for(int j=0;j<l;++j) out[i][j]=0;
    		for(int k=0;k<l;++k) for(int i=0;i<l;++i) for(int j=0;j<l;++j) out[i][j]=(out[i][j]+(long long)a[i][k]*b[k][j])%qm;
    	}
    	inline void print(sm p){
    		for(int i=0,_=1<<p.xl;i<_;++i){
    			for(int j=0;j<_;++j) printf("%d ",p[i][j]);
    			putchar('
    ');
    		}
    	}
    	#define ntmp curs[lm]
    	#define mult multiply_strassen
    	#define half(_t,a,x,y) _t.set(a.t,a.sx+(x),a.sy+(y),a.xl-1,a.bxl)
    	void multiply_strassen(sm c,sm a,sm b){
    		if(a.xl<=naive_threshold) multiply_limb(c,a,b); else {
    		int l=1<<a.xl;
    		for(int i=0;i<l;++i) for(int j=0;j<l;++j) c[i][j]=0;
    		int lm=a.xl-1;
    		submatrix A11(a,0,0),A12(a,0,1<<lm),A21(a,1<<lm,0),A22(a,1<<lm,1<<lm);
    		submatrix B11(b,0,0),B12(b,0,1<<lm),B21(b,1<<lm,0),B22(b,1<<lm,1<<lm);
    		submatrix C11(c,0,0),C12(c,0,1<<lm),C21(c,1<<lm,0),C22(c,1<<lm,1<<lm);
    		sub(ntmp[0],B12,B22);
    		mult(ntmp[1],A11,ntmp[0]);
    		add(ntmp[0],A11,A12);
    		mult(ntmp[2],ntmp[0],B22);
    		add(ntmp[0],A21,A22);
    		mult(ntmp[3],ntmp[0],B11);
    		sub(ntmp[0],B21,B11);
    		mult(ntmp[4],A22,ntmp[0]);
    		add(ntmp[0],A11,A22);
    		add(ntmp[8],B11,B22);
    		mult(ntmp[5],ntmp[0],ntmp[8]);
    		sub(ntmp[0],A12,A22);
    		add(ntmp[8],B21,B22);
    		mult(ntmp[6],ntmp[0],ntmp[8]);
    		sub(ntmp[0],A11,A21);
    		add(ntmp[8],B11,B12);
    		mult(ntmp[7],ntmp[0],ntmp[8]);
    		//M1...M7 are stored in ntmp[1]...ntmp[7].
    		/*
    			M1 = A11(B12 - B22)
    			M2 = (A11 + A12)B22
    			M3 = (A21 + A22)B11
    			M4 = A22(B21 - B11)
    			M5 = (A11 + A22)(B11 + B22)
    			M6 = (A12 - A22)(B21 + B22)
    			M7 = (A11 - A21)(B11 + B12)
    			C11 = M5 + M4 - M2 + M6
    			C12 = M1 + M2
    			C21 = M3 + M4
    			C22 = M5 + M1 - M3 - M7
    		*/
    /*		for(int i=1;i<=7;++i){
    			printf("
    M%d====================
    ",i);
    			print(ntmp[i]);
    		}*/
    		add(ntmp[0],ntmp[5],ntmp[4]);
    		sub(ntmp[8],ntmp[0],ntmp[2]);
    		add(C11,ntmp[8],ntmp[6]);
    		add(C12,ntmp[1],ntmp[2]);
    		add(C21,ntmp[3],ntmp[4]);
    		add(ntmp[0],ntmp[5],ntmp[1]);
    		sub(ntmp[8],ntmp[0],ntmp[3]);
    		sub(C22,ntmp[8],ntmp[7]);
    	}
    	}
    }
    using namespace matrices;
    int main(){
    	matrix_mem::init();
    	matrices::init();
    	submatrix a(matrix_mem::get(1<<10),0,0,5,5);
    	submatrix b(matrix_mem::get(1<<10),0,0,5,5);
    	submatrix c(matrix_mem::get(1<<10),0,0,5,5);
    //	submatrix c(a,2,1);
    	a.makeIdentity();
    	b.makeIdentity();
    	mult(c,a,b);
    	for(int i=0;i<32;++i){
    		for(int j=0;j<32;++j){
    			printf("%d ",c[i][j]);
    		}
    		putchar('
    ');
    	}
    	return 0;
    }
    
  • 相关阅读:
    Golang 接口(interface)
    Golang 结构体(struct)
    Golang fmt包介绍
    Golang的函数(func)
    Golang数据类型 (map)
    Golang 指针(pointer)
    Golang数据类型 切片(slice)
    操作系统学习笔记(五) 页面置换算法
    Python 元组、列表
    操作系统学习笔记(四) 存储模型和虚拟内存
  • 原文地址:https://www.cnblogs.com/tmzbot/p/4924757.html
Copyright © 2020-2023  润新知