https://www.luogu.org/problemnew/show/P5050
给定多项式A(x),求$A(x_l)$,$A(x_{l+1})$,..,$A(x_r)$
分治:(如果r-l+1=1,直接O(deg(A))暴力求出即可)
首先设$mid=lfloorfrac{l+r}{2} floor$,$P^{[0]}(x)=prod_{i=l}^{mid}(x-x_i)$,$P^{[1]}(x)=prod_{i=mid+1}^{r}(x-x_i)$
以[l,mid]的求值为例:设$A^{[0]}(x)=A(x)\,mod\,P^{[0]}(x)$
即$A(x)=P^{[0]}(x)B^{[0]}(x)+A^{[0]}(x)$($B^{[0]}$为某个多项式)
可以发现,将$x_l$,$x_{l+1}$,..,$x_{mid}$带入$P^{[0]}(x)$,值都为0
因此对于$l<=i<=mid$,$A(x_i)=A^{[0]}(x_i)$,递归下去算就行;[mid+1,r]的求值同理
这个P可以在分治过程中处理出来
时间复杂度大概是$O(n\,log^2\,n)$(未区分n=r-l+1,m=deg(A))
版本1:基于版本1,加了小范围暴力,预处理了P方便快速插值
注意:这种分治FFT的题,NTT里面wn需要预处理否则可能慢很多(听说多一个log,没有仔细分析)!
1 #prag 2 ma GCC optimize(2) 3 #include<cstdio> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<cmath> 8 using namespace std; 9 #define fi first 10 #define se second 11 #define mp make_pair 12 #define pb push_back 13 typedef long long ll; 14 typedef unsigned long long ull; 15 const int md=998244353; 16 const int N=131072; 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md)) 18 inline int del(int a,int b) 19 { 20 a-=b; 21 return a<0?a+md:a; 22 } 23 int rev[N]; 24 void init(int len) 25 { 26 int bit=0,i; 27 while((1<<(bit+1))<=len) ++bit; 28 for(i=1;i<len;++i) 29 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); 30 } 31 ull poww(ull a,ull b) 32 { 33 ull ans=1; 34 for(;b;b>>=1,a=a*a%md) 35 if(b&1) 36 ans=ans*a%md; 37 return ans; 38 } 39 int inv[300011],pw_1[300011],pw_2[300011]; 40 void dft(int *a,int len,int idx)//要求len为2的幂 41 { 42 int i,j,k,t1,t2;ull wn,wnk; 43 for(i=0;i<len;++i) 44 if(i<rev[i]) 45 swap(a[i],a[rev[i]]); 46 for(i=1;i<len;i<<=1) 47 { 48 wn=idx==1?pw_1[i]:pw_2[i]; 49 //wn=poww(idx==1?3:332748118,(md-1)/(i<<1)); 50 for(j=0;j<len;j+=(i<<1)) 51 { 52 wnk=1; 53 for(k=j;k<j+i;++k,wnk=wnk*wn%md) 54 { 55 t1=a[k];t2=a[k+i]*wnk%md; 56 a[k]+=t2; 57 (a[k]>=md)&&(a[k]-=md); 58 a[k+i]=t1-t2; 59 (a[k+i]<0)&&(a[k+i]+=md); 60 } 61 } 62 } 63 if(idx==-1) 64 { 65 ull ilen=inv[len]; 66 for(i=0;i<len;++i) 67 a[i]=a[i]*ilen%md; 68 } 69 } 70 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂 71 { 72 static int t1[N],t2[N]; 73 g[0]=poww(f[0],md-2); 74 for(int i=2,j;i<=len;i<<=1) 75 { 76 memcpy(t1,f,sizeof(int)*i); 77 memcpy(t2,g,sizeof(int)*(i>>1)); 78 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 79 init(i); 80 dft(t1,i,1);dft(t2,i,1); 81 for(j=0;j<i;++j) 82 t1[j]=ull(t1[j])*t2[j]%md; 83 dft(t1,i,-1); 84 for(j=0;j<(i>>1);++j) 85 t1[j]=t1[j+(i>>1)]; 86 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 87 dft(t1,i,1); 88 for(j=0;j<i;++j) 89 t1[j]=ull(t1[j])*t2[j]%md; 90 dft(t1,i,-1); 91 for(j=i>>1;j<i;++j) 92 g[j]=md-t1[j-(i>>1)]; 93 } 94 } 95 inline void p_de(int *f,int len)//derivative求导;f=f' 96 { 97 for(int i=0;i<len-1;++i) 98 f[i]=ull(i+1)*f[i+1]%md; 99 f[len-1]=0; 100 } 101 inline void p_in(int *f,int len)//integral积分;f=?f 102 { 103 for(int i=len-1;i>=1;--i) 104 f[i]=ull(f[i-1])*inv[i]%md; 105 f[0]=0; 106 } 107 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1 108 { 109 static int t3[N]; 110 p_inv(f,t3,len);p_de(f,len); 111 init(len<<1); 112 dft(f,len<<1,1);dft(t3,len<<1,1); 113 for(int i=0;i<(len<<1);++i) 114 f[i]=ull(f[i])*t3[i]%md; 115 dft(f,len<<1,-1);p_in(f,len); 116 } 117 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0 118 { 119 static int t1[N],t2[N]; 120 g[0]=1; 121 for(int i=2,j;i<=len;i<<=1) 122 { 123 memcpy(t1,g,sizeof(int)*(i>>1)); 124 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 125 p_ln(t1,i); 126 for(j=0;j<(i>>1);++j) 127 t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]); 128 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 129 init(i); 130 dft(t1,i,1); 131 memcpy(t2,g,sizeof(int)*(i>>1)); 132 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 133 dft(t2,i,1); 134 for(j=0;j<i;++j) 135 t1[j]=ull(t1[j])*t2[j]%md; 136 dft(t1,i,-1); 137 for(j=i>>1;j<i;++j) 138 g[j]=t1[j-(i>>1)]; 139 } 140 } 141 void p_div(int *a,int *b,int *c,int n,int m)//c=a/b;deg(a)=n,deg(b)=m,deg(c)=n-m;a,b无前导0;n>=m 142 { 143 reverse(a,a+n+1);reverse(b,b+m+1); 144 int x=n-m+1,t=1; 145 for(;t<x;t<<=1); 146 memset(b+m+1,0,sizeof(int)*max(t-m-1,0)); 147 p_inv(b,c,t); 148 memset(c+x,0,sizeof(int)*((t<<1)-x)); 149 memset(a+x,0,sizeof(int)*((t<<1)-x)); 150 init(t<<1); 151 dft(a,t<<1,1);dft(c,t<<1,1); 152 for(int i=0;i<(t<<1);++i) 153 c[i]=ull(c[i])*a[i]%md; 154 dft(c,t<<1,-1); 155 memset(c+(n-m+1),0,sizeof(int)*((t<<1)-n+m-1)); 156 reverse(c,c+x); 157 } 158 void p_divmod(int *a,int *b,int *c,int *d,int n,int m)//c=a/b,d=a%b,deg(d)=(<=)m-1;其余同上 159 { 160 static int t1[N]; 161 memcpy(d,a,sizeof(int)*(m+1)); 162 int x=n+1,t=1; 163 for(;t<x;t<<=1); 164 memcpy(t1,b,sizeof(int)*(m+1)); 165 memset(t1+m+1,0,sizeof(int)*max(t-m-1,0)); 166 p_div(a,b,c,n,m); 167 memcpy(a,c,sizeof(int)*(n-m+1)); 168 memset(a+n-m+1,0,sizeof(int)*(t-n+m-1)); 169 init(t); 170 dft(a,t,1);dft(t1,t,1); 171 for(int i=0;i<t;++i) 172 t1[i]=ull(t1[i])*a[i]%md; 173 dft(t1,t,-1); 174 for(int i=0;i<=m;++i) 175 delto(d[i],t1[i]); 176 } 177 namespace P_me 178 { 179 int *ta[N];//用线段树的方法给递归的每一层一个编号,ta[i]表示编号为i的层的P函数的各项系数 180 int data[N*40],*tp;//内存池 181 int *a,*x,*y; 182 #define LC (u<<1) 183 #define RC (u<<1|1) 184 int mt1[N]; 185 const int T=200;//小范围暴力阀值 186 void _p_me1(int l,int r,int u)//计算(x-x_l)(x-x_{l+1})..(x-x_r)并存下来 187 { 188 if(r-l<=T) 189 { 190 int i,j; 191 tp[0]=1; 192 for(i=l;i<=r;++i) 193 { 194 tp[i-l+1]=tp[i-l]; 195 for(j=i-l;j>=1;--j) 196 { 197 tp[j]=(ull(tp[j])*(md-x[i])+tp[j-1])%md; 198 } 199 tp[0]=ull(tp[0])*(md-x[i])%md; 200 } 201 ta[u]=tp;tp+=r-l+2; 202 return; 203 } 204 int mid=(l+r)>>1; 205 _p_me1(l,mid,LC);_p_me1(mid+1,r,RC); 206 int x=r-l+2,t=1;//x=(mid-l+1)+(r-mid)+1 207 for(;t<x;t<<=1); 208 memcpy(mt1,ta[LC],sizeof(int)*(mid-l+2)); 209 memset(mt1+mid-l+2,0,sizeof(int)*(t-mid+l-2)); 210 memcpy(tp,ta[RC],sizeof(int)*(r-mid+1)); 211 memset(tp+r-mid+1,0,sizeof(int)*(t-r+mid-1)); 212 init(t); 213 dft(mt1,t,1);dft(tp,t,1); 214 for(int i=0;i<t;++i) 215 tp[i]=ull(tp[i])*mt1[i]%md; 216 dft(tp,t,-1); 217 ta[u]=tp;tp+=r-l+2; 218 } 219 int mt2[N],mt3[N]; 220 void _p_me2(int *a,int n,int l,int r,int u)//a是A的系数,deg(A)<=n;求A(x_l)到A(x_r),放入y_l到y_r 221 { 222 if(r-l<=T) 223 { 224 int t,i,j; 225 for(i=l;i<=r;++i) 226 { 227 t=a[n]; 228 for(j=n-1;j>=0;--j) 229 t=(ull(t)*x[i]+a[j])%md; 230 y[i]=t; 231 } 232 return; 233 } 234 int x=(n+1)<<1,t=1; 235 for(;t<x;t<<=1); 236 int mt4[t];//根据需要改成new? 237 int mid=(l+r)>>1,n1; 238 memcpy(mt1,a,sizeof(int)*(n+1)); 239 for(n1=n;n1>=0 && mt1[n1]==0;) --n1; 240 if(n1<0) 241 { 242 memset(y+l,0,sizeof(int)*(r-l+1)); 243 return; 244 } 245 memcpy(mt2,ta[LC],sizeof(int)*(mid-l+2)); 246 if(n1<mid-l+1) 247 { 248 memcpy(mt4,mt1,sizeof(int)*(n1+1)); 249 _p_me2(mt4,n1,l,mid,LC); 250 } 251 else 252 { 253 p_divmod(mt1,mt2,mt3,mt4,n1,mid-l+1); 254 _p_me2(mt4,mid-l,l,mid,LC); 255 } 256 memcpy(mt1,a,sizeof(int)*(n+1)); 257 for(n1=n;n1>=0 && mt1[n1]==0;) --n1; 258 memcpy(mt2,ta[RC],sizeof(int)*(r-mid+1)); 259 if(n1<r-mid) 260 { 261 memcpy(mt4,mt1,sizeof(int)*(n1+1)); 262 _p_me2(mt4,n1,mid+1,r,RC); 263 } 264 else 265 { 266 p_divmod(mt1,mt2,mt3,mt4,n1,r-mid); 267 _p_me2(mt4,r-mid-1,mid+1,r,RC); 268 } 269 } 270 void p_multieval(int *a0,int *x0,int *y0,int n,int m)//deg(a)=n,x有m个数 271 { 272 tp=data; 273 a=a0;x=x0;y=y0; 274 _p_me1(0,m-1,1); 275 _p_me2(a,n,0,m-1,1); 276 } 277 } 278 using P_me::p_multieval; 279 int a[N],x[N],y[N]; 280 int n,m; 281 int main() 282 { 283 int i; 284 inv[1]=1; 285 for(i=2;i<=300000;++i) 286 inv[i]=ull(md-md/i)*inv[md%i]%md; 287 for(i=1;i<300000;i<<=1) 288 { 289 pw_1[i]=poww(3,(md-1)/(i<<1)); 290 pw_2[i]=poww(332748118,(md-1)/(i<<1)); 291 } 292 //n=100000;m=100000; 293 scanf("%d%d",&n,&m); 294 for(i=0;i<=n;++i) 295 //a[i]=rand()%md; 296 scanf("%d",a+i); 297 for(i=0;i<m;++i) 298 //x[i]=rand()%md; 299 scanf("%d",x+i); 300 p_multieval(a,x,y,n,m); 301 for(i=0;i<m;++i) 302 printf("%d ",y[i]); 303 return 0; 304 }