• 【XSY2744】信仰圣光 分治FFT 多项式exp 容斥原理


    题目描述

      有一个(n)个元素的置换,你要选择(k)个元素,问有多少种方案满足:对于每个轮换,你都选择了其中的一个元素。

      对(998244353)取模。

      (kleq nleq 152501)

    题解

    吐槽

      为什么一道FFT题要把(n)设为(150000)

    解法一

      先把轮换拆出来。

      直接DP。

      设(f_{i,j})为前(i)个轮换选择了(j)个元素,且每个轮换都选择了至少一个元素的方案数。

    [f_{i,j}=sum_{k=1}^{a_i}f_{i-1,j-k}inom{a_i}{k} ]

      时间复杂度为(O(n^2)),因为枚举的是第(i)组和前(i-1)组的配对,而任意两个元素之间最多被配对一次。

      可以分治FFT做到(O(nlog^2 n))

    解法二

      考虑容斥。

      设(m)为轮换个数。

      枚举有哪些轮换(S)中可能有被选中的元素,容斥系数就是({(-1)}^{m-|S|})(sum)为这些轮换的大小总和):

      或者枚举哪些轮换(S)中没有被选中的元素,容斥系数就是({(-1)}^{|S|})

    [egin{align} s&=sum_{S}{(-1)}^{m-|S|}inom{sum}{k}\ s&=sum_{S}{(-1)}^{|S|}inom{n-sum}{k}\ end{align} ]

      现在我们要对于每一个(i),计算(f_i=sum_{S,sum=i}{(-1)}^{|S|})

      构造生成函数(A_i(x)=1-x^{a_i}),那么(F(x)=prod_{i=1}^mA_i(x))

      直接做还是(O(nlog^2n))的。我们需要一些优化。

    [egin{align} F(x)&=prod_{i=1}^m1-x^{a_i}\ ln(F(x))&=sum_{i=1}^nln(1-x^{a_i})\ ln(F(x))&=sum_{i=1}^nsum_{j=a_i}-frac{x^{ja_i}}{j} end{align} ]

      那么可以在(O(nlog n))内算出(ln(F(x))),然后(exp)一下。

      时间复杂度:(O(nlog n))

      由于常数过大,所以要用下面那条式子(因为只用计算到(x^{n-k}))。

    解法一

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<utility>
    #include<iostream>
    using namespace std;
    typedef long long ll;
    typedef pair<int,int> pii;
    void open(const char *s)
    {
    #ifndef ONLINE_JUDGE
    	char str[100];
    	sprintf(str,"%s.in",s);
    	freopen(str,"r",stdin);
    	sprintf(str,"%s.out",s);
    	freopen(str,"w",stdout);
    #endif
    }
    int rd()
    {
    	int s=0,c;
    	while((c=getchar())<'0'||c>'9');
    	s=c-'0';
    	while((c=getchar())>='0'&&c<='9')
    		s=s*10+c-'0';
    	return s;
    }
    const int p=998244353;
    const int g=3;
    ll fp(ll a,ll b)
    {
    	ll s=1;
    	for(;b;b>>=1,a=a*a%p)
    		if(b&1)
    			s=s*a%p;
    	return s;
    }
    ll inv[200010];
    ll fac[200010];
    ll ifac[200010];
    int a[200010];
    int n,m,k;
    int c[200010];
    int b[200010];
    ll getc(int x,int y)
    {
    	return fac[x]*ifac[y]%p*ifac[x-y]%p;
    }
    ll *f[500010];
    int len[500010];
    int cnt;
    int a1[600010];
    int a2[600010];
    int rev[600010];
    void ntt(int *a,int n,int t)
    {
    	for(int i=1;i<n;i++)
    	{
    		rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
    		if(i>rev[i])
    			swap(a[i],a[rev[i]]);
    	}
    	for(int i=2;i<=n;i<<=1)
    	{
    		int wn=fp(g,(p-1)/i*(t==1?1:i-1));
    		for(int j=0;j<n;j+=i)
    		{
    			int w=1;
    			for(int k=j;k<j+i/2;k++)
    			{
    				int u=a[k];
    				int v=(ll)a[k+i/2]*w%p;
    				a[k]=(u+v)%p;
    				a[k+i/2]=(u-v)%p;
    				w=(ll)w*wn%p;
    			}
    		}
    	}
    	if(t==-1)
    	{
    		int inv=fp(n,p-2);
    		for(int i=0;i<n;i++)
    			a[i]=(ll)a[i]*inv%p;
    	}
    }
    void solve(int &now,int l,int r)
    {
    	now=++cnt;
    	if(l==r)
    	{
    		len[now]=min(a[l],k);
    		f[now]=new ll[len[now]+1];
    		f[now][0]=0;
    		for(int i=1;i<=len[now];i++)
    			f[now][i]=ifac[i]*ifac[a[l]-i]%p;
    		return;
    	}
    	int ls,rs,mid=(l+r)>>1;
    	solve(ls,l,mid);
    	solve(rs,mid+1,r);
    	len[now]=min(len[ls]+len[rs],k);
    	f[now]=new ll[len[now]+1];
    	int v=1;
    	while(v<=len[ls]+len[rs])
    		v<<=1;
    	for(int i=0;i<v;i++)
    		a1[i]=(i<=len[ls]?f[ls][i]:0);
    	for(int i=0;i<v;i++)
    		a2[i]=(i<=len[rs]?f[rs][i]:0);
    	ntt(a1,v,1);
    	ntt(a2,v,1);
    	for(int i=0;i<v;i++)
    		a1[i]=(ll)a1[i]*a2[i]%p;
    	ntt(a1,v,-1);
    	for(int i=0;i<=len[now];i++)
    		f[now][i]=a1[i];
    	delete [] f[ls];
    	delete [] f[rs];
    }
    void solve()
    {
    //	scanf("%d%d",&n,&k);
    	n=rd();
    	k=rd();
    	for(int i=1;i<=n;i++)
    		c[i]=rd();
    //		scanf("%d",&c[i]);
    	if(k==n)
    	{
    		printf("1
    ");
    		return;
    	}
    	m=0;
    	cnt=0;
    	memset(b,0,sizeof b);
    	memset(a,0,sizeof a);
    	for(int i=1;i<=n;i++)
    		if(!b[i])
    		{
    			m++;
    			for(int j=i;!b[j];j=c[j])
    			{
    				b[j]=1;
    				a[m]++;
    			}
    		}
    	if(k<m)
    	{
    		printf("0
    ");
    		return;
    	}
    	int rt;
    	solve(rt,1,m);
    	ll ans=f[rt][k];
    	ans=ans*fp(getc(n,k),p-2)%p;
    	for(int i=1;i<=m;i++)
    		ans=ans*fac[a[i]]%p;
    	ans=(ans+p)%p;
    	printf("%lld
    ",ans);
    }
    int main()
    {
    	open("a");
    	inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
    	for(int i=2;i<=200000;i++)
    	{
    		inv[i]=-p/i*inv[p%i]%p;
    		fac[i]=fac[i-1]*i%p;
    		ifac[i]=ifac[i-1]*inv[i]%p;
    	}
    	int t;
    //	scanf("%d",&t);
    	t=rd();
    	while(t--)
    		solve();
    	return 0;
    }
    

    解法二

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<utility>
    #include<iostream>
    using namespace std;
    typedef long long ll;
    typedef pair<int,int> pii;
    int rd()
    {
    	int s=0,c;
    	while((c=getchar())<'0'||c>'9');
    	s=c-'0';
    	while((c=getchar())>='0'&&c<='9')
    		s=s*10+c-'0';
    	return s;
    }
    void open(const char *s)
    {
    #ifndef ONLINE_JUDGE
    	char str[100];
    	sprintf(str,"%s.in",s);
    	freopen(str,"r",stdin);
    	sprintf(str,"%s.out",s);
    	freopen(str,"w",stdout);
    #endif
    }
    const int p=998244353;
    const int g=3;
    ll fp(ll a,ll b)
    {
    	ll s=1;
    	for(;b;b>>=1,a=a*a%p)
    		if(b&1)
    			s=s*a%p;
    	return s;
    }
    ll inv[300010];
    ll fac[300010];
    ll ifac[300010];
    namespace ntt
    {
    	int rev[600000];
    	void ntt(int *a,int n,int t)
    	{
    		for(int i=1;i<n;i++)
    		{
    			rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
    			if(i>rev[i])
    				swap(a[i],a[rev[i]]);
    		}
    		for(int i=2;i<=n;i<<=1)
    		{
    			int wn=fp(g,(p-1)/i*(t==1?1:i-1));
    			for(int j=0;j<n;j+=i)
    			{
    				int w=1;
    				for(int k=j;k<j+i/2;k++)
    				{
    					int u=a[k];
    					int v=(ll)a[k+i/2]*w%p;
    					a[k]=(u+v)%p;
    					a[k+i/2]=(u-v)%p;
    					w=(ll)w*wn%p;
    				}
    			}
    		}
    		if(t==-1)
    		{
    			int inv=fp(n,p-2);
    			for(int i=0;i<n;i++)
    				a[i]=(ll)a[i]*inv%p;
    		}
    	}
    	void getinv(int *a,int *b,int n)
    	{
    		if(n==1)
    		{
    			b[0]=fp(a[0],p-2);
    			return;
    		}
    		getinv(a,b,n>>1);
    		static int a1[600000],a2[600000];
    		for(int i=0;i<n;i++)
    			a1[i]=a[i];
    		for(int i=n;i<n<<1;i++)
    			a1[i]=0;
    		for(int i=0;i<n>>1;i++)
    			a2[i]=b[i];
    		for(int i=n>>1;i<n<<1;i++)
    			a2[i]=0;
    		ntt(a1,n<<1,1);
    		ntt(a2,n<<1,1);
    		for(int i=0;i<n<<1;i++)
    			a1[i]=a2[i]*(2-(ll)a1[i]*a2[i]%p)%p;
    		ntt(a1,n<<1,-1);
    		for(int i=0;i<n;i++)
    			b[i]=a1[i];
    	}
    	void getln(int *a,int *b,int n)
    	{
    		static int a1[600000],a2[600000];
    		for(int i=1;i<n;i++)
    			a1[i-1]=(ll)a[i]*i%p;
    		a1[n-1]=0;
    		getinv(a,a2,n);
    		for(int i=n;i<n<<1;i++)
    			a1[i]=a2[i]=0;
    		ntt(a1,n<<1,1);
    		ntt(a2,n<<1,1);
    		for(int i=0;i<n<<1;i++)
    			a1[i]=(ll)a1[i]*a2[i]%p;
    		ntt(a1,n<<1,-1);
    		for(int i=1;i<n;i++)
    			b[i]=(ll)a1[i-1]*inv[i]%p;
    		b[0]=0;
    	}
    	void getexp(int *a,int *b,int n)
    	{
    		if(n==1)
    		{
    			b[0]=1;
    			return;
    		}
    		getexp(a,b,n>>1);
    		static int a1[600000],a2[600000],a3[600000];
    		for(int i=n>>1;i<n;i++)
    			b[i]=0;
    		getln(b,a3,n);
    		for(int i=0;i<n>>1;i++)
    		{
    			a1[i]=b[i];
    			a2[i]=(a[i+(n>>1)]-a3[i+(n>>1)])%p;
    		}
    		for(int i=n>>1;i<n;i++)
    			a1[i]=a2[i]=0;
    		ntt(a1,n,1);
    		ntt(a2,n,1);
    		for(int i=0;i<n;i++)
    			a1[i]=(ll)a1[i]*a2[i]%p;
    		ntt(a1,n,-1);
    		for(int i=0;i<n>>1;i++)
    			b[i+(n>>1)]=a1[i];
    	}
    }
    int a[200010];
    int n,m,k;
    int c[200010];
    int b[200010];
    int cnt;
    ll ans;
    int d[300010];
    int s[300010];
    int f[300010];
    ll getc(int x,int y)
    {
    	if(y>x||y<0)
    		return 0;
    	return fac[x]*ifac[y]%p*ifac[x-y]%p;
    }
    void dfs(int x,int y,int v)
    {
    	if(x>m)
    	{
    		ans=(ans+v*getc(y,k))%p;
    		return;
    	}
    	dfs(x+1,y,v);
    	dfs(x+1,y+a[x],-v);
    }
    void solve()
    {
    //	scanf("%d%d",&n,&k);
    	n=rd();
    	k=rd();
    	for(int i=1;i<=n;i++)
    		c[i]=rd();
    //		scanf("%d",&c[i]);
    	if(k==n)
    	{
    		printf("1
    ");
    		return;
    	}
    	m=0;
    	cnt=0;
    	memset(b,0,sizeof b);
    	memset(a,0,sizeof a);
    	for(int i=1;i<=n;i++)
    		if(!b[i])
    		{
    			m++;
    			for(int j=i;!b[j];j=c[j])
    			{
    				b[j]=1;
    				a[m]++;
    			}
    		}
    	if(k<m)
    	{
    		printf("0
    ");
    		return;
    	}
    	memset(d,0,sizeof d);
    	memset(s,0,sizeof s);
    	for(int i=1;i<=m;i++)
    		d[a[i]]++;
    	for(int i=1;i<=n;i++)
    		if(d[i])
    			for(int j=1;i*j<=n;j++)
    				s[i*j]=(s[i*j]-inv[j]*d[i])%p;
    	int l=1;
    	while(l<=n-k)
    		l<<=1;
    	s[0]=1;
    	ntt::getexp(s,f,l);
    	ans=0;
    	for(int i=0;i<=n-k;i++)
    		ans=(ans+f[i]*getc(n-i,k))%p;
    //		ans=(ans+f[i]*getc(i,k))%p;
    	ans=ans*fp(getc(n,k),p-2)%p;
    //	if(m&1)
    //		ans=-ans;
    	ans=(ans+p)%p;
    	printf("%lld
    ",ans);
    }
    int main()
    {
    	open("a");
    	inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
    	for(int i=2;i<=300000;i++)
    	{
    		inv[i]=-p/i*inv[p%i]%p;
    		fac[i]=fac[i-1]*i%p;
    		ifac[i]=ifac[i-1]*inv[i]%p;
    	}
    	int t;
    //	scanf("%d",&t);
    	t=rd();
    	while(t--)
    		solve();
    	return 0;
    }
    
  • 相关阅读:
    Java多线程系列 JUC锁03 公平锁(一)
    Java多线程系列 JUC锁02 互斥锁ReentrantLock
    JDBC课程3--通过ResultSet执行查询操作
    JDBC课程2--实现Statement(用于执行SQL语句)--使用自定义的JDBCTools的工具类静态方法,包括insert/update/delete三合一
    JDBC_通过DriverManager获得数据库连接
    JDBC课程1-实现Driver接口连接mysql数据库、通用的数据库连接方法(使用文件jdbc.properties)
    [终章]进阶20-流程控制结构--if/case/while结构
    MySQL进阶19--函数的创建(举例)/设置mysql的创建函数的权限/查看(show)/删除(drop) / 举4个栗子
    MySQL进阶18- 存储过程- 创建语句-参数模式(in/out/inout-对应三个例子) -调用语法-delimiter 结束标记'$'- 删除/查看/修改-三个练习
    SQL进阶17-变量的声明/使用(输出)--全局变量/会话变量--用户变量/局部变量
  • 原文地址:https://www.cnblogs.com/ywwyww/p/8561048.html
Copyright © 2020-2023  润新知