「SCOI2016」萌萌哒
题目描述
一个长度为 (n) 的大数,用 (S_1S_2S_3 ldots S_n) 表示,其中 (S_i) 表示数的第 (i) 位,(S_1) 是数的最高位,告诉你一些限制条件,每个条件表示为四个数 $(l_1, r_1, l_2, r_2) $,即两个长度相同的区间,表示子串 $S_{l_1}S_{l_1 + 1}S_{l_1 + 2} ldots S_{r_1} $与 (S_{l_2}S_{l_2 + 1}S_{l_2 + 2} ldots S_{r_2})完全相同。
比如 (n = 6) 时,某限制条件 $(l_1 = 1, r_1 = 3, l_2 = 4, r_2 = 6) $,那么 (123123)、(351351) 均满足条件,但是 (12012)、(131141) 不满足条件,前者数的长度不为 (6),后者第二位与第五位不同。问满足以上所有条件的数有多少个。
(1 leq n leq 10^5, 1 leq m leq 10^5,1 leq li_1,ri_1,li_2,ri_2 leq n) 并且保证 ${r_i}_1 - {l_i}_1 = {r_i}_2 - {l_i}_2 $
解题思路 :
观察发现,限制条件的本质是使得一些位置只能填同样的字符,我们不妨把限制在一起的位置看做点,在它们之间连一条边
那么统一联通块里面的位置就只能填同一字符了,所以设联通快数为 (x) ,那么答案就是 (9 imes 10^{x-1}) (第一位不能填 (0) ,所以少一种选择)
于是就有一个暴力的做法,用并查集维护联通块,每次对于一组限制 (l_1, r_1, l_2, r_2) ,暴力将两个区间的对应点合并,复杂度 (O(n^2logn))
考虑怎么优化并查集的合并,由于是区间问题,所以很容易就想到用线段树或者 (st) 表来维护
每次把可以询问区间拆成 (log) 个区间,区间与区间之间进行连边,难点在于最后怎么将区间之间的合并转化到点上
由于题目只需要最终询问一次,不妨利用 (lazytag) 的思想,对每一种长度的区间用一个并查集来维护,最后算答案的时候将合并信息下传
具体来讲,考虑 (st) 表的做法:设 (fa(i,j)) 表示左端点为 (i) 的长度为 (2 ^ j) 的区间所在集合的 (root) 的左端点
那么下传信息的时候只需要 ((i, j-1), (i+2^{j-1},j-1)) 分别和 ((fa(i, j), j - 1)) 合并即可,最后统计一下 ((i, 0)) 的联通块数 ,总复杂度 (O(nlog^2n))
/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int f = 0, ch = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 500005, Mod = 1000000007;
int fa[N][21], n, m;
inline int ask(int x, int y){
return x == fa[x][y] ? x : fa[x][y] = ask(fa[x][y], y);
}
int main(){
read(n), read(m);
for(int i = 1; i <= n; i++)
for(int j = 0; j <= 20; j++) fa[i][j] = i;
while(m--){
int l1, r1, l2, r2;
read(l1), read(r1), read(l2), read(r2);
int ps1 = l1, ps2 = l2;
for(int i = 20; ~i; i--)
if(ps1 + (1 << i) - 1 <= r1){
int p = ask(ps1, i), q = ask(ps2, i);
if(p != q) fa[p][i] = q;
ps1 += (1 << i), ps2 += (1 << i);
}
}
for(int j = 20; j; j--)
for(int i = 1; i + (1 << j) - 1 <= n; i++){
int p = ask(i, j), q = ask(i, j - 1);
fa[q][j-1] = ask(p, j - 1);
p = ask(i, j), q = ask(i + (1 << j - 1), j - 1);
fa[q][j-1] = ask(p + (1 << j - 1), j - 1);
}
ll res = 1ll;
for(int i = 1; i <= n; i++)
if(fa[i][0] == i) (res *= (res == 1ll) ? 9ll : 10ll) %= Mod;
cout << res;
return 0;
}