• CF1257G


    题目

    给定(n)个质数,设(S)代表这些质数的乘积所代表的数的所有正因子。对于一个有限正整数集合(D),如果任意(ain D)(bin D)(a eq b),都满足(a mid b),那么(D)就是好的。问(S)最大好子集是多大。

    题解

    假如这(n)个质数互不相等,显然答案就是(C(n,frac{n}{2}))。因为这样取的数互相不会整除彼此,而且是(C(n,k))中最大的数。那么如果有相等的数呢?类比一下,答案就是所有质因子个数(相同也算)为(frac{n}{2})的数的集合。它们显然互不整除,而且也是所有质因子个数相等的数的集合中最大的那个。

    这样问题就是一个简单背包dp了。一共有(m)个质数,设第(i)种质数(p_i)(c_i)个,(dp[pos][sum])代表前(i)种质数,剩余质因子个数为(sum)时的方案数。转移:

    [dp[i][j]=sumlimits_{k=0}^{c_i}{dp[i-1][j-k]} ]

    可以优化成(O(n^2)),用前缀和。btw,如果这里直接用(dp)存储前缀和,有(dp[0][k]=1),这样最后答案应该是(dp[m][frac{n}{2}]-dp[m][frac{n}{2}-1]);但是如果只令(dp[0][0]=1),即(dp[0])是差分的形式,然后直接按照前缀和优化来存储和计算(dp),最后答案直接就是(dp[m][frac{n}{2}]),因为一开始(dp[0])就是差分。

    显然会超时。如果用生成函数的思想,(p_i)(c_i)个,那么就构造多项式((x^0+x^1+...+x^{c_i})),然后将所有(p_i)的多项式乘起来,最后(x^{frac{n}{2}})的系数就是答案。直接分治+fft。

    vector作为返回值时是以作为右值返回,时间复杂度为(O(1)),常数不会太大,非常快,不用担心超时。将vector作为返回值问题不大。
    时间复杂度(O(nlog^2{n}))

    #include <bits/stdc++.h>
    
    #define endl '
    '
    #define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
    #define mp make_pair
    #define seteps(N) fixed << setprecision(N) 
    typedef long long ll;
    
    using namespace std;
    /*-----------------------------------------------------------------*/
    
    ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
    #define INF 0x3f3f3f3f
    const int N = 3e5 + 10;
    const int M = 998244353;
    
    int rev[N];
    inline ll qpow(ll a, ll b, ll m) {
        ll res = 1;
        while (b) {
            if (b & 1)
                res = (res * a) % m;
    
            a = (a * a) % m;
            b = b >> 1;
        }
        return res;
    }
    
    void change(vector<ll>& y, int len) { // 蝴蝶变换
        for (int i = 0; i < len; ++i) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) {
                rev[i] |= len >> 1;
            }
        }
        for (int i = 0; i < len; ++i) {
            if (i < rev[i]) {
                swap(y[i], y[rev[i]]);
            }
        }
        return;
    }
    
    void ntt(vector<ll>& y, int len, int on) { // -1逆变换
        change(y, len);
        for (int h = 2; h <= len; h <<= 1) {
            ll gn = qpow(3, (M - 1) / h, M); // 原根为3
            if (on == -1)
                gn = qpow(gn, M - 2, M);
            for (int j = 0; j < len; j += h) {
                ll g = 1;
    
                for (int k = j; k < j + h / 2; k++) {
                    ll u = y[k];
                    ll t = g * y[k + h / 2] % M;
                    y[k] = (u + t) % M;
                    y[k + h / 2] = (u - t + M) % M;
                    g = g * gn % M;
                }
            }
        }
        if (on == -1) {
            ll inv = qpow(len, M - 2, M);
            for (int i = 0; i < len; i++) {
                y[i] = y[i] * inv % M;
            }
        }
    }
    
    int get(int x) {
        int res = 1;
        while(res < x) {
            res <<= 1;
        }
        return res;
    }
    
    int arr[N];
    vector<int> num;
    vector<int> pre;
    
    vector<ll> solve(int l, int r) {
        if(l == r) {
            vector<ll> f;
            for(int i = 0; i <= num[l]; i++) {
                f.push_back(1);
            }
            return f;
        }
        int mid = (l + r) / 2;
        vector<ll> f = solve(l, mid);
        vector<ll> g = solve(mid + 1, r);
        int nl = f.size(), nr = g.size();
        int len = get(nl + nr - 1);
        f.resize(len, 0);
        g.resize(len, 0);
        ntt(f, len, 1);
        ntt(g, len, 1);
        for(int i = 0; i < len; i++) f[i] = f[i] * g[i] % M;
        ntt(f, len, -1);
        f.resize(nl + nr - 1);
        return f;
    }
    
    int main() {
        IOS;
        int n;
        cin >> n;
        for(int i = 1; i <= n; i++) {
            cin >> arr[i];
        }
        sort(arr + 1, arr + 1 + n);
        int cnt = 1;
        num.push_back(0);
        for(int i = 2; i <= n + 1; i++) {
            if(i > n || arr[i] != arr[i - 1]) {
                num.push_back(cnt);
                cnt = 1;
            } else cnt++;
        }
        for(int i = 0; i < num.size(); i++) {
            pre.push_back(num[i]);
            if(i) pre[i] = (pre[i] + pre[i - 1]) % M;
        }
        vector<ll> ans = solve(1, num.size() - 1);
        cout << ans[n / 2] << endl;
    }
    
  • 相关阅读:
    Leetcode: Summary Ranges
    Leetcode: Kth Smallest Element in a BST
    Leetcode: Basic Calculator II
    Leetcode: Basic Calculator
    Leetcode: Count Complete Tree Nodes
    Leetcode: Implement Stack using Queues
    Leetcode: Maximal Square
    Leetcode: Contains Duplicate III
    Leetcode: Invert Binary Tree
    Leetcode: The Skyline Problem
  • 原文地址:https://www.cnblogs.com/limil/p/15350584.html
Copyright © 2020-2023  润新知