• FWT 学习总结


    我理解的FWT是在二元运算意义下的卷积

    目前比较熟练掌握的集合对称差卷积

    对于子集卷积和集合并卷积掌握不是很熟练(挖坑ing)

    那么就先来谈一谈集合对称差卷积吧

    所谓集合对称差卷积

    就是h(i)=sigma(g(j)*f(k))(j^k=i)

    首先一个很显然的事情是如下结论:

    证明就是如果S是空集,答案为1,否则设存在元素v,则(S交T)和(S交T^v)两两相消配对

    答案为0

    由于j^k=i,则一定存在j^k^i=0,所以我们可以用上面的式子化简卷积

    式子的化简显然是正确的,就是将判断符号带入之后利用乘法分配律分配

    这样我们定义:

    显然利用反演我们可以得到

    至于证明,我们将上面的式子带入下面的式子就可以啦

    之后我们再下面的式子代入最上面的式子

    可以得到

    这也就意味着如果我们可以完成快速沃尔什变换和逆变换

    我们就可以在O(n)的时间内把他们乘起来

    至于变换的过程,我们利用递推的思想可以得到:

    说的简单的一点就是我们FWT的时候利用倍增思想

    设左边得到的集合幂级数为tf0

    设右边得到的集合幂级数为tf1

    合并之后得到的集合幂级数为(tf0+tf1,tf0-tf1)

    至于逆变换

    得到的集合幂级数为((tf0+tf1)/2,(tf0-tf1)/2)

    由于有除法,所以FWT不能对特征为2的集合幂级数使用

    以上图片截图自vfk论文

    然后一道众人皆知的例题是SRM 518 Nim

    在[1,L]中可以重复的选取k个质数,问有多少种方案使得他们异或和为0

    我们可以构造集合幂级数,之后我们会发现选取两个我们只需要FWT一次就可以了

    选取k个我们就可以使用快速幂了,注意这里并不是对FWT做快速幂

    由于集合幂级数并没有次数界的问题,所以可以直接对每个点做快速幂

    之后再FWT还原即可

    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #include<cstdlib>
    using namespace std;
    
    typedef long long LL;
    const int maxn=200010;
    const int mod=1e9+7;
    int k,L,m,inv;
    bool vis[maxn];
    int p[maxn],cnt=0;
    int f[maxn];
    
    int pow_mod(int v,int p){
    	int tmp=1;
    	while(p){
    		if(p&1)tmp=1LL*tmp*v%mod;
    		v=1LL*v*v%mod;p>>=1;
    	}return tmp;
    }
    void Get_prime(){
    	for(int i=2;i<=L;++i){
    		if(!vis[i])p[++cnt]=i;
    		for(int j=1;j<=cnt;++j){
    			if(1LL*p[j]*i>L)break;
    			vis[i*p[j]]=true;
    			if(i%p[j]==0)break;
    		}
    	}return;
    }
    void FWT(int *A,int n,int flag){
    	for(int k=1;k<n;k<<=1){
    		int mk=(k<<1);
    		for(int j=0;j<n;j+=mk){
    			for(int i=0;i<k;++i){
    				int x=A[i+j],y=A[i+j+k];
    				A[i+j]=1LL*(x+y)*flag%mod;
    				A[i+j+k]=1LL*(x-y+mod)*flag%mod;
    			}
    		}
    	}return;
    }
    
    int main(){
    	freopen("srmnim.in","r",stdin);
    	freopen("srmnim.out","w",stdout);
    	scanf("%d%d",&k,&L);
    	Get_prime();inv=pow_mod(2,mod-2);
    	for(m=1;m<=L;m<<=1);
    	for(int i=1;i<=cnt;++i)f[p[i]]=1;
    	FWT(f,m,1);
    	for(int i=0;i<m;++i)f[i]=pow_mod(f[i],k);
    	FWT(f,m,inv);printf("%d
    ",f[0]);
    	return 0;
    }
    

    之后是2015年ACM北京赛区的题目

    The Celebration of Rabbits

    要求你先给每个数赋值在[0,m]范围内,之后再给每个数都加上同一个x,x在[L,R]范围内

    求有多少种方案使得最后的数的异或和为0(数的数量为奇数)

    首先我们证明一件事情:对于每一种方案中的x都是唯一的

    采用反证法,设不是唯一的

    设现在序列为a1,a2,a3……,且满足a1^a2^a3……=0

    若不是唯一的,则一定存在序列a1-k,a2-k,a3-k……,且他们异或和为0(k!=0)

    我们考虑最后一位,由于a1^a2^a3……=0,又因为数的数量为奇数

    所以末位为1的有偶数个,末位为0的有奇数个

    若k的末位为1,则序列的异或和不为0(因为会转换成末位为1的奇数个,末位为0的有偶数个)

    所以k的末位为0,对上一位不会有进位影响,同理我们可以用同样方法考虑上一位

    最终得到k=0,与假设不符,证毕

    既然x都是唯一的,那么我们可以枚举x,之后问题转换成在[x,x+m]中可以重复的取若干个数使得异或和为0

    FWT即可

    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<cstdlib>
    #include<algorithm>
    using namespace std;
    
    typedef long long LL;
    const int mod=1e9+7;
    int n,m,L,R;
    int ans=0,N,inv;
    int f[20010];
    int pow_mod(int v,int p){
    	int tmp=1;
    	while(p){
    		if(p&1)tmp=1LL*tmp*v%mod;
    		v=1LL*v*v%mod;p>>=1;
    	}return tmp;
    }
    void FWT(int *A,int n,int flag){
    	for(int k=1;k<n;k<<=1){
    		int mk=(k<<1);
    		for(int j=0;j<n;j+=mk){
    			for(int i=0;i<k;++i){
    				int x=A[i+j],y=A[i+j+k];
    				A[i+j]=1LL*(x+y)*flag%mod;
    				A[i+j+k]=1LL*(x-y+mod)*flag%mod;
    			}
    		}
    	}return;
    }
    
    int main(){
    	inv=pow_mod(2,mod-2);
    	while(scanf("%d%d%d%d",&n,&m,&L,&R)==4){
    		ans=0;
    		for(int i=L;i<=R;++i){
    			for(N=1;N<=i+m;N<<=1);
    			for(int j=0;j<N;++j){
    				if(j>=i&&j<=i+m)f[j]=1;
    				else f[j]=0;
    			}
    			FWT(f,N,1);
    			for(int j=0;j<N;++j)f[j]=pow_mod(f[j],(n<<1)+1);
    			FWT(f,N,inv);
    			ans=ans+f[0];
    			if(ans>=mod)ans-=mod;
    		}printf("%d
    ",ans);
    	}return 0;
    }
    

    codeforces 259 div1 D

    有2^m个点,每个点上有信息,每个时刻每个点会分别向跟他海明码距离为k的点传b(k)次他自己的信息

    问t时刻后每个点的信息量

    考试题的加强版本,而且数据范围非常全

    首先模数的问题可以同考试题的做法解决

    首先做法1,我们考虑每个时刻的传输并用集合幂级数表示

    设集合幂级数f(i)表示输入数组

    设集合幂级数g(j),满足g(j)=b(num(j)) num(j)即j的二进制表示中1的个数

    则下一时刻可得h(k)=sigma(f(i)*g(j))(i^j=k)

    之后t时刻的话快速幂即可,由于要写快速乘,时间复杂度O(nlog^2n)

    用一些奇技淫巧优化快速乘做到O(nlogn)

    #include<cstdio>
    #include<cstring>
    #include<cstdlib>
    #include<iostream>
    #include<algorithm>
    using namespace std;
    
    typedef long long LL;
    int n,m;
    int Num[1050010];
    LL t,mod;
    LL b[22];
    LL f[1050010],g[1050010];
    
    inline read(LL &num){
    	num=0;char ch=getchar();
    	while(ch<'!')ch=getchar();
    	while(ch>='0'&&ch<='9')num=num*10+ch-'0',ch=getchar();
    }
    inline LL mul(LL a,LL b) {
        LL tmp=(a*b-(LL)((long double)a/mod*b+1e-8)*mod);
        return tmp<0?tmp+mod:tmp;
    }
    LL pow_mod(LL v,LL p){
    	LL tmp=1;
    	while(p){
    		if(p&1)tmp=mul(tmp,v);
    		v=mul(v,v);p>>=1;
    	}return tmp;
    }
    void FWT(LL *A,int n,int flag){
    	for(int k=1;k<n;k<<=1){
    		int mk=(k<<1);
    		for(int j=0;j<n;j+=mk){
    			for(int i=0;i<k;++i){
    				LL x=A[i+j],y=A[i+j+k];
    				A[i+j]=(x+y)>>flag;
    				A[i+j+k]=(x-y+mod)>>flag;
    				if(A[i+j]>=mod)A[i+j]-=mod;
    				if(A[i+j+k]>=mod)A[i+j+k]-=mod;
    			}
    		}
    	}return;
    }
    
    int main(){
    	scanf("%d",&m);read(t);read(mod);
    	n=(1<<m);mod*=n;
    	for(int i=0;i<n;++i)read(f[i]),f[i]%=mod;
    	for(int i=0;i<=m;++i)read(b[i]),b[i]%=mod;
    	for(int i=0;i<n;++i){
    		Num[i]=Num[i>>1]+(i&1);
    		g[i]=b[Num[i]];
    	}
    	FWT(f,n,0);FWT(g,n,0);
    	for(int i=0;i<n;++i){
    		g[i]=pow_mod(g[i],t);
    		f[i]=mul(f[i],g[i]);
    	}
    	FWT(f,n,1);mod/=n;
    	for(int i=0;i<n;++i)printf("%I64d
    ",f[i]%mod);
    	return 0;
    }
    

    我们还有更优的做法

    由于每个点的贡献是可分的,我们不妨考虑每个点t时刻内的贡献系数

    同考试题可以证明根据海明码距离我们可以分出m+1个等价类

    问题在于我们要搞出单次转移的系数,之后矩阵乘法就可以了

    我们考虑海明码距离j对海明码距离i的转移,显然i中对海明码距离贡献为1的点有i个,贡献为0的点有m-i个

    设k=i-j,我们选取贡献为1的点x个并取反,选取贡献为0的点y个并取反,一定满足y=x-k

    这样操作后两个点的海明码距离为x+y

    系数的贡献显然就是b(x+y)*C(i,x)*C(m-i,y)

    其中C是组合数,我们枚举x算贡献就可以了

    我们矩阵乘法搞出系数来了之后同考试题的做法一样做FWT就可以了

    时间复杂度O(m^3logn+nlogn)

    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #include<cstdlib>
    using namespace std;
    
    typedef long long LL;
    int m,n;
    int Num[1050010];
    LL t,mod;
    LL C[22][22];
    LL f[1050010],g[1050010];
    LL b[22];
    struct Matrix{
    	LL a[22][22];
    	void clear(){memset(a,0,sizeof(a));}
    }A,ans;
    
    void pre_C(){
    	C[0][0]=1;
    	for(int i=1;i<=m;++i){
    		C[i][0]=C[i][i]=1;
    		for(int j=1;j<i;++j){
    			C[i][j]=C[i-1][j-1]+C[i-1][j];
    			if(C[i][j]>=mod)C[i][j]-=mod;
    		}
    	}return;
    }
    LL mul(LL a,LL b){
    	LL s=0;
    	while(b){
    		if(b&1){
    			s=s+a;
    			if(s>=mod)s-=mod;
    		}
    		a<<=1;
    		if(a>=mod)a-=mod;
    		b>>=1;
    	}return s;
    }
    Matrix operator *(const Matrix &A,const Matrix &B){
    	Matrix C;C.clear();
    	for(int i=0;i<=m;++i){
    		for(int j=0;j<=m;++j){
    			for(int k=0;k<=m;++k){
    				C.a[i][j]=C.a[i][j]+mul(A.a[i][k],B.a[k][j]);
    				if(C.a[i][j]>=mod)C.a[i][j]-=mod;
    			}
    		}
    	}return C;
    }
    Matrix pow_mod(Matrix v,LL p){
    	Matrix tmp;tmp.clear();
    	for(int i=0;i<=m;++i)tmp.a[i][i]=1;
    	while(p){
    		if(p&1)tmp=tmp*v;
    		v=v*v;p>>=1;
    	}return tmp;
    }
    void FWT(LL *A,int n,int flag){
    	for(int k=1;k<n;k<<=1){
    		int mk=(k<<1);
    		for(int j=0;j<n;j+=mk){
    			for(int i=0;i<k;++i){
    				LL x=A[i+j],y=A[i+j+k];
    				A[i+j]=(x+y)>>flag;
    				A[i+j+k]=(x-y+mod)>>flag;
    				if(A[i+j]>=mod)A[i+j]-=mod;
    				if(A[i+j+k]>=mod)A[i+j+k]-=mod;
    			}
    		}
    	}return;
    }
    
    int main(){
    	scanf("%d",&m);n=(1<<m);
    	scanf("%I64d%I64d",&t,&mod);mod*=n;
    	pre_C();
    	for(int i=0;i<n;++i)scanf("%I64d",&f[i]),f[i]%=mod;
    	for(int i=0;i<=m;++i)scanf("%I64d",&b[i]),b[i]%=mod;
    	for(int i=0;i<=m;++i){
    		for(int j=0;j<=m;++j){
    			int k=i-j;
    			for(int x=0;x<=i;++x){
    				int y=x-k;
    				if(y<0||y>m-i)continue;
    				int d=x+y;
    				A.a[j][i]+=mul(b[d],mul(C[i][x],C[m-i][y]));
    				if(A.a[j][i]>=mod)A.a[j][i]-=mod;
    			}
    		}
    	}
    	A=pow_mod(A,t);
    	ans.a[0][0]=1;
    	ans=ans*A;
    	for(int i=0;i<n;++i){
    		Num[i]=Num[i>>1]+(i&1);
    		g[i]=ans.a[0][Num[i]];
    	}
    	FWT(f,n,0);FWT(g,n,0);
    	for(int i=0;i<n;++i)f[i]=mul(f[i],g[i]);
    	FWT(f,n,1);mod/=n;
    	for(int i=0;i<n;++i)printf("%I64d
    ",f[i]%mod);
    	return 0;
    }
    

    总结:如何想到FWT?

    1、题目中的形式是类似卷积一样的东西,朴素做法O(n^2)算贡献

    2、贡献过程是二元运算关系

    3、通常情况下算异或和为0神马的QAQ

  • 相关阅读:
    mysql去重
    java 实现一套流程管理、流转的思路(伪工作流)
    js模块加载框架 sea.js学习笔记
    使用js命名空间进行模块式开发
    二叉树的基本操作实现(数据结构实验)
    学生信息管理系统-顺序表&&链表(数据结构第一次作业)
    计算表达式的值--顺序栈(数据结构第二次实验)
    使用seek()方法报错:“io.UnsupportedOperation: can't do nonzero cur-relative seeks”错误的原因
    seek()方法的使用
    python中如何打印某月日历
  • 原文地址:https://www.cnblogs.com/joyouth/p/5512848.html
Copyright © 2020-2023  润新知