• [洛谷P3301][BZOJ3129][SDOI2013]方程(扩展Lucas+容斥)


    Solution

    • 先考虑\(n_1=0\)的情况
    • 那么只要考虑形如\(X_i>=A_i\)的限制
    • 注意求的是正整数解的个数,即对于\(i>n_2\)\(X_i>=1(A_i=1)\)
    • \(\sum_{i=1}^{n}B_i=m\)非负整数解的个数为\(C(m+n-1,m)\)
    • 解释:序列共\(m+n-1\)个位置,选\(n-1\)个位置出来当隔板,把序列分为长度之和为\(m\)\(n\)段(可能存在长度为\(0\)的段,即隔板相邻的情况)
    • 现在为了满足这些限制,令\(B_i=X_i-A_i\),则\(B_i\)非负整数解的个数就是原题的合法解的个数
    • 那么\(m\)要减掉\(\sum_{i=1}^{n}A_i\)
    • 考虑\(n_1>0\)的情况,用总方案数\(-\)存在\(X_i>=A_i+1(1<=i<=n_1)\)的情况
    • 即考虑容斥:不考虑前\(n_1\)个数的限制的方案数\(-\)\(n_1\)个数至少有\(1\)个不满足条件的方案数\(+\)\(n_1\)个数至少有\(2\)个不满足条件的方案数\(-\)……
    • 发现\(n,m\)很大,但任意一组数据的\(p\)都可以拆成\(\Pi_{i=1}^{k}pi^{qi}\),且\(p_i<=10007\),那么用扩展\(lucas\)求组合数取模即可

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define ll long long
    
    template <class t>
    inline void read(t & res)
    {
    	char ch;
    	while (ch = getchar(), !isdigit(ch));
    	res = ch ^ 48;
    	while (ch = getchar(), isdigit(ch))
    	res = res * 10 + (ch ^ 48);
    }
    
    const int o = 2000;
    int a[o], b[o], pk, p, c[o], d[o], tst, n1, n2, n, m, ans, h[o], now, f[20][10010];
    bool vis[o];
    ll tot;
    
    inline int exgcd(int a, int b, int &x, int &y)
    {
    	if (!b)
    	{
    		x = 1;
    		y = 0;
    		return a;
    	}
    	int ret = exgcd(b, a % b, x, y), tmp = x;
    	x = y;
    	y = tmp - a / b * y;
    	return ret;
    }
    
    inline int ksm(int x, ll y)
    {
    	int res = 1;
    	while (y)
    	{
    		if (y & 1) res = (ll)res * x % pk;
    		y >>= 1;
    		x = (ll)x * x % pk;
    	}
    	return res;
    }
    
    inline int fac(int n, int p, int k)
    {
    	if (n == 1 || n == 0) return 1;
    	ll cnt = n / p, bl = n / pk, res = fac(n / p, p, k), i, tmp;
    	tot += cnt;
    	tmp = f[now][pk - 1];
    	tmp = ksm(tmp, bl);
    	res = res * tmp % pk;
    	res = res * f[now][n % pk] % pk;
    	return res;
    }
    
    inline int solve(int n, int m, int id)
    {
    	int p = c[id], k = d[id];
    	pk = a[id];
    	tot = 0;
    	int ra = fac(m, p, k); ll ta = tot;
    	tot = 0;
    	int rb = fac(n - m, p, k); ll tb = tot;
    	tot = 0;
    	int rc = fac(n, p, k); ll tc = tot;
    	ll t = tc - ta - tb;
    	if (t < 0) t = (t % k + k) % k;
    	int ia, ib, xxx;
    	exgcd(ra, pk, ia, xxx);
    	exgcd(rb, pk, ib, xxx);
    	if (ia < 0) ia += pk;
    	if (ib < 0) ib += pk;
    	return (ll)rc * ia % pk * ib % pk * ksm(p, t) % pk;
    }
    
    inline void init()
    {
    	int i, s = sqrt(p), lp = p, j;
    	for (i = 2; i <= s; i++)
    	if (lp % i == 0)
    	{
    		int t = 0, r = 1;
    		while (lp % i == 0) 
    		{
    			t++;
    			r *= i;
    			lp /= i;
    		}
    		a[++a[0]] = r; 
    		c[a[0]] = i;
    		d[a[0]] = t;
    	}
    	if (lp != 1) 
    	{
    		a[++a[0]] = lp;
    		c[a[0]] = lp;
    		d[a[0]] = 1;
    	}
    	for (i = 1; i <= a[0]; i++)
    	{
    		f[i][0] = 1;
    		for (j = 1; j <= a[i]; j++)
    		if (j % c[i]) f[i][j] = (ll)f[i][j - 1] * j % a[i];
    		else f[i][j] = f[i][j - 1];
    	}
    }
    
    inline int cc(ll n, ll m, int p)
    {
    	if (n < m || m < 0) return 0;
    	int ans = 0, i;
    	for (i = 1; i <= a[0]; i++) 
    	{
    		now = i;
    		b[i] = solve(n, m, i);
    	}
    	for (i = 1; i <= a[0]; i++)
    	{
    		int mi = p / a[i], g, y, aa = a[i];
    		exgcd(mi, aa, g, y);
    		ans = (ans + (ll)mi * g % p * b[i] % p + p) % p;
    	}
    	return ans;
    }
    
    inline void add(int &x, int y)
    {
    	x += y;
    	if (x >= p) x -= p;
    }
    
    inline void pd()
    {
    	int i, tm = m, cnt = 0;
    	for (i = 1; i <= n1; i++)
    	if (vis[i])
    	{
    		cnt++;
    		tm -= h[i] + 1;
    	}
    	else tm--;
    	if (!cnt) return;
    	if (cnt & 1) add(ans, p - cc(tm + n - 1, tm, p));
    	else add(ans, cc(tm + n - 1, tm, p));
    } 
    
    inline void dfs(int k)
    {
    	if (k == n1 + 1)
    	{
    		pd();
    		return;
    	}
    	vis[k] = 0;
    	dfs(k + 1);
    	vis[k] = 1;
    	dfs(k + 1);
    }
    
    int main()
    {
    	int i;
    	read(tst); read(p);
    	init();
    	while (tst--)
    	{
    		read(n); 
    		read(n1); 
    		read(n2);
    		read(m);
    		int tmp = n1 + n2;
    		for (i = 1; i <= tmp; ++i) read(h[i]);
    		m -= n - n1 - n2;
    		for (i = n1 + 1; i <= n2 + n1; i++) m -= h[i];
    		int tm = m - n1;
    		ans = cc(tm + n - 1, tm, p);
    		dfs(1);
    		printf("%d\n", ans);
    	}
    	return 0;
    }
    
  • 相关阅读:
    c#个人记录常用方法(更新中)
    Newtonsoft.Json.dll解析json的dll文件使用
    组织http请求
    ado.net中的几个对象
    jquery-easyui使用
    aspx与mvc页面验证码
    aspx页面状态管理(查询字符串Request与Application)
    aspx页面状态管理Cookie和ViewState
    在网页中插入qq连接
    ASP.NET中上传图片检测其是否为真实的图片 防范病毒上传至服务器
  • 原文地址:https://www.cnblogs.com/cyf32768/p/12196441.html
Copyright © 2020-2023  润新知