我们大家都知道xor卷积有个很好的做法:FWT.FWT的变换形式是很好看的
// 说明一下Vector可以向量化运算,也可以当做数组来slice与concat
Vector tf(A,2^n){
Vector A0=A.slice(0,2^n/2-1);
Vector A1=A.slice(2^n/2,2^n-1);
A0=tf(A0,2^n/2);
A1=tf(A1,2^n/2);
return concat(A0+A1,A0-A1);
}
Array itf(A,2^n){
Vector A0=A.slice(0,2^n/2-1);
Vector A1=A.slice(2^n/2,2^n-1);
return concat(
itf((A0+A1)/2),
itf((A0-A1)/2)
);
}
但是为什么是这个形式呢,这个形式是怎么创造出来的呢..?注意到二进制上的xor相当于modular addition,那么FWT其实是在做一个高维的循环卷积.我们大家都知道可以用FFT来做一维的循环卷积,那么有没有想过怎样把FFT推广到高维呢..?
考虑FFT是什么样的变换:将一个函数用系数表示法转换成点值表示法,转换为哪些点值呢..?是x^n=1的复根.这样的转换具有很好的对称性因而可以极大地加速,同时这个转换本质上是对一个列向量乘上一个矩阵(求值的原理),我们发现它的逆矩阵长得也很规整也可以加速计算.
我们具体地表示出来..
Omega(n,m)=cos(2*pi*m/n)+i*sin(2*pi*m/n)
DFT(n,f)=PolyCoeff(f(Omega(n,i)) for i in range(0,n)])
IDFT(n,f)=PolyCoeff(f(Omega(n,n-i))/n for i in range(0,n)])
# 这种程度的相似性简直可以给满分
# 至于那个时域频域的转换其实两边Omega是共轭的..你可以感受一下那个正弦波的抵消过程..(其实我不是很懂这方面有可能说的有问题)
仿照类似的定义,我们发现它可以推广到高维(突然想到了这个问题2333 https://www.zhihu.com/question/54063300 )
ω(n)=cos(2*pi/n)+i*sin(2*pi/n) # 打Omega真的累..
DFT(d1,d2,...dn,f)=
ret=arraydim(d1,d2,...dn)
for i1,i2,...in in [0..d1-1],[0..d2-1],..[0,dn-1]:#那个in真心尴尬..凑合着看吧..还有那个中括号是总所周知的闭区间
ret[i1,i2...in]=f(ω(d1)^i1,ω(d2)^i2,...ω(dn)^in)
return MPolyCoeff(ret) # M是multivariate..话说为什么要在意这些细节
# IDFT类似的定义下吧,累死了
IDFT=Likewise()
那对于所有d都是二维的情况..
DFT(n,f)=
if(n==1) return f
f0=f.bind(1)
f1=f.bind(-1)
#对n-1个变量的函数求值
F0=MCoeff(DFT(n-1,f0))
F1=MCoeff(DFT(n-1,f1))
return MPolyCoeff([F0,F1])
IDFT=Likewise()
就这么简单..?是的,然而这个递归函数仍然不具备快速计算FWT的能力..我们需要更进一步
以下我们假设我们拥有的函数都是以系数数组的形式存储的(不然写MCoeff和MPolyCoeff好累啊)
我们看看这样的DFT究竟干了什么..
f(1,...)=f[0](...)+f[1](...)
f(-1,...)=f[0](...)+ω(2)*f[1](...)=f[0](...)-f[1](...)
真是尴尬..注意这里就可以递归的做FWT了..
然后IFWT太类似了..没什么说的必要了..
我们可以顺手类比一下,推广到高进制的情况..
f(1,...)=f[0](...)+f[1](...)+...+f[n-1](...)
f(ω(n),...)=f[0](...)+ω(n)f[1](...)+...+ω(n)^(n-1)f[n-1](...)
vdots # <= 乱入的AMSMath
Likewise # <= 乱入的懒得写..
然后我们发现这样每层要消耗d^2*n/d=n*d
的时间来合并..好废时间的感觉..在二维FFT上瞬间就上升到n1.5的复杂度了..总复杂度类似x维n(1+1/x)
然而这个合并本身也是一个FFT..经(bu)过(yong)一(nao)小(zi)点(dou)思(zhi)考(dao)我们可以用本身FFT的方式优化到一层d*n/d*log(d)
即F(n)=dF(n)+n*log(d)=n*log(d)*log_d(n)=n*log(d)*log(n)/log(d)=n*log(n)
然后就意识流的做完了..?(没有什么脑子的东西强行水一篇真浪费时间)