• 【XSY2680】玩具谜题 NTT 牛顿迭代


    题目描述

      小南一共有(n)种不同的玩具小人,每种玩具小人的数量都可以被认为是无限大。每种玩具小人都有特定的血量,第(i)种玩具小人的血量就是整数(i)。此外,每种玩具小人还有自己的攻击力,攻击力可以是任意非负整数,且两种不同的玩具小人的攻击力可以相同。我们把第(i)种玩具小人的血量和攻击力表示成(a_i)(b_i)

      为了让玩具小人们进行战斗,小南打算把一些小人选出来,编成队伍。一个队伍可以表示成一个由玩具小人组成的序列:((p_1,p_2,ldots,p_l)),其中(p_i)表示队伍中第(i)个玩具小人的种类,(l)为队伍的长度。对于不同的(i)(p_i)可以相同。两个队伍被认为相同,当且仅当长度相同,且每个位置的玩具小人种类都分别相同。

      一个队伍也有血量和攻击力两个属性,记为(a_t,b_t)。队伍的血量就是每个玩具小人的血量之和,而队伍攻击力可能会由于队伍内部产生矛盾而减小,对于长度为(l)的队伍,队伍的攻击力为每个玩具小人的攻击力之乘积除以(l)的阶乘。同时,当(l)大于等于某个常数(c)时,攻击力会有一个额外的加成:乘以((1+frac{l!}{(l−c)!}))。也就是说:

    [a_t=sum_{i=1}^la_{p_i}\ b_t=egin{cases} frac{1}{l!}sum_{i=1}^lb_{p_i}&,l<c\ (frac{1}{l!}+frac{1}{(l-c)!})sum_{i=1}^lb_{p_i}&,lgeq c end{cases} ]

      然而,小南的玩具小人们对小南的独裁统治感到愤怒,准备联合起来发起民主运动。为了旗帜鲜明地反对动乱,小南必须了解清楚玩具小人们的战斗力。不幸的是,由于玩具小人数量过多,小南已经忘记每种玩具小人的战斗力具体是多少了。现在,小南掌握的情报只有对于每个(1)(n)之间的整数(i),所有血量等于(i)的不同队伍的战斗力之和对(998244353)取模的值是多少((s_i))。他希望你根据已有的情报,还原出每种玩具小人的战斗力对(998244353)取模的结果 。如果镇压成功了,小南会请你到北京去做一回总书记(当然是北京玩具协会的总书记)。

      (nleq 60000,0leq cleq n)

    题解

      设(F=sum_{igeq 1}b_i,S=sum_{igeq 0}s_i),如果(c=0),那么(s_0=2)

    [egin{align} sum_{igeq 0}frac{F^i}{i!}+sum_{igeq 0}frac{F^i}{i!}&=S\ 2e^F&=S\ F=lnfrac{S}{2} end{align} ]

      否则(s_0=1)

    [egin{align} sum_{igeq 1}frac{F^i}{i!}+sum_{igeq c}frac{F^i}{(i-c)!}&=S-1\ sum_{igeq 1}frac{F^i}{i!}+F^csum_{igeq0}frac{F^i}{i!}&=S-1\ (F^c+1)e^F&=S end{align} ]

      然后就是牛顿迭代解方程。我们需要满足

    [g(F)=(F^c+1)e^F-S=0 ]

      的(F)。设当前求出了

    [g(F_0)equiv0pmod {x^{frac{n}{2}}} ]

      的(F_0),现在我们要求(F)满足

    [g(F)equiv 0pmod {x^n} ]

      考虑在(F_0)出对(g)泰勒展开

    [g(F)=g(F_0)+g'(F_0)(F-F_0)+frac{g''(F_0)}{2}{(F-F_0)}^2+cdots ]

      后面的项都是(0),因为(F-F_0)的最小非零项的次数至少是(frac{n}{2}),所以后面的部分在模(x^n)意义下一定会被消掉。

      式子就变成了

    [egin{align} g(F)&equiv g(F_0)+g'(F_0)(F-F_0)pmod {x^n}\ F&equiv F_0-frac{g(F_0)}{g'(F_0)}pmod {x^n}\ F&equiv F_0-frac{({F_0}^c+1)e^{F_0}-S}{(c{F_0}^{c-1}+{F_0}^c+1)e^{F_0}}pmod {x^n} end{align} ]

      套各种多项式算法可以做到

    [T(n)=T(frac{n}{2})+O(nlog n)=O(nlog n) ]

      常数巨大。

    代码

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cstdlib>
    #include<ctime>
    #include<utility>
    #include<cmath>
    #include<functional>
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int,int> pii;
    typedef pair<ll,ll> pll;
    void sort(int &a,int &b)
    {
    	if(a>b)
    		swap(a,b);
    }
    void open(const char *s)
    {
    #ifndef ONLINE_JUDGE
    	char str[100];
    	sprintf(str,"%s.in",s);
    	freopen(str,"r",stdin);
    	sprintf(str,"%s.out",s);
    	freopen(str,"w",stdout);
    #endif
    }
    int rd()
    {
    	int s=0,c;
    	while((c=getchar())<'0'||c>'9');
    	do
    	{
    		s=s*10+c-'0';
    	}
    	while((c=getchar())>='0'&&c<='9');
    	return s;
    }
    void put(int x)
    {
    	if(!x)
    	{
    		putchar('0');
    		return;
    	}
    	static int c[20];
    	int t=0;
    	while(x)
    	{
    		c[++t]=x%10;
    		x/=10;
    	}
    	while(t)
    		putchar(c[t--]+'0');
    }
    int upmin(int &a,int b)
    {
    	if(b<a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    int upmax(int &a,int b)
    {
    	if(b>a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    const ll p=998244353;
    const ll g=3;
    const int maxn=65536;
    ll fp(ll a,ll b)
    {
    	ll s=1;
    	for(;b;b>>=1,a=a*a%p)
    		if(b&1)
    			s=s*a%p;
    	return s;
    }
    ll inv[200000];
    namespace ntt
    {
    	int rev[200000];
    	int m;
    	void ntt(ll *a,int n,int t)
    	{
    		ll u,v,w,wn;
    		int i,j,k;
    		if(n!=m)
    		{
    			m=n;
    			rev[0]=0;
    			for(i=1;i<n;i++)
    				rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
    		}
    		for(i=0;i<n;i++)
    			if(rev[i]<i)
    				swap(a[i],a[rev[i]]);
    		for(i=2;i<=n;i<<=1)
    		{
    			wn=fp(g,(p-1)/i);
    			if(t==-1)
    				wn=fp(wn,p-2);
    			for(j=0;j<n;j+=i)
    			{
    				w=1;
    				for(k=j;k<j+i/2;k++)
    				{
    					u=a[k];
    					v=a[k+i/2]*w%p;
    					a[k]=(u+v)%p;
    					a[k+i/2]=(u-v)%p;
    					w=w*wn%p;
    				}
    			}
    		}
    		if(t==-1)
    		{
    			ll inv=fp(n,p-2);
    			for(i=0;i<n;i++)
    				a[i]=a[i]*inv%p;
    		}
    	}
    	void getinv(ll *a,ll *b,int n)
    	{
    		if(n==1)
    		{
    			b[0]=fp(a[0],p-2);
    			return;
    		}
    		getinv(a,b,n>>1);
    		static ll a1[200000],a2[200000];
    		int i;
    		for(i=0;i<n;i++)
    			a1[i]=a[i];
    		for(;i<n<<1;i++)
    			a1[i]=0;
    		for(i=0;i<n>>1;i++)
    			a2[i]=b[i];
    		for(;i<n<<1;i++)
    			a2[i]=0;
    		ntt(a1,n<<1,1);
    		ntt(a2,n<<1,1);
    		for(i=0;i<n<<1;i++)
    			a1[i]=(2*a2[i]-a1[i]*a2[i]%p*a2[i])%p;
    		ntt(a1,n<<1,-1);
    		for(i=0;i<n;i++)
    			b[i]=a1[i];
    	}
    	void getln(ll *a,ll *b,int n)
    	{
    		static ll a1[200000],a2[200000];
    		int i;
    		for(i=1;i<n;i++)
    			a1[i-1]=a[i]*i%p;
    		a1[n-1]=0;
    		getinv(a,a2,n);
    		for(i=n;i<n<<1;i++)
    			a1[i]=a2[i]=0;
    		ntt(a1,n<<1,1);
    		ntt(a2,n<<1,1);
    		for(i=0;i<n<<1;i++)
    			a1[i]=a1[i]*a2[i]%p;
    		ntt(a1,n<<1,-1);
    		b[0]=0;
    		for(i=1;i<n;i++)
    			b[i]=a1[i-1]*inv[i]%p;
    	}
    	void getexp(ll *a,ll *b,int n)
    	{
    		if(n==1)
    		{
    			b[0]=1;
    			return;
    		}
    		getexp(a,b,n>>1);
    		static ll a1[200000],a2[200000];
    		int i;
    		for(i=0;i<n>>1;i++)
    			a1[i]=b[i];
    		for(;i<n<<1;i++)
    			a1[i]=0;
    		for(i=n>>1;i<n;i++)
    			b[i]=0;
    		getln(b,a2,n);
    		for(i=0;i<n;i++)
    			a2[i]=-a2[i];
    		for(i=n;i<n<<1;i++)
    			a2[i]=0;
    		a2[0]++;
    		for(i=0;i<n;i++)
    			a2[i]=(a2[i]+a[i])%p;
    		ntt(a1,n<<1,1);
    		ntt(a2,n<<1,1);
    		for(i=0;i<n<<1;i++)
    			a1[i]=a1[i]*a2[i]%p;
    		ntt(a1,n<<1,-1);
    		for(i=0;i<n;i++)
    			b[i]=a1[i];
    	}
    	void getpow(ll *a,ll *b,int n,ll k)
    	{
    		int d=0;
    		while(d<n&&!a[d])
    			d++;
    		int i;
    		if(d>=n)
    		{
    			for(i=0;i<n;i++)
    				b[i]=0;
    			if(!k)
    				b[0]=1;
    			return;
    		}
    		static ll a1[200000],a2[200000];
    		ll c=a[d];
    		ll e=fp(c,p-2);
    		for(i=0;i<n;i++)
    			if(i+d<n)
    				a1[i]=a[i+d]*e%p;
    			else
    				a1[i]=0;
    		getln(a1,a2,n);
    		for(i=0;i<n;i++)
    			a2[i]=a2[i]*k%p;
    		getexp(a2,a1,n);
    		for(i=0;i<n&&i<d*k;i++)
    			b[i]=0;
    		c=fp(c,k);
    		for(i=d*k;i<n;i++)
    			b[i]=a1[i-d*k]*c%p;
    	}
    	void mul(ll *a,ll *b,ll *c,int n)
    	{
    		int i;
    		static ll a1[200000],a2[200000];
    		for(i=0;i<n;i++)
    		{
    			a1[i]=a[i];
    			a2[i]=b[i];
    		}
    		for(;i<n<<1;i++)
    			a1[i]=a2[i]=0;
    		ntt(a1,n<<1,1);
    		ntt(a2,n<<1,1);
    		for(i=0;i<n<<1;i++)
    			a1[i]=a1[i]*a2[i]%p;
    		ntt(a1,n<<1,-1);
    		for(i=0;i<n;i++)
    			c[i]=a1[i];
    	}
    }
    using namespace ntt;
    ll a[200000],b[200000];
    void init()
    {
    	int i;
    	inv[0]=inv[1]=1;
    	for(i=2;i<=maxn;i++)
    		inv[i]=-p/i*inv[p%i]%p;
    }
    int c;
    void gao(ll *a,ll *b,int n)
    {
    	if(n==1)
    	{
    		b[0]=0;
    		return;
    	}
    	gao(a,b,n>>1);
    	int i;
    	for(i=n>>1;i<n;i++)
    		b[i]=0;
    	static ll a1[200000],a2[200000],a3[200000],a4[200000],a5[200000],a6[200000],a7[200000];
    	//a1=F^(c-1)
    	getpow(b,a1,n,c-1);
    	//a2=F^c=a1F
    	mul(a1,b,a2,n);
    	//a3=e^F
    	getexp(b,a3,n);
    	for(i=0;i<n;i++)
    		a4[i]=a2[i];
    	a4[0]++;
    	mul(a4,a3,a5,n);
    	for(i=0;i<n;i++)
    		a5[i]=(a5[i]-a[i])%p;
    	for(i=0;i<n;i++)
    		a6[i]=(a2[i]+c*a1[i])%p;
    	a6[0]++;
    	mul(a6,a3,a7,n);
    	getinv(a7,a6,n);
    	mul(a6,a5,a7,n);
    	for(i=0;i<n;i++)
    		b[i]=(b[i]-a7[i])%p;
    }
    void gao2(ll *a,ll *b,int n)
    {
    	int i;
    	for(i=0;i<n;i++)
    		a[i]=a[i]*inv[2]%p;
    	getln(a,b,n);
    }
    int n;
    int main()
    {
    	init();
    	open("c");
    	scanf("%d%d",&n,&c);
    	int m=1;
    	while(m<=n)
    		m<<=1;
    	int i;
    	for(i=1;i<=n;i++)
    		scanf("%lld",&a[i]);
    	for(i=n+1;i<m;i++)
    		a[i]=0;
    	if(!c)
    	{
    		a[0]=2;
    		gao2(a,b,m);
    	}
    	else
    	{
    		a[0]=1;
    		gao(a,b,m);
    	}
    	for(i=1;i<=n;i++)
    	{
    		b[i]=(b[i]+p)%p;
    		printf("%lld
    ",b[i]);
    	}
    	return 0;
    }
    
  • 相关阅读:
    Java之美[从菜鸟到高手演变]之设计模式
    常见JAVA框架
    每周一荐:学习ACE一定要看的书
    YUV格式&像素
    关于makefile
    socket通信
    [理论篇]一.JavaScript中的死连接`javascript:void(0)`和空连接`javascript:;`
    [应用篇]第三篇 JSP 标准标签库(JSTL)总结
    [应用篇]第一篇 EL表达式入门
    KVM基本实现原理
  • 原文地址:https://www.cnblogs.com/ywwyww/p/8513585.html
Copyright © 2020-2023  润新知