https://www.luogu.org/problemnew/show/P5205
按道理说,多项式开根可以有多个解(根据常数项不同有不同的解)。此题只需要输出常数项为1的解(题面漏了)
首先,可以直接多项式快速幂做(2对998244353的逆元)次幂(直接做只能在输入常数项为1时)(我不是很懂为什么能起效,不过的确能AC)
版本1:基于版本1
1 #prag 2 ma GCC optimize(2) 3 #include<cstdio> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<cmath> 8 using namespace std; 9 #define fi first 10 #define se second 11 #define mp make_pair 12 #define pb push_back 13 typedef long long ll; 14 typedef unsigned long long ull; 15 const int md=998244353; 16 const int N=262144; 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md)) 18 inline int del(int a,int b) 19 { 20 a-=b; 21 return a<0?a+md:a; 22 } 23 int rev[N]; 24 void init(int len) 25 { 26 int bit=0,i; 27 while((1<<(bit+1))<=len) ++bit; 28 for(i=1;i<len;++i) 29 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); 30 } 31 ull poww(ull a,ull b) 32 { 33 ull ans=1; 34 for(;b;b>>=1,a=a*a%md) 35 if(b&1) 36 ans=ans*a%md; 37 return ans; 38 } 39 int inv[300011]; 40 void dft(int *a,int len,int idx)//要求len为2的幂 41 { 42 int i,j,k,t1,t2;ull wn,wnk; 43 for(i=0;i<len;++i) 44 if(i<rev[i]) 45 swap(a[i],a[rev[i]]); 46 for(i=1;i<len;i<<=1) 47 { 48 wn=poww(idx==1?3:332748118,(md-1)/(i<<1)); 49 for(j=0;j<len;j+=(i<<1)) 50 { 51 wnk=1; 52 for(k=j;k<j+i;++k,wnk=wnk*wn%md) 53 { 54 t1=a[k];t2=a[k+i]*wnk%md; 55 a[k]+=t2; 56 (a[k]>=md)&&(a[k]-=md); 57 a[k+i]=t1-t2; 58 (a[k+i]<0)&&(a[k+i]+=md); 59 } 60 } 61 } 62 if(idx==-1) 63 { 64 ull ilen=inv[len]; 65 for(i=0;i<len;++i) 66 a[i]=a[i]*ilen%md; 67 } 68 } 69 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂 70 { 71 static int t1[N],t2[N]; 72 g[0]=poww(f[0],md-2); 73 for(int i=2,j;i<=len;i<<=1) 74 { 75 memcpy(t1,f,sizeof(int)*i); 76 memcpy(t2,g,sizeof(int)*(i>>1)); 77 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 78 init(i); 79 dft(t1,i,1);dft(t2,i,1); 80 for(j=0;j<i;++j) 81 t1[j]=ull(t1[j])*t2[j]%md; 82 dft(t1,i,-1); 83 for(j=0;j<(i>>1);++j) 84 t1[j]=t1[j+(i>>1)]; 85 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 86 dft(t1,i,1); 87 for(j=0;j<i;++j) 88 t1[j]=ull(t1[j])*t2[j]%md; 89 dft(t1,i,-1); 90 for(j=i>>1;j<i;++j) 91 g[j]=md-t1[j-(i>>1)]; 92 } 93 } 94 inline void p_de(int *f,int len)//derivative求导;f=f' 95 { 96 for(int i=0;i<len-1;++i) 97 f[i]=ull(i+1)*f[i+1]%md; 98 f[len-1]=0; 99 } 100 inline void p_in(int *f,int len)//integral积分;f=?f 101 { 102 for(int i=len-1;i>=1;--i) 103 f[i]=ull(f[i-1])*inv[i]%md; 104 f[0]=0; 105 } 106 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1 107 { 108 static int t3[N]; 109 p_inv(f,t3,len);p_de(f,len); 110 init(len<<1); 111 dft(f,len<<1,1);dft(t3,len<<1,1); 112 for(int i=0;i<(len<<1);++i) 113 f[i]=ull(f[i])*t3[i]%md; 114 dft(f,len<<1,-1);p_in(f,len); 115 } 116 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0 117 { 118 static int t1[N],t2[N]; 119 g[0]=1; 120 for(int i=2,j;i<=len;i<<=1) 121 { 122 memcpy(t1,g,sizeof(int)*(i>>1)); 123 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 124 p_ln(t1,i); 125 for(j=0;j<(i>>1);++j) 126 t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]); 127 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 128 init(i); 129 dft(t1,i,1); 130 memcpy(t2,g,sizeof(int)*(i>>1)); 131 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 132 dft(t2,i,1); 133 for(j=0;j<i;++j) 134 t1[j]=ull(t1[j])*t2[j]%md; 135 dft(t1,i,-1); 136 for(j=i>>1;j<i;++j) 137 g[j]=t1[j-(i>>1)]; 138 } 139 } 140 inline void p_pow_1(int *f,int *g,int len,int b)//要求len为2的幂,常数项为1 141 { 142 p_ln(f,len); 143 for(int i=0;i<len;++i) 144 f[i]=ull(f[i])*b%md; 145 p_exp(f,g,len); 146 } 147 void p_pow(int *f,int *g,int len,int b)//g=f^b;要求len为2的幂 148 { 149 int i;ll p=-1; 150 for(i=0;i<len;++i) 151 if(f[i]) 152 { 153 p=i; 154 break; 155 } 156 if(p==-1) return; 157 for(i=0;i<len-p;++i) 158 f[i]=f[i+p]; 159 memset(f+len-p,0,sizeof(int)*p); 160 int t=f[0],t1=poww(t,md-2),t2=poww(t,b); 161 for(i=0;i<len;++i) 162 f[i]=ull(f[i])*t1%md; 163 p_pow_1(f,g,len,b); 164 for(i=0;i<len;++i) 165 g[i]=ull(g[i])*t2%md; 166 p*=b; 167 for(i=len-1;i>=p;--i) 168 g[i]=g[i-p]; 169 memset(g,0,sizeof(int)*min(ll(len),p)); 170 } 171 int a[N],b[N]; 172 int n,n1; 173 int main() 174 { 175 int i,t; 176 inv[1]=1; 177 for(i=2;i<=300000;++i) 178 inv[i]=ull(md-md/i)*inv[md%i]%md; 179 scanf("%d",&n);n1=n; 180 for(i=0;i<n;++i) 181 scanf("%d",a+i); 182 for(t=1;t<n;t<<=1); 183 n=t; 184 p_pow(a,b,n,499122177); 185 for(i=0;i<n1;++i) 186 printf("%d ",b[i]); 187 return 0; 188 }
也可以直接牛顿迭代做。设$g(f(x))=f(x)^2-A(x)$
$f(x)=f_0(x)-frac{f_0(x)^2-A(x)}{2f_0(x)}=frac{A(x)}{2f_0(x)}+frac{f_0(x)}{2}$
版本2:基于版本2
1 #prag 2 ma GCC optimize(2) 3 #include<cstdio> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<cmath> 8 using namespace std; 9 #define fi first 10 #define se second 11 #define mp make_pair 12 #define pb push_back 13 typedef long long ll; 14 typedef unsigned long long ull; 15 const int md=998244353; 16 const int N=262144; 17 #define addto(a,b) ((a)+=(b),((a)>=md)&&((a)-=md)) 18 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md)) 19 inline int del(int a,int b) 20 { 21 a-=b; 22 return a<0?a+md:a; 23 } 24 int rev[N]; 25 void init(int len) 26 { 27 int bit=0,i; 28 while((1<<(bit+1))<=len) ++bit; 29 for(i=1;i<len;++i) 30 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); 31 } 32 ull poww(ull a,ull b) 33 { 34 ull ans=1; 35 for(;b;b>>=1,a=a*a%md) 36 if(b&1) 37 ans=ans*a%md; 38 return ans; 39 } 40 int inv[300011]; 41 void dft(int *a,int len,int idx)//要求len为2的幂 42 { 43 int i,j,k,t1,t2;ull wn,wnk; 44 for(i=0;i<len;++i) 45 if(i<rev[i]) 46 swap(a[i],a[rev[i]]); 47 for(i=1;i<len;i<<=1) 48 { 49 wn=poww(idx==1?3:332748118,(md-1)/(i<<1)); 50 for(j=0;j<len;j+=(i<<1)) 51 { 52 wnk=1; 53 for(k=j;k<j+i;++k,wnk=wnk*wn%md) 54 { 55 t1=a[k];t2=a[k+i]*wnk%md; 56 a[k]+=t2; 57 (a[k]>=md)&&(a[k]-=md); 58 a[k+i]=t1-t2; 59 (a[k+i]<0)&&(a[k+i]+=md); 60 } 61 } 62 } 63 if(idx==-1) 64 { 65 ull ilen=inv[len]; 66 for(i=0;i<len;++i) 67 a[i]=a[i]*ilen%md; 68 } 69 } 70 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂 71 { 72 static int t1[N],t2[N]; 73 g[0]=poww(f[0],md-2); 74 for(int i=2,j;i<=len;i<<=1) 75 { 76 memcpy(t1,f,sizeof(int)*i); 77 memcpy(t2,g,sizeof(int)*(i>>1)); 78 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 79 init(i); 80 dft(t1,i,1);dft(t2,i,1); 81 for(j=0;j<i;++j) 82 t1[j]=ull(t1[j])*t2[j]%md; 83 dft(t1,i,-1); 84 for(j=0;j<(i>>1);++j) 85 t1[j]=t1[j+(i>>1)]; 86 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 87 dft(t1,i,1); 88 for(j=0;j<i;++j) 89 t1[j]=ull(t1[j])*t2[j]%md; 90 dft(t1,i,-1); 91 for(j=i>>1;j<i;++j) 92 g[j]=md-t1[j-(i>>1)]; 93 } 94 } 95 inline void p_de(int *f,int len)//derivative求导;f=f' 96 { 97 for(int i=0;i<len-1;++i) 98 f[i]=ull(i+1)*f[i+1]%md; 99 f[len-1]=0; 100 } 101 inline void p_in(int *f,int len)//integral积分;f=?f 102 { 103 for(int i=len-1;i>=1;--i) 104 f[i]=ull(f[i-1])*inv[i]%md; 105 f[0]=0; 106 } 107 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1 108 { 109 static int t3[N]; 110 p_inv(f,t3,len);p_de(f,len); 111 init(len<<1); 112 dft(f,len<<1,1);dft(t3,len<<1,1); 113 for(int i=0;i<(len<<1);++i) 114 f[i]=ull(f[i])*t3[i]%md; 115 dft(f,len<<1,-1);p_in(f,len); 116 } 117 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0 118 { 119 static int t1[N],t2[N]; 120 g[0]=1; 121 for(int i=2,j;i<=len;i<<=1) 122 { 123 memcpy(t1,g,sizeof(int)*(i>>1)); 124 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 125 p_ln(t1,i); 126 for(j=0;j<(i>>1);++j) 127 t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]); 128 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 129 init(i); 130 dft(t1,i,1); 131 memcpy(t2,g,sizeof(int)*(i>>1)); 132 memset(t2+(i>>1),0,sizeof(int)*(i>>1)); 133 dft(t2,i,1); 134 for(j=0;j<i;++j) 135 t1[j]=ull(t1[j])*t2[j]%md; 136 dft(t1,i,-1); 137 for(j=i>>1;j<i;++j) 138 g[j]=t1[j-(i>>1)]; 139 } 140 } 141 void p_sqrt(int *f,int *g,int len)//g=sqrt(f);要求len为2的幂,f[0]=1 142 { 143 static int t1[N],t2[N]; 144 g[0]=1; 145 for(int i=2,j;i<=len;i<<=1) 146 { 147 memcpy(t1,g,sizeof(int)*(i>>1)); 148 memset(t1+(i>>1),0,sizeof(int)*(i>>1)); 149 for(j=0;j<i;++j) 150 addto(t1[j],t1[j]); 151 p_inv(t1,t2,i); 152 memset(t2+i,0,sizeof(int)*i); 153 memcpy(t1,f,sizeof(int)*i); 154 memset(t1+i,0,sizeof(int)*i); 155 init(i<<1); 156 dft(t1,i<<1,1);dft(t2,i<<1,1); 157 for(j=0;j<(i<<1);++j) 158 t1[j]=ull(t1[j])*t2[j]%md; 159 dft(t1,i<<1,-1); 160 for(j=0;j<(i>>1);++j) 161 g[j]=(ull(g[j])*499122177+t1[j])%md; 162 memcpy(g+(i>>1),t1+(i>>1),sizeof(int)*(i>>1)); 163 } 164 } 165 int a[N],b[N]; 166 int n,n1; 167 int main() 168 { 169 int i,t; 170 inv[1]=1; 171 for(i=2;i<=300000;++i) 172 inv[i]=ull(md-md/i)*inv[md%i]%md; 173 scanf("%d",&n);n1=n; 174 for(i=0;i<n;++i) 175 scanf("%d",a+i); 176 for(t=1;t<n;t<<=1); 177 n=t; 178 p_sqrt(a,b,n); 179 for(i=0;i<n1;++i) 180 printf("%d ",b[i]); 181 return 0; 182 }