前置知识
(以下内容并不严谨,可以参考论文《转置原理的简单介绍》)
对于一个算法,其为线性算法当且仅当仅包含以下操作:
1.$read i$,将$r_{i}$的值赋为(下一个)读入的元素
2.$write i$,将$r_{i}$的值赋给(下一个)输出的元素
3.$update i j c$,将$r_{i}$的值修改为$r_{i}+ccdot r_{j}$(其中$c$为常数)
其中$r_{i}$表示第$i$个变量,初值均为0
由此,线性算法即可用一个操作序列描述
假设有$n$次$read$操作和$m$次$write$操作,不妨强制这$n$次$read$和$m$次$write$操作分别严格在$update$之前和之后(以下$n,m$意义默认为此意义,且保证此性质)
结论1:对于线性算法,恰存在一个$m imes n$的矩阵$A$——若将读入的$n$个元素和输出的$m$个元素分别构成$n imes 1$和$m imes 1$的矩阵$x$和$y$,则$y=Ax$
归纳每一个变量都是$x_{i}$的线性组合即可
进一步的,称该线性算法为矩阵$A$的线性算法
结论2:对于矩阵$A$的线性算法,将其操作序列翻转并将操作依次变为$write i,read i$和$update j i c$,则得到的线性算法为$A^{T}$的线性算法
操作可以看作不断乘上矩阵$E_{1},E_{2},...,E_{t}$,根据定义$A=prod_{i=1}^{t}E_{i}$
操作翻转和变化即变为乘上$E_{t}^{T},E_{t-1}^{T},...,E_{1}^{T}$,而注意到$A^{T}=prod_{i=t}^{1}E_{i}^{T}$,也即成立
题解
令$f_{k}(x)=prod_{i=1}^{k}(x+b_{i})$,则问题即求$sum_{i=1}^{n}c_{i}f_{k}(a_{i})$
将其以多项式的形式展开,即$sum_{i=1}^{n}c_{i}sum_{j=0}^{n}[x^{j}]f_{k}(x)cdot a_{i}^{j}$
调换枚举顺序,即$sum_{j=0}^{n}[x^{j}]f_{k}(x)sum_{i=1}^{n}c_{i}a_{i}^{j}$
记后者为$S_{j}$,注意到其可以看作$sum_{i=1}^{n}c_{i}cdot [x^{j}]frac{1}{1-a_{i}x}=[x^{j}]sum_{i=1}^{n}frac{c_{i}}{1-a_{i}x}$,那么考虑求该多项式,可以对其分治计算,并以分式的形式存储即可(最后再求逆展开)
由此,问题即求$sum_{i=0}^{n}S_{i}cdot [x^{i}]f_{k}(x)$
构造矩阵$A_{i,k}=[x^{i}]f_{k}(x)$,并将$S$看作输入矩阵(预处理过程显然独立),那么即需要实现$A$的线性算法,根据转置原理(结论2)不妨去实现$A^{T}$的线性算法
关于$A^{T}$的线性算法,假设读入为$Q$,代入式子即求$sum_{i=1}^{n}Q_{i}A_{k,i}=[x^{k}]sum_{i=1}^{n}Q_{i}prod_{j=1}^{i}(x+b_{j})$
分治,并求出$H(x)=prod_{j=l}^{r}(x+b_{j})$和$F(x)=sum_{i=l}^{r}Q_{i}prod_{j=l}^{i}(x+b_{j})$,则转移如下
$$
egin{cases}H(x)=H_{l}(x)H_{r}(x)\F(x)=F_{l}(x)+H_{l}(x)F_{r}(x)end{cases}
$$
进而$A$的线性算法即将过程"转置",具体做法如下——
$H(x)$与输入$S$无关,因此可以先预处理出来
$A^{T}$的线性算法中,最后执行的是输出$[x^{i}]F(x)$,那么即需要读入$[x^{i}]F(x)$(注意要求$AS$,即读入为$S$)
分治的形式是从底向上做,那么反过来即要自顶向下做(先执行操作再递归)
下面,来考虑操作的翻转,从后往前依次考虑操作:
1.$[x^{i}]F(x)=[x^{i}]F_{l}(x)+[x^{i}]G(x)$(其中$G(x)=H_{l}(x)F_{r}(x)$),将其表示为形如$update i j c$的操作,再变换后也即$[x^{i}]F_{l}(x)=[x^{i}]G(x)=[x^{i}]F(x)$
2.$ntt(G,-1),[x^{i}]G(x)=[x^{i}]H_{l}(x)cdot [x^{i}]F_{r}(x),ntt(F_{r},1)$
(注意这只是将整体的顺序调过来,内部并没有翻转)
关于$ntt$内部如何转置,注意到$ntt(a,p)$可以看作将$a$乘上矩阵$A_{i,j}=omega^{pij}$,联系结论2的证明即需要将该矩阵转置,而显然转置后仍为自身,因此不需要变化
(但注意$ntt(*,-1)$中的除法虽然不在矩阵乘法的范围内,但不取出处理也是正确的)
接下来即仅有第2步,注意到其中$[x^{i}]H_{l}(x)$看作是常数,因此需要先对其ntt(可以看作预处理),因此也即变为$[x^{i}]F_{r}(x)=[x^{i}]H_{l}(x)cdot [x^{i}]G(x)$
$A^{T}$的线性算法中,最先执行的是将读入$[x^{1}]F(x)=S_{i}$和$[x^{0}]F(x)=b_{i}cdot [x^{1}]F(x)$,那么将其转置后即需输出$b_{i}cdot [x^{0}]F(x)+[x^{1}]F(x)$
时间复杂度为$o(nlog^{2}n)$,可以通过
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 250005 4 #define mod 998244353 5 #define ll long long 6 #define vi vector<int> 7 #define L (k<<1) 8 #define R (L+1) 9 #define mid (l+r>>1) 10 int n,a[N],b[N],c[N],rev[1<<20]; 11 vi A[N<<2],B[N<<2],F[N<<2]; 12 int Log(int m){ 13 int n=1; 14 while (n<m)n<<=1; 15 return n; 16 } 17 int qpow(int n,int m){ 18 int s=n,ans=1; 19 while (m){ 20 if (m&1)ans=(ll)ans*s%mod; 21 s=(ll)s*s%mod; 22 m>>=1; 23 } 24 return ans; 25 } 26 void init(int n){ 27 for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)+((i&1)*(n>>1)); 28 } 29 void ntt(vi &a,int n,int p=0){ 30 a.resize(n); 31 for(int i=0;i<n;i++) 32 if (i<rev[i])swap(a[i],a[rev[i]]); 33 for(int i=2;i<=n;i<<=1){ 34 int s=qpow(3,(mod-1)/i); 35 if (p)s=qpow(s,mod-2); 36 for(int j=0;j<n;j+=i) 37 for(int k=0,ss=1;k<(i>>1);k++,ss=(ll)ss*s%mod){ 38 int x=a[j+k],y=(ll)ss*a[j+k+(i>>1)]%mod; 39 a[j+k]=(x+y)%mod,a[j+k+(i>>1)]=(x-y+mod)%mod; 40 } 41 } 42 if (p){ 43 int s=qpow(n,mod-2); 44 for(int i=0;i<n;i++)a[i]=(ll)a[i]*s%mod; 45 } 46 } 47 vi mul(vi a,vi b,int m){ 48 int n=Log(m<<1); 49 if (m<0)m=a.size()+b.size()-1,n=Log(m); 50 init(n); 51 for(int i=m;i<a.size();i++)a[i]=0; 52 for(int i=m;i<b.size();i++)b[i]=0; 53 ntt(a,n),ntt(b,n); 54 vi ans; 55 for(int i=0;i<n;i++)ans.push_back((ll)a[i]*b[i]%mod); 56 ntt(ans,n,1); 57 for(int i=m;i<n;i++)ans.pop_back(); 58 return ans; 59 } 60 vi inv(vi a,int n){ 61 vi ans; 62 if (n==1){ 63 ans.push_back(qpow(a[0],mod-2)); 64 return ans; 65 } 66 vi s=inv(a,(n>>1)); 67 ans=mul(a,s,n); 68 for(int i=0;i<n;i++)ans[i]=mod-ans[i]; 69 ans[0]+=2; 70 return mul(s,ans,n); 71 } 72 void get_S(int k,int l,int r){ 73 if (l==r){ 74 A[k].push_back(c[l]); 75 B[k].push_back(1),B[k].push_back(mod-a[l]); 76 return; 77 } 78 get_S(L,l,mid),get_S(R,mid+1,r); 79 int m=B[L].size()+B[R].size()-1,n=Log(m); 80 init(n); 81 ntt(A[L],n),ntt(A[R],n),ntt(B[L],n),ntt(B[R],n); 82 for(int i=0;i<n;i++){ 83 A[k].push_back(((ll)A[L][i]*B[R][i]+(ll)A[R][i]*B[L][i])%mod); 84 B[k].push_back((ll)B[L][i]*B[R][i]%mod); 85 } 86 ntt(A[k],n,1),ntt(B[k],n,1); 87 for(int i=m;i<n;i++){ 88 A[k].pop_back(); 89 B[k].pop_back(); 90 } 91 } 92 void get_H(int k,int l,int r){ 93 if (l==r){ 94 B[k].clear(); 95 B[k].push_back(b[l]),B[k].push_back(1); 96 return; 97 } 98 get_H(L,l,mid),get_H(R,mid+1,r); 99 B[k]=mul(B[L],B[R],-1); 100 } 101 void get_F(int k,int l,int r){ 102 if (l==r){ 103 printf("%d ",((ll)b[l]*F[k][0]+F[k][1])%mod); 104 return; 105 } 106 F[L]=F[R]=F[k]; 107 int m=B[L].size(),n=F[L].size(); 108 for(int i=m;i<n;i++)F[L].pop_back(); 109 m+=B[R].size()-1,n=Log(m); 110 init(n); 111 ntt(B[L],n),ntt(F[R],n,1); 112 for(int i=0;i<n;i++)F[R][i]=(ll)F[R][i]*B[L][i]%mod; 113 ntt(F[R],n); 114 m=B[R].size(); 115 for(int i=m;i<n;i++)F[R].pop_back(); 116 get_F(L,l,mid),get_F(R,mid+1,r); 117 } 118 int main(){ 119 scanf("%d",&n); 120 for(int i=1;i<=n;i++)scanf("%d",&a[i]); 121 for(int i=1;i<=n;i++)scanf("%d",&b[i]); 122 for(int i=1;i<=n;i++)scanf("%d",&c[i]); 123 get_S(1,1,n); 124 A[1]=mul(A[1],inv(B[1],Log(n+1)),n+1); 125 get_H(1,1,n); 126 F[1]=A[1],get_F(1,1,n); 127 return 0; 128 }