• 洛谷P5050 【模板】多项式多点求值


    https://www.luogu.org/problemnew/show/P5050

    给定多项式A(x),求$A(x_l)$,$A(x_{l+1})$,..,$A(x_r)$

    分治:(如果r-l+1=1,直接O(deg(A))暴力求出即可)

    首先设$mid=lfloorfrac{l+r}{2} floor$,$P^{[0]}(x)=prod_{i=l}^{mid}(x-x_i)$,$P^{[1]}(x)=prod_{i=mid+1}^{r}(x-x_i)$

    以[l,mid]的求值为例:设$A^{[0]}(x)=A(x)\,mod\,P^{[0]}(x)$

    即$A(x)=P^{[0]}(x)B^{[0]}(x)+A^{[0]}(x)$($B^{[0]}$为某个多项式)

    可以发现,将$x_l$,$x_{l+1}$,..,$x_{mid}$带入$P^{[0]}(x)$,值都为0

    因此对于$l<=i<=mid$,$A(x_i)=A^{[0]}(x_i)$,递归下去算就行;[mid+1,r]的求值同理

    这个P可以在分治过程中处理出来

    时间复杂度大概是$O(n\,log^2\,n)$(未区分n=r-l+1,m=deg(A))

    版本1:基于版本1,加了小范围暴力,预处理了P方便快速插值

    注意:这种分治FFT的题,NTT里面wn需要预处理否则可能慢很多(听说多一个log,没有仔细分析)!

      1 #prag
      2 ma GCC optimize(2)
      3 #include<cstdio>
      4 #include<algorithm>
      5 #include<cstring>
      6 #include<vector>
      7 #include<cmath>
      8 using namespace std;
      9 #define fi first
     10 #define se second
     11 #define mp make_pair
     12 #define pb push_back
     13 typedef long long ll;
     14 typedef unsigned long long ull;
     15 const int md=998244353;
     16 const int N=131072;
     17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
     18 inline int del(int a,int b)
     19 {
     20     a-=b;
     21     return a<0?a+md:a;
     22 }
     23 int rev[N];
     24 void init(int len)
     25 {
     26     int bit=0,i;
     27     while((1<<(bit+1))<=len)    ++bit;
     28     for(i=1;i<len;++i)
     29         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
     30 }
     31 ull poww(ull a,ull b)
     32 {
     33     ull ans=1;
     34     for(;b;b>>=1,a=a*a%md)
     35         if(b&1)
     36             ans=ans*a%md;
     37     return ans;
     38 }
     39 int inv[300011],pw_1[300011],pw_2[300011];
     40 void dft(int *a,int len,int idx)//要求len为2的幂
     41 {
     42     int i,j,k,t1,t2;ull wn,wnk;
     43     for(i=0;i<len;++i)
     44         if(i<rev[i])
     45             swap(a[i],a[rev[i]]);
     46     for(i=1;i<len;i<<=1)
     47     {
     48         wn=idx==1?pw_1[i]:pw_2[i];
     49         //wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
     50         for(j=0;j<len;j+=(i<<1))
     51         {
     52             wnk=1;
     53             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
     54             {
     55                 t1=a[k];t2=a[k+i]*wnk%md;
     56                 a[k]+=t2;
     57                 (a[k]>=md)&&(a[k]-=md);
     58                 a[k+i]=t1-t2;
     59                 (a[k+i]<0)&&(a[k+i]+=md);
     60             }
     61         }
     62     }
     63     if(idx==-1)
     64     {
     65         ull ilen=inv[len];
     66         for(i=0;i<len;++i)
     67             a[i]=a[i]*ilen%md;
     68     }
     69 }
     70 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂
     71 {
     72     static int t1[N],t2[N];
     73     g[0]=poww(f[0],md-2);
     74     for(int i=2,j;i<=len;i<<=1)
     75     {
     76         memcpy(t1,f,sizeof(int)*i);
     77         memcpy(t2,g,sizeof(int)*(i>>1));
     78         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
     79         init(i);
     80         dft(t1,i,1);dft(t2,i,1);
     81         for(j=0;j<i;++j)
     82             t1[j]=ull(t1[j])*t2[j]%md;
     83         dft(t1,i,-1);
     84         for(j=0;j<(i>>1);++j)
     85             t1[j]=t1[j+(i>>1)];
     86         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
     87         dft(t1,i,1);
     88         for(j=0;j<i;++j)
     89             t1[j]=ull(t1[j])*t2[j]%md;
     90         dft(t1,i,-1);
     91         for(j=i>>1;j<i;++j)
     92             g[j]=md-t1[j-(i>>1)];
     93     }
     94 }
     95 inline void p_de(int *f,int len)//derivative求导;f=f'
     96 {
     97     for(int i=0;i<len-1;++i)
     98         f[i]=ull(i+1)*f[i+1]%md;
     99     f[len-1]=0;
    100 }
    101 inline void p_in(int *f,int len)//integral积分;f=?f
    102 {
    103     for(int i=len-1;i>=1;--i)
    104         f[i]=ull(f[i-1])*inv[i]%md;
    105     f[0]=0;
    106 }
    107 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1
    108 {
    109     static int t3[N];
    110     p_inv(f,t3,len);p_de(f,len);
    111     init(len<<1);
    112     dft(f,len<<1,1);dft(t3,len<<1,1);
    113     for(int i=0;i<(len<<1);++i)
    114         f[i]=ull(f[i])*t3[i]%md;
    115     dft(f,len<<1,-1);p_in(f,len);
    116 }
    117 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0
    118 {
    119     static int t1[N],t2[N];
    120     g[0]=1;
    121     for(int i=2,j;i<=len;i<<=1)
    122     {
    123         memcpy(t1,g,sizeof(int)*(i>>1));
    124         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
    125         p_ln(t1,i);
    126         for(j=0;j<(i>>1);++j)
    127             t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]);
    128         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
    129         init(i);
    130         dft(t1,i,1);
    131         memcpy(t2,g,sizeof(int)*(i>>1));
    132         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
    133         dft(t2,i,1);
    134         for(j=0;j<i;++j)
    135             t1[j]=ull(t1[j])*t2[j]%md;
    136         dft(t1,i,-1);
    137         for(j=i>>1;j<i;++j)
    138             g[j]=t1[j-(i>>1)];
    139     }
    140 }
    141 void p_div(int *a,int *b,int *c,int n,int m)//c=a/b;deg(a)=n,deg(b)=m,deg(c)=n-m;a,b无前导0;n>=m
    142 {
    143     reverse(a,a+n+1);reverse(b,b+m+1);
    144     int x=n-m+1,t=1;
    145     for(;t<x;t<<=1);
    146     memset(b+m+1,0,sizeof(int)*max(t-m-1,0));
    147     p_inv(b,c,t);
    148     memset(c+x,0,sizeof(int)*((t<<1)-x));
    149     memset(a+x,0,sizeof(int)*((t<<1)-x));
    150     init(t<<1);
    151     dft(a,t<<1,1);dft(c,t<<1,1);
    152     for(int i=0;i<(t<<1);++i)
    153         c[i]=ull(c[i])*a[i]%md;
    154     dft(c,t<<1,-1);
    155     memset(c+(n-m+1),0,sizeof(int)*((t<<1)-n+m-1));
    156     reverse(c,c+x);
    157 }
    158 void p_divmod(int *a,int *b,int *c,int *d,int n,int m)//c=a/b,d=a%b,deg(d)=(<=)m-1;其余同上
    159 {
    160     static int t1[N];
    161     memcpy(d,a,sizeof(int)*(m+1));
    162     int x=n+1,t=1;
    163     for(;t<x;t<<=1);
    164     memcpy(t1,b,sizeof(int)*(m+1));
    165     memset(t1+m+1,0,sizeof(int)*max(t-m-1,0));
    166     p_div(a,b,c,n,m);
    167     memcpy(a,c,sizeof(int)*(n-m+1));
    168     memset(a+n-m+1,0,sizeof(int)*(t-n+m-1));
    169     init(t);
    170     dft(a,t,1);dft(t1,t,1);
    171     for(int i=0;i<t;++i)
    172         t1[i]=ull(t1[i])*a[i]%md;
    173     dft(t1,t,-1);
    174     for(int i=0;i<=m;++i)
    175         delto(d[i],t1[i]);
    176 }
    177 namespace P_me
    178 {
    179     int *ta[N];//用线段树的方法给递归的每一层一个编号,ta[i]表示编号为i的层的P函数的各项系数
    180     int data[N*40],*tp;//内存池
    181     int *a,*x,*y;
    182 #define LC (u<<1)
    183 #define RC (u<<1|1)
    184     int mt1[N];
    185     const int T=200;//小范围暴力阀值
    186     void _p_me1(int l,int r,int u)//计算(x-x_l)(x-x_{l+1})..(x-x_r)并存下来
    187     {
    188         if(r-l<=T)
    189         {
    190             int i,j;
    191             tp[0]=1;
    192             for(i=l;i<=r;++i)
    193             {
    194                 tp[i-l+1]=tp[i-l];
    195                 for(j=i-l;j>=1;--j)
    196                 {
    197                     tp[j]=(ull(tp[j])*(md-x[i])+tp[j-1])%md;
    198                 }
    199                 tp[0]=ull(tp[0])*(md-x[i])%md;
    200             }
    201             ta[u]=tp;tp+=r-l+2;
    202             return;
    203         }
    204         int mid=(l+r)>>1;
    205         _p_me1(l,mid,LC);_p_me1(mid+1,r,RC);
    206         int x=r-l+2,t=1;//x=(mid-l+1)+(r-mid)+1
    207         for(;t<x;t<<=1);
    208         memcpy(mt1,ta[LC],sizeof(int)*(mid-l+2));
    209         memset(mt1+mid-l+2,0,sizeof(int)*(t-mid+l-2));
    210         memcpy(tp,ta[RC],sizeof(int)*(r-mid+1));
    211         memset(tp+r-mid+1,0,sizeof(int)*(t-r+mid-1));
    212         init(t);
    213         dft(mt1,t,1);dft(tp,t,1);
    214         for(int i=0;i<t;++i)
    215             tp[i]=ull(tp[i])*mt1[i]%md;
    216         dft(tp,t,-1);
    217         ta[u]=tp;tp+=r-l+2;
    218     }
    219     int mt2[N],mt3[N];
    220     void _p_me2(int *a,int n,int l,int r,int u)//a是A的系数,deg(A)<=n;求A(x_l)到A(x_r),放入y_l到y_r
    221     {
    222         if(r-l<=T)
    223         {
    224             int t,i,j;
    225             for(i=l;i<=r;++i)
    226             {
    227                 t=a[n];
    228                 for(j=n-1;j>=0;--j)
    229                     t=(ull(t)*x[i]+a[j])%md;
    230                 y[i]=t;
    231             }
    232             return;
    233         }
    234         int x=(n+1)<<1,t=1;
    235         for(;t<x;t<<=1);
    236         int mt4[t];//根据需要改成new?
    237         int mid=(l+r)>>1,n1;
    238         memcpy(mt1,a,sizeof(int)*(n+1));
    239         for(n1=n;n1>=0 && mt1[n1]==0;)    --n1;
    240         if(n1<0)
    241         {
    242             memset(y+l,0,sizeof(int)*(r-l+1));
    243             return;
    244         }
    245         memcpy(mt2,ta[LC],sizeof(int)*(mid-l+2));
    246         if(n1<mid-l+1)
    247         {
    248             memcpy(mt4,mt1,sizeof(int)*(n1+1));
    249             _p_me2(mt4,n1,l,mid,LC);
    250         }
    251         else
    252         {
    253             p_divmod(mt1,mt2,mt3,mt4,n1,mid-l+1);
    254             _p_me2(mt4,mid-l,l,mid,LC);    
    255         }
    256         memcpy(mt1,a,sizeof(int)*(n+1));
    257         for(n1=n;n1>=0 && mt1[n1]==0;)    --n1;
    258         memcpy(mt2,ta[RC],sizeof(int)*(r-mid+1));
    259         if(n1<r-mid)
    260         {
    261             memcpy(mt4,mt1,sizeof(int)*(n1+1));
    262             _p_me2(mt4,n1,mid+1,r,RC);
    263         }
    264         else
    265         {
    266             p_divmod(mt1,mt2,mt3,mt4,n1,r-mid);
    267             _p_me2(mt4,r-mid-1,mid+1,r,RC);
    268         }
    269     }
    270     void p_multieval(int *a0,int *x0,int *y0,int n,int m)//deg(a)=n,x有m个数
    271     {
    272         tp=data;
    273         a=a0;x=x0;y=y0;
    274         _p_me1(0,m-1,1);
    275         _p_me2(a,n,0,m-1,1);
    276     }
    277 }
    278 using P_me::p_multieval;
    279 int a[N],x[N],y[N];
    280 int n,m;
    281 int main()
    282 {
    283     int i;
    284     inv[1]=1;
    285     for(i=2;i<=300000;++i)
    286         inv[i]=ull(md-md/i)*inv[md%i]%md;
    287     for(i=1;i<300000;i<<=1)
    288     {
    289         pw_1[i]=poww(3,(md-1)/(i<<1));
    290         pw_2[i]=poww(332748118,(md-1)/(i<<1));
    291     }
    292     //n=100000;m=100000;
    293     scanf("%d%d",&n,&m);
    294     for(i=0;i<=n;++i)
    295         //a[i]=rand()%md;
    296         scanf("%d",a+i);
    297     for(i=0;i<m;++i)
    298         //x[i]=rand()%md;
    299         scanf("%d",x+i);
    300     p_multieval(a,x,y,n,m);
    301     for(i=0;i<m;++i)
    302         printf("%d
    ",y[i]);
    303     return 0;
    304 }
    View Code
  • 相关阅读:
    sqlplus -S选项说明
    oracle中常见set指令
    nohup详解
    centos64位编译32位程序
    【PHP系列】框架的抉择
    【PHP系列】PHP推荐标准之PSR-4,自动加载器策略
    【PHP系列】PHP推荐标准之PSR-3,日志记录器接口
    【PHP系列】PHP推荐标准之PSR-1,PSR-2
    【项目管理】管理工具的抉择 --- 持续更新中
    【CNMP系列】CentOS7.0下安装FTP服务
  • 原文地址:https://www.cnblogs.com/hehe54321/p/10616071.html
Copyright © 2020-2023  润新知