• 联合省选 2020 补题记录


    有点做不动题,就来写点东西。。。

    鸽子了 A 卷的保序回归 和 B 卷。

    通过翻游记大概搞到了题目顺序

    d1t1 : P6619 [省选联考 2020 A/B 卷] 冰火战士

    d1t2 : P6620 [省选联考 2020 A 卷] 组合数问题

    d2t1 : P6622 [省选联考 2020 A/B 卷] 信号传递

    d2t2 : P6623 [省选联考 2020 A 卷] 树

    d2t3 : P6624 [省选联考 2020 A 卷] 作业题

    我才发现顺序就是题目编号

    【upd】心情好,随手写了一道 B 卷题:P6626 [省选联考 2020 B 卷] 消息传递

    补题的时候我都不开 O2 以还原最真实的情景,所以不考虑开 O2 能过的做法


    冰火战士

    毕竟是补题,当时一群人被卡常大肆交流做法的时候也看到了 树状数组上二分 / 线段树上二分 这个算法标签。。。

    但是既然是补题,还是写已知的最优秀的做法吧 qaq。

    显然场地温度为 (d) 时,冰战士能量和是 温度 (le d) 的能量战士和,火战士能量和是 温度 (ge d) 的战士能量和。

    两边都带等不方便二分,不妨把所有火战士的温度都 (+1),变为 (>d) 的战士能量和。

    最后答案是两边能量和较小值的两倍。

    这个东西显然可能有一段是平的,不好三分。

    考虑二分两个位置,第一个是 最靠后的 冰能量 不大于 火能量 的位置 以及 第一次 冰能量 大于 火能量 的位置

    树状数组上二分的时候,火能量可以通过总和减去前缀,冰能量直接前缀计算即可。

    显然答案是这两个位置之一。

    求第一个位置可以直接树状数组上二分,但是第二个位置是“最靠前的”,不那么常规。

    仔细想想发现第二个位置的火能量是可以直接计算的,然后找到火战士 最靠后的值不小于火能量的位置 即可。

    第一次写树状数组上二分感觉非常 /tuu,码力完全不行,写了 3h 重构了一发又写了 1h+,考场上绝对完蛋。

    跑了 2.3s,质疑线段树二分能不能过去。。。

    【upd】Guidingstar 说他考场上线段树二分过了,那可能是洛谷慢?

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	return f?x:-x;
    }
    const int N = 2000005;
    int Q, lsh[N], len, sum2;
    int tr1[N], tr2[N];
    struct node {
    	int op, x, y, z;
    } a[N];
    inline void add1(int x, int d) {//ice, 前缀 
    	for(int i = x; i <= len; i += i & -i) tr1[i] += d;
    }
    inline void add2(int x, int d) {//fire, 后缀
    	sum2+=d; 
    	for(int i = x; i <= len; i += i & -i) tr2[i] += d;
    }
    inline int ask1(int x) {
    	int res = 0;
    	for(int i = x; i > 0; i -= i & -i) res += tr1[i];
    	return res;
    }
    inline int ask2(int x) {
    	int res = 0;
    	for(int i = x; i > 0; i -= i & -i) res += tr2[i];
    	return res;
    }
    signed main() {
    	Q = read();
    	rep(i, 1, Q) {
    		a[i].op = read(), a[i].x = read();
    		if(a[i].op == 1) a[i].y = read(), a[i].z = read(), lsh[++len] = a[i].y;
    	}
    	sort(lsh + 1, lsh + len + 1), len = unique(lsh + 1, lsh + len + 1) - lsh - 1;
    	rep(i, 1, Q) if(a[i].op == 1)
    		a[i].y = lower_bound(lsh + 1, lsh + len + 1, a[i].y) - lsh;
    	rep(i, 1, Q) {
    		int op = a[i].op, x = a[i].x;
    		if(op == 2) {
    			if(a[x].x == 0) add1(a[x].y, -a[x].z);
    			else add2(a[x].y + 1, -a[x].z);
    		} else {
    			if(a[i].x == 0) add1(a[i].y, a[i].z);
    			else add2(a[i].y + 1, a[i].z);
    		}
    		int r = 0, s1 = 0, s2 = sum2;
    		pair <int, int> ans1 = mkp(0, 0), ans2 = mkp(0, 0);
    		for(int i = 1 << 20; i; i >>= 1){
    			int tmp = r + i;
    			if(tmp <= len && s1 + tr1[tmp] <= s2 - tr2[tmp])
    				r = tmp, s1 += tr1[tmp], s2 -= tr2[tmp];
    		}
    		ans1 = mkp(s1, r);
    		if(r < len) {
    			ans2.fi = sum2 - ask2(r + 1);
    			s1 = 0, s2 = sum2, r = 0;
    			for(int i = 1 << 20; i; i >>= 1) {
    				int tmp = r + i, t1 = s1 + tr1[tmp], t2 = s2 - tr2[tmp];
    				if(tmp <= len && (t2 >= ans2.fi))
    					r = tmp, s1 = t1, s2 = t2;
    			}
    			ans2.se = r;
    		}
    		ans1 = max(ans1, ans2);
    		if(ans1.fi == 0) puts("Peace"); 
    		else printf("%d %lld
    ", lsh[ans1.se], ans1.fi * 2ll);
    	}
    }
    

    组合数问题

    上来扔给你一个式子叫你算感觉非常可怕。。。

    但是再看几眼发现这个东西非常的 naive。

    先做一点基础的化简,(f(k)) 必然要展开

    [sum_{k=0}^{n}sum_{i=0}^{m}a_ik^i x^kinom{n}{k}\ =sum_{i=0}^{m}a_isum_{k=0}^{n}k^ix^kinom{n}{k} ]

    注意到那个次幂是一维特别大一维特别小,想到了第二类斯特林数展开:

    [m ^ n = sum_{i} egin{Bmatrix}n\iend{Bmatrix}i!inom{m}{i} ]

    带进去

    [=sum_{i = 0} ^{m} a_i sum_{k = 0} ^ {n} sum_{j = 0} ^ {m}egin{Bmatrix} i \ j end{Bmatrix}j!inom{k}{j} x ^ k inom{n}{k}\ =sum_{i = 0} ^{m} a_i sum_{j = 0} ^ {m} egin{Bmatrix} i \ j end{Bmatrix}j! sum_{k = 0} ^ {n} x ^ k inom{n}{j}inom{n-j}{k-j}\ =sum_{i = 0} ^{m} a_i sum_{j = 0} ^ {m}inom{n}{j} egin{Bmatrix} i \ j end{Bmatrix}j! x ^ j sum_{k = 0} ^ {n - j} x ^ k inom{n-j}{k}\ =sum_{i = 0} ^{m} a_i sum_{j = 0} ^ {m}n^{underline{j}} egin{Bmatrix} i \ j end{Bmatrix} x ^ j (1+x)^{n-j} ]

    这 100 分大概是白送的?心态好一点勇敢想正解都能想出来吧!

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	return f?x:-x;
    }
    
    const int N = 1005;
    int n, x, mod, m, a[N], S2[N][N], fac[N], dwn[N], ans, pw1[N], pw2[N];
    inline int qpow(int n, int k) {
    	int res = 1;
    	for(; k; k >>= 1, n = 1ll * n * n % mod)
    		if(k & 1) res = 1ll * n * res % mod;
    	return res;
    }
    signed main() {
    	n = read(), x = read(), mod = read(), m = read();
    	rep(i, 0, m) a[i] = read();
    	S2[0][0] = 1;
    	rep(i, 1, m) {
    		rep(j, 1, i)
    			S2[i][j] = (1ll * S2[i - 1][j] * j % mod + S2[i - 1][j - 1]) % mod;
    	}
    	fac[0] = 1;
    	for(int i = 1; i <= m; ++i) fac[i] = 1ll * i * fac[i - 1] % mod;
    	dwn[0] = 1;
    	for(int i = 1; i <= m; ++i) dwn[i] = 1ll * (n - i + 1) * dwn[i - 1] % mod;
    	pw1[m] = qpow(1 + x, n - m);
    	for(int i = m - 1; i >= 0; --i) pw1[i] = 1ll * pw1[i + 1] * (1 + x) % mod;
    	pw2[0] = 1;
    	for(int i = 1; i <= m; ++i) pw2[i] = 1ll * pw2[i - 1] * x % mod;
    	for(int i = 0; i <= m; ++i) pw1[i] = 1ll * pw1[i] * pw2[i] % mod;
    	for(int i = 0; i <= m; ++i) {
    		int res = 0;
    		for(int j = 0; j <= i; ++j)
    			res = (res + 1ll * S2[i][j] * dwn[j] % mod * pw1[j] % mod) % mod;
    		ans = (ans + 1ll * res * a[i] % mod) % mod;
    	}
    	cout << ans << '
    ';
    }
    

    信号传递

    统计 (c(i,j)) 表示 (i) 塔向 (j) 塔传递了几次。

    (dp(msk)) 表示 ([1,popcount(msk)]) 使用 (msk) 里面为 (1) 的位置来填的最小贡献。

    每一次枚举一个 (i otin msk) 来转移。

    (A=msk,B={x|x otin A operatorname{and} x ot=i})

    那么

    (c(i,j)(jin A)) 的贡献是 (k imes c(i,j) imes (x_i+x_j)),其中 (x) 为下标。

    (c(i,j)(jin B)) 的贡献是 (c(i,j) imes (x_j-x_i))(显然 (j)(i) 后面所以不带绝对值)。

    由于每一次枚举 (i) 的时候 (x_j) 并不知道,所以对于 (jin A) 提前加贡献,对于 (jin B) 延后算贡献。

    综上,大概思路是,枚举 (i otin msk),枚举 (j ot=i)

    显然 (x_i=popcount(msk)+1)(当前集合大小加一)。

    如果 (jin A),则转移 (dp(msk|2^i)gets dp(msk) + k imes c(i,j) imes x_i + c(j,i) imes x_i)

    如果 (jin B),则转移 (dp(msk|2^i)gets dp(msk) + k imes c(j,i) imes x_i - c(i,j) imes x_i)

    复杂度 (O(2^mm^2)),可以得到 (70) 分。

    70 pts Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
        while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
        return f?x:-x;
    }
    int dp[1 << 23], cnt[1 << 23], n, m, k, c[23][23], S[100000];
    signed main() {
    	n = read(), m = read(), k = read();
    	for(int i = 0; i < n; ++i) S[i] = read() - 1;
    	for(int i = 0; i < n - 1; ++i) ++c[S[i]][S[i + 1]];
    	memset(dp, 0x3f, sizeof(dp));
    	dp[0] = 0;
    	for(int msk = 0; msk < 1 << m; ++msk) {
    		cnt[msk] = cnt[msk >> 1] + (msk & 1);
    		for(int i = 0; i < m; ++i) {
    			if(msk >> i & 1) continue;
    			int t = 0, x = cnt[msk] + 1;
    			for(int j = 0; j < m; ++j) if(i != j) {
    				if(msk >> j & 1) t += x * k * c[i][j] + x * c[j][i];
    				else t += x * k * c[j][i] - c[i][j] * x;
    			}
    			ckmin(dp[msk | (1 << i)], t + dp[msk]);
    		}
    	}
    	cout << dp[(1 << m) - 1] << '
    ';
    	return 0;
    }
    

    大眼观察转移,注意到很多计算是重复的。具体来说,只有 (m imes 2^m)(f(msk,i)),可以提前预处理。

    注意到预处理的时候如果再枚举 (j),复杂度就会错掉,仍然是 (O(2^mm^2))

    观察一下,贡献分为两类:从 (i) 往别的塔转移;从别的塔往 (i) 转移。

    考虑分别预处理这两类贡献设为 (f(msk,i),g(msk,i))

    考虑到 (msk) 之间的依赖性,可以通过 (f(msk,i)=f(mskoperatorname{xor}operatorname{lowbit}(msk),i)+c(log(operatorname{lowbit}(msk)),i)) 来递推,(g) 同理。

    复杂度这样子就是 (O(2^mm)) 了,可惜空间也是 (O(2^mm)) 这个复杂度,过不去。

    还是给出这部分的代码吧,比较好理解后面的优化。

    MLE Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
        while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
        return f?x:-x;
    }
    int dp[1 << 23], cnt[1 << 23], n, m, k, c[23][23], S[100000];
    int f[1 << 23][23], g[1 << 23][23], lg[1 << 23];
    inline void init() {
    	for(int i = 0; i < m; ++i) lg[1 << i] = i;
    	for(int msk = 1; msk < 1 << m; ++msk) {
    		for(int i = 0; i < m; ++i) {
    			int lb = msk & -msk;
    			f[msk][i] = f[msk ^ lb][i] + c[i][lg[lb]];
    			g[msk][i] = g[msk ^ lb][i] + c[lg[lb]][i];
    		}
    	}
    }
    signed main() {
    	n = read(), m = read(), k = read();
    	for(int i = 0; i < n; ++i) S[i] = read() - 1;
    	for(int i = 0; i < n - 1; ++i) ++c[S[i]][S[i + 1]];
    	init();
    	memset(dp, 0x3f, sizeof(dp));
    	dp[0] = 0;
    	for(int msk = 0, U = (1 << m) - 1; msk < U; ++msk) {
    		cnt[msk] = cnt[msk >> 1] + (msk & 1);
    		for(int i = 0; i < m; ++i) {
    			if(msk >> i & 1) continue;
    			int x = cnt[msk] + 1, o = x * k, res = 0, nx = msk | (1 << i);
    			res += o * f[msk][i];
    			res += x * g[msk][i];
    			res += o * g[U ^ nx][i];
    			res -= x * f[U ^ nx][i];
    			ckmin(dp[nx], res + dp[msk]);
    		}
    	}
    	cout << dp[(1 << m) - 1] << '
    ';
    	return 0;
    }
    

    既然都做到上面那一步通过 (operatorname{lowbit}) 来递推了。

    考虑用一个栈来保留当前具有公共前缀的 (msk)。举个例子可能好理解一些:

    比如 (msk=1101011),那么栈内应该存的是:(1000000,1100000,1101000,1101010,1101011)

    可以发现栈顶就是 (msk) 去掉 (operatorname{lowbit})的值,这样子我们可以根据栈顶 (f,g) 的值来递推出 (msk)(f,g) 值,再把 (msk) 压入栈中。

    但是这个优化使得我们不太方便维护 补集异或 (2^i) 的答案,可以考虑用全集的答案 减去当前的答案 再减去 (i) 的答案。

    根据压栈次数可以发现复杂度为 (2^m),而这部分的空间复杂度则降至 (O(m^2)),总空间复杂度变成了 (O(2^m))!!!

    然后就过去了。

    能自己想到这些优化还是很开心的,但是加起来搞了近三个小时,省选场上估计我没这个耐心也没这个时间,还是要完蛋啊。。。

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
        while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
        return f?x:-x;
    }
    int dp[1 << 23], n, m, k, U, c[23][23], S[100000], lg[1 << 23];
    int f[23][23], g[23][23], stk[23], top, st[23], ed[23];
    signed main() {
    	n = read(), m = read(), k = read();
    	for(int i = 0; i < n; ++i) S[i] = read() - 1;
    	for(int i = 0; i < n - 1; ++i)
    		++c[S[i]][S[i + 1]], ++st[S[i]], ++ed[S[i + 1]];
    	for(int i = 0; i < m; ++i) lg[1 << i] = i;
    	U = (1 << m) - 1;
    	memset(dp, 0x3f, sizeof(dp));
    	dp[0] = 0;
    	for(int msk = 0; msk < U; ++msk) {
    		if(msk) {
    			int lb = msk & -msk, LG = lg[lb];
    			while(top && (stk[top] ^ lb) != msk) --top;
    			stk[++top] = msk;
    			for(int i = 0; i < m; ++i) {
    				f[top][i] = f[top - 1][i] + c[i][LG];
    				g[top][i] = g[top - 1][i] + c[LG][i];
    			}
    		}
    		int x = top + 1, o = x * k;
    		for(int i = 0; i < m; ++i) {
    			if(msk >> i & 1) continue;
    			int res = 0;
    			res += o * f[top][i];
    			res += x * g[top][i];
    			res += o * (ed[i] - g[top][i]);
    			res -= x * (st[i] - f[top][i]);
    			res -= o * c[i][i] - c[i][i] * x;
    			ckmin(dp[msk | (1 << i)], res + dp[msk]);
    		}
    	}
    	cout << dp[U] << '
    ';
    	return 0;
    }
    

    非常的裸,要你维护一个集合,支持全局加一,查询全局异或和,合并。

    直接用 Trie 树即可。

    因为要全局加一,考虑从低位往高位插入。这样加一的时候,对于遍历到的节点执行以下两个操作即可:

    • 交换左右子树

    • 往左子树递归

    原因很简单,左子树是 (0),变成 (1),右子树是 (1) 变成 (0) 并且往下一位进 (1)

    对于左子树就是一个子问题了。

    维护全局异或和,我的方法是维护每一层 (1) 的个数,这样子遍历每一层可以 (O(log V)) 查询。

    合并直接 Trie 启发式合并就好了。

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	return f?x:-x;
    }
    
    const int N = 600006;
    const int T = N * 23;
    int n, w[N], fa[N], cnt[N][21][2], tot, tr[T][2], rt[N], s[T];
    int hed[N], et;
    LL ans;
    struct edge { int nx, to; } e[N];
    inline void adde(int u, int v) {
    	e[++et].nx = hed[u], e[et].to = v, hed[u] = et;
    }
    void add(int u) {
    	int p = rt[u];
    	for(int i = 0; i < 21; ++i) {
    		cnt[u][i][0] += s[tr[p][1]];
    		cnt[u][i][1] -= s[tr[p][1]];
    		cnt[u][i][1] += s[tr[p][0]];
    		cnt[u][i][0] -= s[tr[p][0]];
    		swap(tr[p][0], tr[p][1]);
    		if(tr[p][0]) p = tr[p][0];
    		else break;
    	}
    }
    int merge(int x, int y) {
    	if(!x || !y) return x | y;
    	s[x] += s[y];
    	tr[x][0] = merge(tr[x][0], tr[y][0]);
    	tr[x][1] = merge(tr[x][1], tr[y][1]);
    	return x;
    }
    void insert(int u, int x) {
    	if(!rt[u]) rt[u] = ++tot;
    	int p = rt[u];
    	for(int i = 0; i < 21; ++i) {
    		int c = x >> i & 1;
    		if(!tr[p][c]) tr[p][c] = ++tot;
    		p = tr[p][c], ++cnt[u][i][c], ++s[p];
    	}
    }
    void dfs(int u) {
    	for(int i = hed[u]; i; i = e[i].nx) {
    		int v = e[i].to;
    		dfs(v);
    		add(v);
    		for(int j = 0; j < 21; ++j)
    			cnt[u][j][0] += cnt[v][j][0], cnt[u][j][1] += cnt[v][j][1];
    		rt[u] = merge(rt[u], rt[v]);
    	}
    	insert(u, w[u]);
    	for(int i = 0; i < 21; ++i)
    		if(cnt[u][i][1] & 1) ans += 1 << i;
    }
    signed main() {
    	n = read();
    	rep(i, 1, n) w[i] = read();
    	rep(i, 2, n) fa[i] = read(), adde(fa[i], i);
    	dfs(1);
    	cout << ans << '
    ';
    }
    

    作业题

    首先 (gcd) 反演掉,可以考虑简便的欧拉函数 (sumlimits_{d|n}varphi(d)=n)

    [ans=sum_{d}varphi(d)sum_{G={e|w_e\%d=0}}left(sum_{Tin G}sum_{ein T}w_e ight) ]

    注意到在枚举 (d) 的过程中,如果特判掉没有生成树的情况,剩下的情况数非常少。

    粗略估计最大上界:这个值域内每一个数因数个数最大为 (sqrt{152501}=144),边数最大为 (dfrac{n(n-1)}{2}=435),每一棵生成树至少需要 (n-1) 条边,所以总共计算图的生成树边权和次数最多为 (144*15=2160)

    很显然这个上界达不到,哪怕达到,你写个 (O(n^4)) 的东西都能过。

    那么现在的问题转化为,给你一张图,求其所有生成树边权和。

    如果你做过 P5296 [北京省选集训2019]生成树计数,你会发现这题是个弱化版,直接套上即可。

    大概思路是,设每一条边的边权为一个多项式 (1+w_ex),可以发现跑完矩阵树之后一次项系数就是答案。

    直接把边权设成 Poly,带进去求 det。注意到可以在 (mod x^2) 意义下计算,所以这多项式就是个常数,复杂度 (O(n^3))。千万不要在 (mod x^{n+1}) 意义下计算,因为这样子是 (O(n^5)) 的。

    如果你像我一样脑抽了可以写个插值,求出这个多项式的 (n) 个点值然后把多项式插出来。如果你不知道怎么快速插这个多项式可以看看 拉格朗日插值如何插出系数,或者写个高斯消元也不影响复杂度。复杂度 (O(n^4))

    因为一开始写的是第二种方法,有一个点没过去,加了一点剪枝,代码比较 shit。除了上面提到的没有生成树用并查集判掉,还加了一个记忆化。因为发现不同的 (d) 提出的边可能会相同,就哈希了一下。不过加上这玩意 (3s+ o 800ms) 还是比较震惊的。。。

    方法一非常的稳健,最大点 (80ms)

    两种方法不那么好写,可能是我码力太菜了,码了 2h 左右才过去。

    我感觉我考场上都没有勇气来想这种题的正解,更别说想到之后还有没有足够的时间码出来。

    Code for First Solution
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
        while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
        return f?x:-x;
    }
    #define mod 998244353
    const LL P = 10000000000000061ll;
    inline int qpow(int n, int k) {
    	int res = 1;
    	for(; k; k >>= 1, n = 1ll * n * n % mod)
    		if(k & 1) res = 1ll * n * res % mod;
    	return res;
    }
    const int N = 32;
    const int M = 160005;
    int n, m, mp[N][N], x[N], y[N], f[N], ans;
    int phi[M], pri[M], pct;
    bool vis[M], bok[M];
    struct node {
    	int u, v, w;
    	node() { u = v = w = 0; }
    	node(int u_, int v_, int w_) { u = u_, v = v_, w = w_; }
    } e[N * N];
    struct poly {
    int a[2];
    poly(int x0 = 0, int x1 = 0) { a[0] = x0, a[1] = x1; }
    inline int& operator [] (const int &k) { return a[k]; }
    friend poly operator * (poly a, poly b) {
    	return poly(1ll * a[0] * b[0] % mod, (1ll * a[0] * b[1] + 1ll * a[1] * b[0]) % mod);
    }
    poly inv() {
    	int iv = qpow(a[0], mod - 2);
    	return poly(iv, 1ll * iv * (mod - iv) % mod * a[1] % mod);
    }
    poly& operator += (const poly &b) {
    	(a[0] += b.a[0]) %= mod, (a[1] += b.a[1]) %= mod;
    	return *this;
    }
    poly& operator -= (const poly &b) {
    	(a[0] += mod - b.a[0]) %= mod, (a[1] += mod - b.a[1]) %= mod;
    	return *this;
    }
    poly operator -() {
    	return poly(!a[0] ? 0 : mod - a[0], !a[1] ? 0 : mod - a[1]);
    }
    	
    } a[N][N];
    int F[N];
    inline int anc(int x) { return x == F[x] ? x : F[x] = anc(F[x]); }
    inline poly det(int n) {
    	poly res(1, 0);
    	for(int i = 0; i < n; ++i) {
    		for(int j = i + 1; j < n; ++j) {
    			poly tmp = a[j][i] * (a[i][i].inv());
    			for(int l = i; l < n; ++l)
    				a[j][l] -= a[i][l] * tmp;
    		}
    		res = res * a[i][i];
    	}
    	return res;
    }
    inline void init(const int&n = M - 1) {
    	phi[1] = 1;
    	for(int i = 2; i <= n; ++i) {
    		if(!vis[i]) pri[++pct] = i, phi[i] = i - 1;
    		for(int j = 1; j <= pct && i * pri[j] <= n; ++j) {
    			vis[i * pri[j]] = 1;
    			if(i % pri[j] == 0) {
    				phi[i * pri[j]] = phi[i] * pri[j];
    				break;
    			} else phi[i * pri[j]] = phi[i] * phi[pri[j]];
    		}
    	}
    }
    inline int qwq(int x, int*f, int n) {
    	int res = 0;
    	for(int i = n - 1; i >= 0; --i) res = (1ll * res * x % mod + f[i]) % mod;
    	return res;
    }
    map<LL, int> Map;
    inline int calc(int d) {
    	if(bok[d]) return 0;
    	memset(mp, 0, sizeof(mp));
    	int cnt = 0, h = 0;
    	rep(i, 0, n - 1) F[i] = i;
    	LL pw = 1;
    	for(int i = 1; i <= m; ++i, pw = 2ll * pw % P ) {
    		int x = e[i].u, y = e[i].v, w = e[i].w;
    		if(w % d) continue;
    		mp[x][y] = mp[y][x] = w, h = (h + pw) % P;
    		if(anc(x) != anc(y)) F[anc(x)] = anc(y), ++cnt;
    	}
    	if(cnt != n - 1) {
    		for(int j = d; j < M; j += d) bok[j] = 1;
    		return 0;
    	}
    	int tmp = Map[h];
    	if(tmp) return tmp;
    	for(int i = 0; i < n; ++i)
    		for(int j = 0; j < n; ++j)
    			a[i][j] = poly(0, 0);
    	for(int i = 0; i < n; ++i) {
    		for(int j = 0; j < n; ++j) {
    			if(!mp[i][j]) continue;
    			a[i][j] = poly(1, mp[i][j]);
    			a[i][i] += a[i][j], a[i][j] = -a[i][j];
    		}
    	}
    	return Map[h] = det(n - 1).a[1];
    }
    signed main() {
    	init();
    	n = read(), m = read();
    	rep(i, 1, m) e[i].u = read() - 1, e[i].v = read() - 1, e[i].w = read();
    	for(int i = 1; i < M; ++i) ans = (ans + 1ll * phi[i] * calc(i)) % mod;
    	cout << ans << '
    ';
    }
    
    Code for Second Solution
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp(x,y) make_pair(x,y)
    #define pb(x) push_back(x)
    #define sz(v) (int)v.size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
        while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
        return f?x:-x;
    }
    #define mod 998244353
    const LL P = 10000000000000061ll;
    inline int qpow(int n, int k) {
    	int res = 1;
    	for(; k; k >>= 1, n = 1ll * n * n % mod)
    		if(k & 1) res = 1ll * n * res % mod;
    	return res;
    }
    const int N = 32;
    const int M = 160005;
    int n, m, mp[N][N], a[N][N], x[N], y[N], f[N], ans;
    int phi[M], pri[M], pct;
    bool vis[M], bok[M];
    struct node {
    	int u, v, w;
    	node() { u = v = w = 0; }
    	node(int u_, int v_, int w_) { u = u_, v = v_, w = w_; }
    } e[N * N];
    int F[N];
    inline int anc(int x) { return x == F[x] ? x : F[x] = anc(F[x]); }
    inline int det(int n) {
    	int res = 1, flg = 0;
    	for(int i = 0; i < n; ++i) {
    		for(int j = i + 1; j < n; ++j) {
    			while(a[j][i]) {
    				int d = a[i][i] / a[j][i];
    				for(int k = i; k < n; ++k)
    					a[i][k] = (a[i][k] + mod - 1ll * a[j][k] * d % mod) % mod,
    					swap(a[i][k], a[j][k]);
    				flg ^= 1;
    			}
    		}
    		if(!a[i][i]) return 0;
    		res = 1ll * res * a[i][i] % mod;
    	}
    	return flg ? mod - res : res;
    }
    inline void lagrange(int*f, int*x, int*y, int n) {
    	static int a[N], b[N], c[N];
    	memset(a, 0, n << 2);
    	memset(b, 0, (n + 1) << 2);
    	memset(c, 0, n << 2);
    	memset(f, 0, n << 2);
    	for(int i = 0; i < n; ++i) {
    		int A = 1;
    		for(int j = 0; j < n; ++j) if(i != j)
    			A = 1ll * A * (x[i] - x[j] + mod) % mod;
    		a[i] = 1ll * y[i] * qpow(A, mod - 2) % mod;
    	}
    	b[0] = 1;
    	for(int i = 0; i < n; ++i) {
    		for(int j = i + 1; j >= 1; --j)
    			b[j] = (b[j - 1] + 1ll * b[j] * (mod - x[i])) % mod;
    		b[0] = 1ll * b[0] * (mod - x[i]) % mod;
    	}
    	for(int i = 0; i < n; ++i) {
    		int iv = qpow(mod - x[i], mod - 2);
    		c[0] = 1ll * b[0] * iv % mod;
    		for(int j = 1; j < n; ++j)
    			c[j] = 1ll * (b[j] - c[j - 1] + mod) * iv % mod;
    		for(int j = 0; j < n; ++j)
    			f[j] = (f[j] + 1ll * c[j] * a[i]) % mod;
    	}
    }
    inline void init(const int&n = M - 1) {
    	phi[1] = 1;
    	for(int i = 2; i <= n; ++i) {
    		if(!vis[i]) pri[++pct] = i, phi[i] = i - 1;
    		for(int j = 1; j <= pct && i * pri[j] <= n; ++j) {
    			vis[i * pri[j]] = 1;
    			if(i % pri[j] == 0) {
    				phi[i * pri[j]] = phi[i] * pri[j];
    				break;
    			} else phi[i * pri[j]] = phi[i] * phi[pri[j]];
    		}
    	}
    }
    inline int qwq(int x, int*f, int n) {
    	int res = 0;
    	for(int i = n - 1; i >= 0; --i) res = (1ll * res * x % mod + f[i]) % mod;
    	return res;
    }
    map<LL, int> Map;
    inline int calc(int d) {
    	if(bok[d]) return 0;
    	memset(mp, 0, sizeof(mp));
    	int cnt = 0, h = 0;
    	rep(i, 0, n - 1) F[i] = i;
    	LL pw = 1;
    	for(int i = 1; i <= m; ++i, pw = 2ll * pw % P ) {
    		int x = e[i].u, y = e[i].v, w = e[i].w;
    		if(w % d) continue;
    		mp[x][y] = mp[y][x] = w, h = (h + pw) % P;
    		if(anc(x) != anc(y)) F[anc(x)] = anc(y), ++cnt;
    	}
    	if(cnt != n - 1) {
    		for(int j = d; j < M; j += d) bok[j] = 1;
    		return 0;
    	}
    	int tmp = Map[h];
    	if(tmp) return tmp;
    	for(int z = 0; z < n; ++z) {
    		memset(a, 0, sizeof(a));
    		for(int i = 0; i < n; ++i) {
    			for(int j = 0; j < n; ++j) {
    				if(!mp[i][j]) continue;
    				a[i][j] = 1 + (z + 1) * mp[i][j];
    				a[i][i] += a[i][j], a[i][j] = mod - a[i][j];
    			}
    		}
    		x[z] = z + 1, y[z] = det(n - 1);
    	}
    	lagrange(f, x, y, n);
    	return Map[h] = f[1];
    }
    signed main() {
    	init();
    	n = read(), m = read();
    	rep(i, 1, m) e[i].u = read() - 1, e[i].v = read() - 1, e[i].w = read();
    	for(int i = 1; i < M; ++i) ans = (ans + 1ll * phi[i] * calc(i)) % mod;
    	cout << ans << '
    ';
    }
    

    消息传递

    当作复习点分治板子了,居然写了 40min 才过去,我也是服了我了。。。

    是震波的弱化版吧,但是只需要维护距离恰好为 (k) 的,而不是 (le k) 的,把震波里的树状数组改成数组少一只 (log),复杂度变成 (O((n+m)log n)),就过去了。

    想玩一下,其实也是怕 vector 不开 O2 常数爆炸,写了个什么指针+内存池,结果少开了一位调了好一会。。。

    大概思路是容斥,维护 (f(u,k)) 表示 (u) 这个分治中心的分治子树内部离它距离为 (k) 的点的个数,(g(u,k)) 表示这个分治中心的分治子树内部离它在分治树上父亲距离为 (k) 的点的个数。

    查询 ((x,k)) 的时候,答案就是 (sum f(u,k-operatorname{dis}(u,x))-g(pa_u,k-operatorname{dis}(pa_u,x)))

    维护 (f,g) 要动态内存,(f) 只用存到分治树大小,(g) 要存到分治树大小加一,总大小是 (2nlog n) 的,带修加上点权也很方便,总之这题非常板子就是了。感觉在线写法比离线方便啊 qaq。

    Code
    #include<bits/stdc++.h>
    using namespace std;
    #define fi first
    #define se second
    #define mkp make_pair
    #define pb push_back
    #define sz(v) (int)(v).size()
    typedef long long LL;
    typedef double db;
    template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
    template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
    #define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
    #define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
    inline int read(){
    	int x=0,f=1;char ch=getchar();
    	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	return f?x:-x;
    }
    const int N = 100005;
    int n, m;
    int et, hed[N];
    struct edge { int nx, to; } e[N << 1];
    inline void adde(int u, int v) {
    	e[++et].to = v, e[et].nx = hed[u], hed[u] = et;
    }
    int ST[20][N << 1], tmr, dfn[N], dep[N], lg[N << 1];
    int rt, mx[N], siz[N], used[N], tsiz, vt[N], fsz[N];
    int *f[N], *g[N], pool[N * 40], *mem;
    void dfs(int u, int ft) {
    	dfn[u] = ++tmr, ST[0][tmr] = dep[u];
    	for(int i = hed[u]; i; i = e[i].nx) {
    		int v = e[i].to;
    		if(v == ft) continue;
    		dep[v] = dep[u] + 1, dfs(v, u), ST[0][++tmr] = dep[u];
    	}
    }
    inline void init_dis() {
    	dfs(1, 0);
    	lg[0] = -1;
    	for(int i = 1; i <= tmr; ++i) lg[i] = lg[i >> 1] + 1;
    	rep(i, 1, lg[tmr]) rep(j, 1, tmr - (1 << i) + 1)
    		ST[i][j] = min(ST[i - 1][j], ST[i - 1][j + (1 << (i - 1))]);
    }
    inline int dis(int x, int y) {
    	int l = dfn[x], r = dfn[y];
    	if(l > r) l ^= r ^= l ^= r;
    	int t = lg[r - l + 1];
    	return dep[x] + dep[y] - (min(ST[t][l], ST[t][r - (1 << t) + 1]) << 1);
    }
    void getrt(int u, int ft) {
    	siz[u] = 1, mx[u] = 0;
    	for(int i = hed[u]; i; i = e[i].nx) {
    		int v = e[i].to;
    		if(v == ft || used[v]) continue;
    		getrt(v, u), siz[u] += siz[v];
    		ckmax(mx[u], siz[v]);
    	}
    	ckmax(mx[u], tsiz - siz[u]);
    	if(mx[u] < mx[rt]) rt = u;
    }
    
    void divide(int u) {
    	fsz[u] = tsiz;
    	f[u] = mem, mem += tsiz;
    	for(int *i = f[u]; i != mem; ++i) *i = 0;
    	g[u] = mem, mem += tsiz + 1;
    	for(int *i = g[u]; i != mem; ++i) *i = 0;
    	used[u] = 1;
    	for(int i = u; i; i = vt[i]) {
    		++f[i][dis(u, i)];
    		if(vt[i]) ++g[i][dis(u, vt[i])];
    	}
    	for(int i = hed[u]; i; i = e[i].nx) {
    		int v = e[i].to;
    		if(used[v]) continue;
    		tsiz = siz[v] > siz[u] ? fsz[u] - siz[u] : siz[v];
    		rt = 0, getrt(v, 0), vt[rt] = u, divide(rt);
    	}
    }
    inline int query(int x, int k) {
    	int res = 0;
    	for(int i = x, d; i; i = vt[i]) {
    		d = k - dis(x, i);
    		if(0 <= d && d < fsz[i]) {
    			res += f[i][d];
    		}
    		if(vt[i]) {
    			d = k - dis(x, vt[i]);
    			if(0 <= d && d < fsz[i] + 1) {
    				res -= g[i][d];
    			}
    		}
    	}
    	return res;
    }
    inline void clear() {
    	et = 0;
    	tmr = 0;
    	mem = pool;
    	memset(vt, 0, sizeof(vt));
    	memset(hed, 0, sizeof(hed));
    	memset(used, 0, sizeof(used));
    }
    void Main() {
    	n = read(), m = read();
    	clear();
    	rep(i, 2, n) {
    		int x = read(), y = read();
    		adde(x, y), adde(y, x);
    	}
    	init_dis();
    	mx[rt = 0] = n, tsiz = n, getrt(1, 0), divide(rt);
    	while(m--) {
    		int x = read(), k = read();
    		printf("%d
    ", query(x, k));
    	}
    }
    signed main() {
    	for(int T = read(); T; --T) Main();
    }
    

    总结

    感觉整体难度不大,给我时间我都能做出来。

    问题就是,在这么有限的时间内把这么多东西打出来,还要勇于想正解,还是很考验决策能力以及心态的。

    自信一点吧!

    路漫漫其修远兮,吾将上下而求索
  • 相关阅读:
    java内存分析 栈 堆 常量池的区别
    了解struts2 action的一些原理
    条件语句的写法
    Library中的title与Name
    样式优先级、margin
    文件夹IsShow字段为空
    Ispostback
    HierarchicalDataBoundControl 错误
    DBNull与Null
    sharepoint中的YesNo字段
  • 原文地址:https://www.cnblogs.com/zzctommy/p/14614791.html
Copyright © 2020-2023  润新知