• 「2019 集训队互测 Day 1」最短路径 题解


    「2019 集训队互测 Day 1」最短路径 题解

    题目传送门

    算法标签: 分治,ntt。

    这题主要考察了对于分治的应用。

    首先考虑最简单的“树”的情况。很容易想到,可以点分治+卷积实现。

    然后只剩下环的情况了。

    设环上的第距离环上的第i个点距离为j的点的个数为([x^j]f_i)

    设环长为(len)

    我们将从任意一个位置破环成链。然后再复制一份贴到后面。

    则我们要算出:(sum f_i imes f_j imes x^{j-i},(j>i,0leq i<len,j-igeq lfloor(len-1)/2 floor))。特殊的我们还需要处理一下偶数的情况。

    然后我们可以将那个长度为len+len的东西分成4/5段,每段为(lfloor(len-1)/2 floor),则段内部一定满足(j-igeq lfloor(len-1)/2 floor)这个条件,所以直接分治算就可以了。

    然后考虑两段之间的贡献,同样类似dp决策单调性那样分治就可以了。

    code:

    这个代码为了减小常数,采用带权分治。

    /*
    {
    ######################
    #       Author       #
    #        Gary        #
    #        2021        #
    ######################
    */
    #include <bits/stdc++.h>
    #define rb(a,b,c) for(int a=b;a<=c;++a)
    #define rl(a,b,c) for(int a=b;a>=c;--a)
    #define LL long long
    #define IT iterator
    #define PB push_back
    #define II(a,b) make_pair(a,b)
    #define FIR first
    #define SEC second
    #define FREO freopen("check.out","w",stdout)
    #define rep(a,b) for(int a=0;a<b;++a)
    #define SRAND mt19937 rng(chrono::steady_clock::now().time_since_epoch().count())
    #define random(a) rng()%a
    #define ALL(a) a.begin(),a.end()
    #define POB pop_back
    #define ff fflush(stdout)
    #define fastio ios::sync_with_stdio(false)
    #define check_min(a,b) a=min(a,b)
    #define check_max(a,b) a=max(a,b)
    using namespace std;
    const int INF = 0x3f3f3f3f;
    typedef pair<int, int> mp;
    inline int read() {
        int x = 0;
        char ch = getchar();
    
        while (ch < '0' || ch > '9') {
            ch = getchar();
        }
    
        while (ch >= '0' && ch <= '9') {
            x = (x << 1) + (x << 3) + (ch ^ 48);
            ch = getchar();
        }
    
        return x;
    }
    // NTT template begins
    int MOD = 998244353;
    int g = 3;
    int len;
    int rev[1 << 19];
    void butterfly(vector<int> &v) {
        rep(i, len) {
            rev[i] = rev[i >> 1] >> 1;
    
            if (i & 1)
                rev[i] |= len >> 1;
        }
    
        rep(i, len) if (rev[i] > i)
            swap(v[i], v[rev[i]]);
    }
    int quick(int A, int B) {
        if (B == 0)
            return 1;
    
        int  tmp = quick(A, B >> 1);
        tmp = 1ll * tmp * tmp % MOD;
    
        if (B & 1)
            tmp = 1ll * tmp * A % MOD;
    
        return tmp;
    }
    int inv(int x) {
        return quick(x, MOD - 2);
    }
    vector<int> ntt(vector<int> v, int ty) {
        for (auto &it : v) {
            it %= MOD;
        }
    
        butterfly(v);
        vector<int> nex;
    
        for (int l = 2; l <= len; l <<= 1) {
            nex.clear();
            nex.resize(len);
            int step = quick(g, (MOD - 1) / l);
    
            if (ty == -1)
                step = inv(step);
    
            for (int j = 0; j < len; j += l) {
                int now = 1;
    
                for (int k = 0; k < l / 2; ++k) {
                    int A, B;
                    A = v[j + k];
                    B = v[j + l / 2 + k];
                    B = 1ll * now * B % MOD;
                    nex[j + k] = (A + B) % MOD;
                    nex[j + k + l / 2] = (A - B + MOD) % MOD;
                    now = 1ll * now * step % MOD;
                }
            }
    
            v = nex;
        }
    
        return v;
    }
    void getlen(int x) {
        len = 1;
    
        while (len < x) {
            len <<= 1;
        }
    }
    vector<int> mul(vector<int> A, vector<int> B) {
        getlen(A.size() + B.size());
        A.resize(len);
        B.resize(len);
        A = ntt(A, 1);
        B = ntt(B, 1);
        rep(i, len) A[i] = 1ll * A[i] * B[i] % MOD;
        A = ntt(A, -1);
        int iv = inv(len);
        rep(i, len) {
            A[i] = 1ll * A[i] * iv % MOD;
        }
    
        while (!A.empty() && A.back() == 0)
            A.pop_back();
    
        return A;
    }
    void add(vector<int> &A, vector<int> B) {
        if (A.size() < B.size())
            A.resize(B.size());
    
        rep(i, B.size()) {
            (A[i] += B[i]) %= MOD;
        }
    }
    vector<int> right_shift(vector<int> A, int x) {
        reverse(ALL(A));
        rb(i, 1, x) A.PB(0);
        reverse(ALL(A));
        return A;
    }
    //NTT template ends
    const int MAXN = 1e5 + 233;
    int on_cycle[MAXN];
    int n, k;
    vector<int> gra[MAXN];
    int siz[MAXN];
    int sz = 0;
    int anslen;
    int get_centroid(int now, int fa = -1) {
        siz[now] = 1;
        int w = -INF;
    
        for (auto it : gra[now])
            if (it != fa && !on_cycle[it]) {
                int ret = get_centroid(it, now);
    
                if (ret)
                    return ret;
    
                siz[now] += siz[it];
                check_max(w, siz[it]);
            }
    
        check_max(w, sz - siz[now]);
    
        if (w <= sz / 2 + 3) {
            return now;
        }
    
        return 0;
    }
    vector<int> cycle;
    bool ok = 0;
    bool vis[MAXN];
    stack<int> sta;
    void findcycle(int now, int pre = -1) {
        if (cycle.size())
            return;
    
        vis[now] = true;
        sta.push(now);
    
        for (auto it : gra[now])
            if (it != pre) {
                if (vis[it]) {
                    int Now;
    
                    do {
                        Now = sta.top(), sta.pop();
                        cycle.PB(Now);
                    } while (Now != it);
    
                    return ;
                }
    
                findcycle(it, now);
    
                if (cycle.size())
                    return;
            }
    
        sta.pop();
    }
    vector<int> f[MAXN * 2];
    vector<int> ret;
    void calc(vector<int> &v, int now, int depth = 0, int pre = -1) {
        if (v.size() <= depth)
            v.resize(depth + 1);
    
        v[depth]++;
    
        for (auto it : gra[now])
            if (it != pre && !on_cycle[it]) {
                calc(v, it, depth + 1, now);
            }
    }
    void getsize(int now, int pre = -1) {
        sz++;
    
        for (auto it : gra[now])
            if (!on_cycle[it] && it != pre)
                getsize(it, now);
    }
    void get(int now) {
        sz = 0;
        getsize(now);
        now = get_centroid(now);
        bool pre = on_cycle[now];
        on_cycle[now] = 1;
        vector<int> presum(1, 1);
    
        for (auto it : gra[now])
            if (!on_cycle[it]) {
                get(it);
                vector<int> tmp;
                calc(tmp, it);
                add(ret, mul(presum, right_shift(tmp, 1)));
                add(presum, right_shift(tmp, 1));
            }
    
        on_cycle[now] = pre;
    }
    int to;
    void solve(int ansl, int ansr, int l, int r) {
        check_min(ansr, anslen);
    
        if (ansl >= ansr || l >= r)
            return;
    
        vector<int> lp, rp;
        rb(i, ansl, ansr - 1) {
            int st = ansr - 1 - i;
    
            if (st + f[i].size() > lp.size())
                lp.resize(st + f[i].size());
    
            rep(j, f[i].size()) {
                (lp[j + st] += f[i][j]) %= MOD;
            }
        }
        rb(i, l, r - 1) {
            int st = i - l;
    
            if (st + f[i].size() > rp.size())
                rp.resize(st + f[i].size());
    
            rep(j, f[i].size()) {
                (rp[j + st] += f[i][j]) %= MOD;
            }
        }
        vector<int> tmp = mul(lp, rp);
        int gap = l - ansr + 1;
        rep(i, tmp.size()) {
            (ret[i + gap] += tmp[i]) %= MOD;
        }
    }
    void divc(int ansl, int ansr, int l, int r) {
        if (l == r - 1) {
            solve(l - to, ansr, l, r);
            return ;
        }
    
        int mid = (l + r) >> 1;
        int ansmid = mid - to;
        divc(ansmid, ansr, mid, r);
        divc(ansl, ansmid, l, mid);
        solve(ansmid, ansr, l, mid);
    }
    void div1(int l, int r) {
        if (l == r - 1)
            return ;
    
        mp best = {INF, INF};
        int tot = 0;
        rb(i, l, r - 1) tot += f[i].size();
        tot /= 2;
        rb(i, l, r - 1) {
            tot -= f[i].size();
            check_min(best, II(abs(tot), i));
        }
        int mid = best.second + 1;
        solve(l, mid, mid, r);
        div1(l, mid);
        div1(mid, r);
    }
    main() {
        //  freopen("sub35.in","r",stdin);
        //  scanf("%d%d",&n,&k);
        n = read();
        k = read();
        rb(i, 1, n) {
            int u, v;
            //      scanf("%d%d",&u,&v);
            u = read();
            v = read();
    
            if (u == v) {
                ok = true;
                continue;
            }
    
            gra[u].PB(v), gra[v].PB(u);
        }
        int rest = 0;
        ret.resize(n + 1);
    
        if (ok) {
            get(1);
        } else {
            findcycle(1);
    
            for (auto it : cycle)
                on_cycle[it] = true;
    
            int now = 0;
    
            for (auto it : cycle) {
                on_cycle[it] = false;
                calc(f[now++], it), get(it);
                on_cycle[it] = true;
            }
    
            anslen = cycle.size();
            cycle.resize(anslen + anslen);
            rep(i, anslen) cycle[i + anslen] = cycle[i], f[i + anslen] = f[i];
            to = anslen / 2;
    
            if (anslen & 1);
            else {
                to--;
                vector<int> tmp;
    
                rep(i, anslen) if (i < (i + to + 1) % anslen)
                    add(tmp, mul(f[i], f[i + to + 1]));
    
                tmp = right_shift(tmp, to + 1);
                add(ret, tmp);
            }
    
            if (to) {
                vector<mp> each;
                int now = 0;
    
                while (now < anslen + anslen) {
                    int nex = min(anslen + anslen, now + to);
                    each.PB(II(now, nex));
                    div1(now, nex);
                    now = nex;
                }
    
                rb(i, 1, each.size() - 1) {
                    divc(each[i - 1].FIR, each[i - 1].SEC, each[i].FIR, each[i].SEC);
                }
            }
        }
    
        rb(i, 0, n) {
            rest += 1ll * ret[i] * quick(i, k) % MOD;
    
            if (rest >= MOD)
                rest -= MOD;
        }
        rest = 1ll * rest * inv(1ll * n * (n - 1) / 2 % MOD) % MOD;
        cout << rest << endl;
        return 0;
    }
    
  • 相关阅读:
    poj3181(Dollar Dayz)
    poj3666(Making the Grade)
    poj2392(Space Elevator)
    hdu5288(OO’s Sequence)
    hdu5289(Assignment)
    快学scala
    Spark Checkpointing
    Spark Performance Tuning (性能调优)
    Spark Memory Tuning (内存调优)
    Sparkstreaming and Kafka
  • 原文地址:https://www.cnblogs.com/gary-2005/p/14312616.html
Copyright © 2020-2023  润新知