先说下暴力做法,如果[l1,r1]和[l2,r2]子串相等等价于两个区间内每个数对应相等。那么可以用并查集暴力维护,把对应相等的数的位置维护到同一个集合里去,最后答案其实就是把每个集合可以放的数个数乘起来就行了。注意:最高位不为0,如果有num个集合,则答案为9 * 10^(num – 1)。
暴力维护复杂度为nm,每次询问枚举每个区间内的点,即n个点;查询集合个数复杂度为n,故总时间复杂度为nm + n ≈O(n²)
实际评测30分。
#include<cstdio> #include<cmath> #define N 100010 #define ll long const ll mod = 1000000007; using namespace std; ll fa[N],n,m,l1,r1,l2,r2; ll qpow(ll a,ll b)//快速幂 { ll ret = 1; while(b) { if(b & 1) ret = ret * a % mod; a = a * a % mod; b >>= 1; } return ret; } void reset_fa() { for(ll i = 1;i <= n;i++) fa[i] = i; return; } ll find(ll x) { if(fa[x] == x) return fa[x]; else return fa[x] = find(fa[x]); } int main() { scanf("%lld %lld",&n,&m); reset_fa(); for(ll i = 1;i <= m;i++) { scanf("%lld %lld %lld %lld",&l1,&r1,&l2,&r2); for(ll j = 0;j <= r1 - l1;j++) fa[find(l1 + j)] = find(l2 + j);//并查集暴力维护 } ll num = 0; for(ll i = 1;i <= n;i++) if(fa[i] == i) num++;//查找集合个数 printf("%lld",9 * qpow(10,num - 1) % mod); return 0; }
然后考虑优化,看数据范围,显然应该从nm部分下手,原本是一个一个点暴力维护,我们可以考虑对区间二进制拆分,拆成多个区间进行合并,复杂度就可以降到nlogn。或者从ST表角度理解也行,本质相同
#include<cstdio> #include<cmath> #define N 100010 #define ll long long const ll mod = 1000000007; using namespace std; ll fa[N][30]; ll n,m,l1,r1,l2,r2; ll qpow(ll a,ll b) { ll ret = 1; while(b) { if(b & 1) ret = ret * a % mod; a = a * a % mod; b >>= 1; } return ret; } void reset_fa() { for(ll i = 1;i <= n;i++) for(ll j = 0;j <= 20;j++) fa[i][j] = i; return; } ll find(ll x,ll y) { if(fa[x][y] == x) return fa[x][y]; else return fa[x][y] = find(fa[x][y],y); } void merge(ll x,ll y,ll j) { if(find(x,j) != find(y,j)) { fa[fa[x][j]][j] = fa[y][j];//区间[x,x + 2^j]和[y,y + 2^j]的合并 } return; } int main() { scanf("%lld %lld",&n,&m); reset_fa(); for(ll i = 1;i <= m;i++) { scanf("%lld %lld %lld %lld",&l1,&r1,&l2,&r2); for(ll j = 20;j >= 0;j--) { if(l1 + ((1 << j) - 1) <= r1) { merge(l1,l2,j); l1 += (1 << j);//二进制拆分 l2 += (1 << j); } } } for(ll j = 20;j > 0;j--) { for(ll i = 1;i + (1 << j) - 1 <= n;i++) { merge(i,find(i,j),j - 1); merge(i + (1 << (j - 1)),fa[i][j] + (1 << (j - 1)),j - 1);//把拆分的区间再合并回来。 } } ll num = 0; for(ll i = 1;i <= n;i++) { if(find(i,0) == i) num++;//查找区间个数 } printf("%lld",9 * qpow(10,num - 1) % mod); return 0; }