• 【LOJ】#6289. 花朵


    题解

    我当时连(n^2)的树背包都搞不明白,这道题稳稳的爆零啊= =

    然后听说这道题需要FFT……我当时FFT的板子都敲不对,然后这道题就扔了

    然后,我去考了thusc……好吧,令人不愉快的经历,听说我要是把这道题做了我大概就能A了D2T2……生无可恋.jpg

    还有一个月,加油吧,NOI2018可能是我最后能去thu的机会了TAT

    设dp[u][0 / 1][i]为以u为根的子树,没选u还是选了u,一共选了i个点
    转移就是从所有子树里选出大小为i的独立集更新,转移可以类似树背包
    这道题dp方程写出来卷积优化就是显然的,关键是怎么优化

    我们把这个树给树链剖分了,设g[u][0 / 1][i]为u这个点除了u的重儿子以外的子树,没选u还是选了u,独立集大小为i的值(把i当成指数,把这个数组当成一个多项式)这是我们用来卷积的多项式

    我们从深度最深的链开始,由于我们希望一下子算出一条链,汇总到链顶,而不关心链上每个点的dp值,用分治FFT把一条链的答案算出来,具体就是存四个多项式,记录这条链的头尾选或没选,然后合并起来

    合并到父亲的时候我们对于每个点的所有轻儿子也分治乘起来,如果一个个乘起来会达到(n ^ 2)

    分治一个链的复杂度是(O(size(p) log^{2} size(p)))p是重链顶端,然后因为轻重链剖分,所以(sum size(p) = O(N log N))复杂度为(O(N log^3 N))

    代码

    #include <iostream>
    #include <cstdio>
    #include <vector>
    #include <algorithm>
    #include <cmath>
    #include <cstring>
    #include <map>
    //#define ivorysi
    #define pb push_back
    #define space putchar(' ')
    #define enter putchar('
    ')
    #define mp make_pair
    #define pb push_back
    #define fi first
    #define se second
    #define mo 974711
    #define MAXN 80005
    #define RG register
    using namespace std;
    typedef long long int64;
    typedef double db;
    template<class T>
    void read(T &res) {
        res = 0;char c = getchar();T f = 1;
        while(c < '0' || c > '9') {
    	if(c == '-') f = -1;
    	c = getchar();
        }
        while(c >= '0' && c <= '9') {
    	res = res * 10 + c - '0';
    	c = getchar();
        }
        res *= f;
    }
    template<class T>
    void out(T x) {
        if(x < 0) {putchar('-');x = -x;}
        if(x >= 10) {
    	out(x / 10);
        }
        putchar('0' + x % 10);
    }
    const int MOD = 998244353,L = (1 << 18);
    int W[L + 5],N,M,B[MAXN];
    int siz[MAXN],dep[MAXN],fa[MAXN],son[MAXN];
    int top[MAXN],Line[MAXN],tot,cnt,lsiz[MAXN],dfn[MAXN];
    vector<int> f[MAXN][2],zero,g[2][MAXN];
    struct node {
        int to,next;
    }E[MAXN * 2];
    struct res_node {
        vector<int> f00,f01,f10,f11;
    };
    int head[MAXN],sumE;
    void add(int u,int v) {
        E[++sumE].to = v;
        E[sumE].next = head[u];
        head[u] = sumE;
    }
    int mul(int a,int b) {
        return 1LL * a * b % MOD;
    }
    int inc(int a,int b) {
        a = a + b;
        if(a >= MOD) a -= MOD;
        return a;
    }
    int fpow(int x,int c) {
        int res = 1,t = x;
        while(c) {
    	if(c & 1) res = mul(res,t);
    	t = mul(t,t);
    	c >>= 1;
        }
        return res;
    }
    void NTT(vector<int> &a,int len,int on) {
        a.resize(len);
        for(int i = 1 , j = len / 2 ; i < len - 1 ; ++i) {
    	if(i < j) swap(a[i],a[j]);
    	int k = len / 2;
    	while(j >= k) {
    	    j -= k;
    	    k >>= 1;
    	}
    	j += k;
        }
        for(int h = 2 ; h <= len ; h <<= 1) {
    	int wn = W[(L + on * L / h) % L];
    	for(int k = 0 ; k < len ; k += h) {
    	    int w = 1;
    	    for(int j = k ; j < k + h / 2 ; ++j) {
    		int u = a[j],t = mul(a[j + h / 2],w);
    		a[j] = inc(u,t);
    		a[j + h / 2] = inc(u,MOD - t);
    		w = mul(w,wn);
    	    }
    	}
        }
        if(on == -1) {
    	int InvL = fpow(len,MOD - 2);
    	for(int i = 0 ; i < len ; ++i) a[i] = mul(a[i],InvL);
        }
    }
    vector<int> operator - (vector<int> a,vector<int> b) {
        int s = max(a.size(),b.size());
        a.resize(s);b.resize(s);
        vector<int> c;c.clear();
        for(int i = 0 ; i < s ; ++i) c.pb(inc(a[i],MOD - b[i]));
        return c;
    }
    vector<int> operator + (vector<int> a,vector<int> b) {
        int s = max(a.size(),b.size());
        a.resize(s);b.resize(s);
        vector<int> c;c.clear();
        for(int i = 0 ; i < s ; ++i) c.pb(inc(a[i],b[i]));
        return c;
    }
    vector<int> operator * (vector<int> a,vector<int> b) {
        int t = a.size() + b.size() - 2,T = 1;
        while(T <= t) T <<= 1;
        vector<int> c;c.clear();
        NTT(a,T,1);NTT(b,T,1);
        for(int i = 0 ; i < T ; ++i) c.pb(mul(a[i],b[i]));
        NTT(c,T,-1);
        if(T > M + 1) c.resize(M + 1),T = M + 1;
        for(int i = T - 1 ; i > 0 ; --i) {
    	if(!c[i]) c.pop_back();
    	else break;
        }
        return c;
    }
    void dfs1(int u) {
        dep[u] = dep[fa[u]] + 1;
        siz[u] = 1;
        for(int i = head[u] ; i ; i = E[i].next) {
    	int v = E[i].to;
    	if(v != fa[u]) {
    	    fa[v] = u;
    	    dfs1(v);
    	    siz[u] += siz[v];
    	    if(siz[v] > siz[son[u]]) son[u] = v;
    	}
        }
    }
    void dfs2(int u) {
        dfn[u] = ++tot;Line[tot] = u;
        ++cnt;
        if(!top[u]) top[u] = u;
        if(son[u]) {
    	top[son[u]] = top[u];
    	dfs2(son[u]);
        }
        else {
    	lsiz[top[u]] = cnt;
    	cnt = 0;
        }
        for(int i = head[u] ; i ; i = E[i].next) {
    	int v = E[i].to;
    	if(v != son[u] && v != fa[u]) dfs2(v);
        }
    }
    void Init() {
        W[0] = 1;W[1] = fpow(3,(MOD - 1) / L);
        for(int i = 2 ; i < L ; ++i) W[i] = mul(W[i - 1],W[1]);
        read(N);read(M);
        for(int i = 1 ; i <= N ; ++i) {
    	read(B[i]);
    	f[i][0].pb(1);
    	f[i][1].pb(0),f[i][1].pb(B[i]);
        }
        int u,v;
        for(int i = 1 ; i < N ; ++i) {
    	read(u);read(v);add(u,v);add(v,u);
        }
        dfs1(1);
        dfs2(1);
    }
    res_node DC(int l,int r) {
        if(l == r) {
    	int u = Line[l];
    	return (res_node){f[u][0],zero,zero,f[u][1]};
        }
        int mid = (l + r) >> 1;
        res_node wl = DC(l,mid),wr = DC(mid + 1,r);
        return (res_node){
    	(wl.f00 + wl.f01) * (wr.f10 + wr.f00) - wl.f01 * wr.f10,
    	(wl.f00 + wl.f01) * (wr.f11 + wr.f01) - wl.f01 * wr.f11,
    	(wl.f10 + wl.f11) * (wr.f10 + wr.f00) - wl.f11 * wr.f10,
    	(wl.f10 + wl.f11) * (wr.f01 + wr.f11) - wl.f11 * wr.f11,
        };
    }
    vector<int> mul(vector<int> *g,int l,int r) {
        if(l == r) return g[l];
        int mid = (l + r) >> 1;
        return mul(g,l,mid) * mul(g,mid + 1,r);
    }
    void Solve() {
        res_node t;
        for(int i = N ; i >= 1 ; --i) {
    	int u = Line[i];
    	if(top[u] == u) {
    	    for(int j = dfn[u] ; j <= dfn[u] + lsiz[u] - 1 ; ++j) {
    		int tot = 0;
    		int c = Line[j];
    		for(int k = head[c] ; k ; k = E[k].next) {
    		    int v = E[k].to;
    		    if(v != fa[c] && v != son[c]) g[0][++tot] = f[v][0] + f[v][1],g[1][tot] = f[v][0];
    		}
    		if(!tot) continue;
    		f[c][0] = mul(g[0],1,tot);
    		f[c][1] = f[c][1] * mul(g[1],1,tot);
    	    }
    	    t = DC(dfn[u],dfn[u] + lsiz[u] - 1);
    	    f[u][0] = t.f00 + t.f01;
    	    f[u][1] = t.f10 + t.f11;
    	}
        }
        f[1][0].resize(M + 1);f[1][1].resize(M + 1);
        out(inc(f[1][0][M],f[1][1][M]));enter;
    }
    int main() {
    #ifdef ivorysi
        freopen("f1.in","r",stdin);
    #endif
        Init();
        Solve();
        return 0;
    }
    
  • 相关阅读:
    HUST第八届程序设计竞赛-G小乐乐打游戏(双bfs)
    HDU-1575-Tr A(矩阵快速幂模板)
    HDU-1061-Rightmost Digit (快速幂模板)
    HihoCoder 1142-三分求极值(三分模板)
    Aizu ITP2_6_A(二分模板)
    Codeforces-938D-Buy a Ticket(最短路设虚拟节点+Dijk优先队列优化)
    POJ-1797-Heavy Transportation(最短路变形)
    HDU-5137-How Many Maos Does the Guanxi Worth(最短路删点)
    POJ-1094-Sorting It All Out (拓扑排序)(判断环和排名是否唯一)
    HDU-1869-六度分离(多源到多源最短路)
  • 原文地址:https://www.cnblogs.com/ivorysi/p/9143127.html
Copyright © 2020-2023  润新知