• AtCoder AGC036C GP 2 (组合计数)


    题目链接

    https://atcoder.jp/contests/agc036/tasks/agc036_c

    题解

    终于有时间补agc036的题了。
    这题其实不难的来着……我太菜了考场上没想出来

    首先转化一下题目: 一个序列可以被按题目的操作方式生成当且仅当它长度为(N), 总和为(3M), 且最大数不超过(2M), 奇数的个数不超过(M).
    必要性显然,充分性归纳易证。

    然后考虑怎么计数: 先不考虑第二个条件,定义(f(n,m,k))表示长度为(n)总和为(m)奇数不超过(k)个的方案数,那么枚举奇数的个数(i), 剩下的偶数和为(m-1), 有(f(n,m,k)=sum^{k}_{iequiv m(mod 2)}{nchoose i}{frac{m-i}{2}+n-1choose n-1}).
    考虑第二个条件,补集转化,最大数大于(2M)意味着剩下的所有数和小于(M), 那么不要把和式写出来然后无脑推式子!固定下最大的数的位置(1),给第一个数减去(2M) (这是个偶数所以不影响奇数那个条件),就是要求(N)个数和为(M), 第一个数大于(0),一共有不超过(M)个奇数的方案数。这个因为有奇数个数的限制所以枚举很麻烦,那就再补集转化!转化为((N-1))个数和为(M)且奇数不超过(M)个。

    因此最后答案就是(f(N,3M,M)-N(f(N,M,M)-f(N-1,M,M))).

    时间复杂度(O(N+M)).

    代码

    #include<cstdio>
    #include<cstdlib>
    #include<cstring>
    #include<cassert>
    #include<iostream>
    #define llong long long
    using namespace std;
    
    inline int read()
    {
    	int x=0; bool f=1; char c=getchar();
    	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
    	for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
    	if(f) return x;
    	return -x;
    }
    
    const int N = 2e6;
    const int P = 998244353;
    llong fact[N+3],finv[N+3];
    
    llong quickpow(llong x,llong y)
    {
    	llong cur = x,ret = 1ll;
    	for(int i=0; y; i++)
    	{
    		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
    		cur = cur*cur%P;
    	}
    	return ret;
    }
    llong comb(llong x,llong y) {return x<0||y<0||x<y ? 0ll : fact[x]*finv[y]%P*finv[x-y]%P;}
    
    llong calc(llong n,llong m,llong k)
    {
    	llong ret = 0ll;
    	for(int i=0; i<=k; i++)
    	{
    		if((m-i)&1) continue;
    		llong tmp = comb(n,i)*comb(((m-i)>>1)+n-1,n-1)%P;
    		ret = (ret+tmp)%P;
    	}
    //	printf("calc %lld %lld %lld=%lld
    ",n,m,k,ret);
    	return ret;
    }
    
    int n,m;
    
    int main()
    {
    	fact[0] = 1ll; for(int i=1; i<=N; i++) fact[i] = fact[i-1]*i%P;
    	finv[N] = quickpow(fact[N],P-2); for(int i=N-1; i>=0; i--) finv[i] = finv[i+1]*(i+1)%P;
    	scanf("%d%d",&n,&m);
    	llong ans = calc(n,3*m,m);
    	ans = (ans-n*(calc(n,m,m)-calc(n-1,m,m)+P)%P+P)%P;
    	printf("%lld
    ",ans);
    	return 0;
    }
    
  • 相关阅读:
    Java 8特性
    11成最多体积的容器
    MySQL数据库理解
    java范型
    ArrayList源码分析
    1.面试题
    jvm简单了解
    121. 买卖股票的最佳时机
    有效的括号
    java如何判断一个字符串中某个字符有几个
  • 原文地址:https://www.cnblogs.com/suncongbo/p/11297768.html
Copyright © 2020-2023  润新知