公式:$f(x)=sum_{i=1}^{n} y_{i} prod_{i eq j} frac{x-x_{j}}{x_{i}-x_{j}}$.
这个式子正常算的话是 $O(n^2)$ 的,如果遇到 $x$ 是连续的情况可以优化到 $O(n log n)$.
但是有些时候我们只知道 $f(x)$ 在 $x=k$ 时的点值是不够的,有时必须求出这个多项式每一位系数.
多项式快速插值可以做到 $O(n log^2 n)$,但是快速插值非常非常难写,用处并不多.
相比之下,有一种简易的写法可以在 $O(n^2)$ 的时间复杂度内通过 $n$ 个不同的点来还原一个 $n-1$ 次多项式.
插值公式中 $prod_{j eq i} (x-x_{j})$ 是比较难求的,其他地方由于都是基于整数的运算,所以比较简单.
先令 $f_{i,j}$ 表示考虑前 $i$ 个点 $(x,y)$,$x^j$ 前的系数.
那么有转移:$f_{i,j}=f_{i-1,j-1}+f_{i-1,j} imes (-x_{i})$ 即分别表示当前位的贡献为 $x^1 / -x_{i}$.
求出这个后,我们枚举 $i$,然后想 $O(n)$ 计算 $h(x)=prod_{i eq j} (x-x_{j})$.
令 $k1[i],k2[i]$ 分别表示 $h(x)$ 的 $x^i$ 前的系数,强制让第 $i$ 位贡献 $x^1$ 时 $x^i$ 前的系数.
由于有 $k2$ 这个强制贡献的状态,转移就比较简单:
$k1[i] leftarrow k2[i+1]$
$f_{n,i}=k1[i] imes (-x_{i}) +k2[i] Rightarrow k2[i]=f_{n,i}+k1[i] imes (x_{i})$.
算出 $k2$ 后把 $y_{i}$ 及插值公式中分母的贡献乘上然后累加到答案数组中即可.
应用:
求 $sum_{i=1}^{n} i^k$.
这是一个关于 $n$ 的 $k+1$ 次多项式.
所以可以取 $k+2$ 个点带进去,然后用拉格朗日插值法来求值.
具体,$f(k)=sum_{i=1}^{n} y_{i} prod_{}^{i eq j}frac{k-x_{j}}{x_{i}-x_{j}}$
由于点可以做到取 $x$ 连续的,所以提前预处理前缀/后缀积极可以做到 $O(n log n)$.
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 1000009 #define ll long long #define mod 1000000007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int f[N]; int ifac[N],fac[N],pre[N],suf[N],inv[N],n,K; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod; return tmp; } int INV(int x) { return qpow(x,mod-2); } void init() { ifac[0]=fac[0]=inv[1]=1; for(int i=2;i<N;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; inv[0]=1; for(int i=1;i<N;++i) { fac[i]=(ll)fac[i-1]*i%mod; ifac[i]=(ll)ifac[i-1]*inv[i]%mod; } pre[0]=1,suf[n+1]=1; for(int i=1;i<=n;++i) pre[i]=(ll)pre[i-1]*(K-i+mod)%mod; for(int i=n;i>=1;--i) suf[i]=(ll)suf[i+1]*(K-i+mod)%mod; } int sol() { int ans=0; for(int i=1;i<=n;++i) { int a1=(ll)ifac[i-1]*ifac[n-i]%mod; if((n-i)&1) a1=(ll)a1*(mod-1)%mod; int a2=(ll)pre[i-1]*suf[i+1]%mod; (ans+=(ll)f[i]*a1%mod*a2%mod)%=mod; } return ans; } int main() { // setIO("input"); scanf("%d%d",&K,&n),n+=2; init(); for(int i=1;i<=n;++i) { f[i]=(ll)(f[i-1]+qpow(i,n-2))%mod; } printf("%d ",sol()); return 0; }
还原多项式系数
#include <cstdio> #include <cstring> #include <algorithm> #define N 2008 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int f[N][N],k1[N],k2[N],s[N],n; struct point { int x,y; point(int x=0,int y=0):x(x),y(y){} }a[N]; int ADD(int x,int y) { return (ll)(x+y)%mod; } int DEC(int x,int y) { return (ll)(x-y+mod)%mod; } int MUL(int x,int y) { return (ll)x*y%mod; } int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=MUL(x,x)) if(y&1) { tmp=MUL(tmp,x); } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void init() { f[0][0]=1; for(int i=1;i<=n;++i) { for(int j=1;j<=i;++j) { f[i][j]=ADD(f[i-1][j-1],MUL(mod-a[i].x,f[i-1][j])); } f[i][0]=MUL(f[i-1][0],mod-a[i].x); } } int main() { // setIO("input"); int X; scanf("%d%d",&n,&X); for(int i=1;i<=n;++i) { scanf("%d%d",&a[i].x,&a[i].y); } init(); for(int i=1;i<=n;++i) { for(int j=0;j<=n;++j) k2[j]=f[n][j]; for(int j=n-1;j>=0;--j) { k1[j]=k2[j+1]; k2[j]=ADD(k2[j],MUL(k1[j],a[i].x)); } int inv=1; for(int j=1;j<=n;++j) if(i!=j) { inv=(ll)inv*(a[i].x-a[j].x+mod)%mod; } inv=get_inv(inv); for(int j=0;j<=n-1;++j) { (s[j]+=(ll)inv*a[i].y%mod*k1[j]%mod)%=mod; } } int ans=0; for(int i=n-1;i>=0;--i) { ans=(ll)((ll)ans*X%mod+s[i])%mod; } printf("%d ",ans); return 0; }
例题
CF917D Stranger Trees
给你一颗树,求 $n$ 个点有多少个生成树满足该生成树与给定树有 $k$ 条边是重合的.
题解:
先对完全图构建矩阵,然后将原树上的边 $(x,y)$ 在矩阵中的边权标记成 $x^1$,其余边权为 $1$.
矩阵树定理求的是所有生成树边权乘积之和,那么要是可以对含 $x$ 的矩阵求行列式的话可以直接得出答案.
但是复杂度太高,而且难写(写不了)
所以用 $n$ 个不同的整数来替换那个 $x^1$,然后跑出来 $n$ 个结果,用拉格朗日插值还原出多项式的系数即可.
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 103 #define ll long long #define mod 1000000007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int n; int A[N],B[N]; int f[N][N],k1[N],k2[N],ans[N]; int deg[N][N],con[N][N],a[N][N]; struct point { int x,y; point(int x=0,int y=0):x(x),y(y){} }p[N]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) { tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } int ADD(int x,int y) { return (ll)(x+y)%mod; } int DEC(int x,int y) { return (ll)(x-y+mod)%mod; } int MUL(int x,int y) { return (ll)x*y%mod; } int gauss() { int ans=1; for(int i=1;i<n;++i) { for(int j=i+1;j<n;++j) { while(a[j][i]) { int t=a[i][i]/a[j][i]; for(int k=i;k<n;++k) { a[i][k]=DEC(a[i][k],MUL(t,a[j][k])); } swap(a[j],a[i]); ans=(ll)ans*(mod-1)%mod; } } if(!a[i][i]) { return 0; } } for(int i=1;i<n;++i) { ans=(ll)ans*a[i][i]%mod; } return ans; } int cal(int val) { for(int i=1;i<=n;++i) { for(int j=1;j<=n;++j) { a[i][j]=mod-1; } } for(int i=1;i<=n;++i) { a[i][i]=n-1; } for(int i=1;i<n;++i) { int x=A[i],y=B[i]; a[x][x]=(ll)(DEC(a[x][x],1)+val)%mod; a[y][y]=(ll)(DEC(a[y][y],1)+val)%mod; a[x][y]=(ll)(a[x][y]+1-val+mod)%mod; a[y][x]=(ll)(a[y][x]+1-val+mod)%mod; } return gauss(); } void init() { f[0][0]=1; for(int i=1;i<=n;++i) { for(int j=1;j<=i;++j) f[i][j]=ADD(f[i-1][j-1],MUL(f[i-1][j],mod-p[i].x)); f[i][0]=(ll)f[i-1][0]*(mod-p[i].x)%mod; } } int main() { // setIO("input"); scanf("%d",&n); int x,y,z; for(int i=1;i<n;++i) { scanf("%d%d",&A[i],&B[i]); } for(int i=1;i<=n;++i) { p[i].x=i; p[i].y=cal(i); } init(); for(int i=1;i<=n;++i) { int inv=1; for(int j=1;j<=n;++j) { if(i!=j) inv=(ll)inv*(p[i].x-p[j].x+mod)%mod; } inv=get_inv(inv); for(int j=0;j<=n;++j) { k2[j]=f[n][j]; } for(int j=n-1;j>=0;--j) { k1[j]=k2[j+1]; k2[j]=ADD(k2[j],MUL(p[i].x,k1[j])); } for(int j=0;j<=n-1;++j) { ans[j]=ADD(ans[j],(ll)k1[j]*inv%mod*p[i].y%mod); } } for(int i=0;i<n;++i) { printf("%d ",ans[i]); } return 0; }
LuoguP4463 [集训队互测2012] calc
朴素的 DP 非常好列:$f[i][j]$ 表示选了 $i$ 个数,且值域为 $[1,j]$ 的总价值和.
那么有 $f[i][j]=f[i-1][j-1] imes j+f[i][j-1]$,直接算的话复杂度是 $O(nD)$ 的.
但是我们可以猜测这是一个关于 $j$ 的 $g_{i}$ 次多项式.
有一个结论:对于 $n$ 次多项式 $h(x)$,满足 $h(x)-h(x-1)$ 是 $n-1$ 次多项式.
那么有 $f[i][j]-f[i][j-1]=f[i-1][j-1] imes j$.
将 $g$ 带入,有 $g_{i}-1=g_{i-1}+1$.
即 $g_{i}=g_{i-1}+2$,说明这是一个关于 $j$ 的 $2 imes i$ 次多项式.
那么我们就求出 $f[n][1...2n+1]$ 后将值带入,然后拉格朗日插值来插一下就行了.
code:
#include <cstdio> #include <cstring> #include <algorithm> #define N 2002 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; int D,n,mod,tot,f[N][N],fac[N]; void init() { fac[0]=1; for(int i=1;i<N;++i) { fac[i]=(ll)fac[i-1]*i%mod; } } struct point { int x,y; point(int x=0,int y=0):x(x),y(y){} }a[N]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod; return tmp; } int get_inv(int x) { return qpow(x,mod-2); } int calc() { int ans=0; for(int i=1;i<=tot;++i) { int inv=1,up=1; for(int j=1;j<=tot;++j) { if(i==j) continue; up=(ll)up*(D-a[j].x+mod)%mod; inv=(ll)inv*(a[i].x-a[j].x+mod)%mod; } inv=get_inv(inv); (ans+=(ll)a[i].y*up%mod*inv%mod)%=mod; } return ans; } int main() { // setIO("input"); scanf("%d%d%d",&D,&n,&mod); init(); for(int i=0;i<=2*n+1;++i) f[0][i]=1; for(int i=1;i<=n;++i) { for(int j=1;j<=2*n+1;++j) { f[i][j]=(ll)(f[i][j-1]+(ll)f[i-1][j-1]*j%mod)%mod; } } for(int i=1;i<=2*n+1;++i) { a[++tot]=point(i,f[n][i]); } printf("%d ",(ll)calc()*fac[n]%mod); return 0; }