原题地址:https://www.luogu.org/problem/P5072
题目简述
给定一个序列,每次查询一个区间[l,r]中所有子序列分别去重后的和mod p
思路
我们考虑每个数的贡献。即该区间内含有这个数的子序列个数。用补集转化为不含这个数的子序列个数。
那么,假设这个数在[l,r]内出现了kk次,则一共有2(r-l+1)-2(r-l+1-k)个子序列包含这个数。
本题可以离线,因此选择使用莫队,过程中维护cnt[k]表示区间内恰好出现k次的数字个数,维护sum[j]表示区间内恰好出现j次的数字之和(区间内出现次数相同的数,对于这些数,区间中包含这些数的子序列个数都相同,因此存数字之和就行)。
然而这样时间复杂度为O(询问次数*单次询问复杂度)=O(n*max(sqrt(n),n))=O(nm),并不可行。我们发现时间瓶颈不在莫队的sqrt(n),而是在单次查询中求解的复杂度n。
有2个套路可供使用:出现次数大于sqrt(n)的数不超过sqrt(n)个,值不为0的cnt[k]少于2*sqrt(n)个(反证易得,本质类似)。
- 对于第一个套路,我们分类讨论:出现次数小于等于sqrt(n),则统计每个出现次数的数字之和;大于sqrt(n)的用哈希表(unordered_set,C++11)存下具体的数字和其出现次数。这样每次查询是sqrt(n)。
- 笔者使用的则是第二个套路:val[x]表示出现次数恰好为x的数字之和(同上文的sum[j])。随着莫队l,r指针的移动,把所有可能变为非0的val[x]记下来,指针移动完毕后再对其进行筛选,把确实非0的val[x]保留,其他去除。这样计算单次答案的复杂度就等同于单次查询中莫队指针移动的平均步数:都是sqrt(n)级别。这样做不需要用到哈希表之类的,常数小了很多,甚至不需要读入优化也能轻松过。
还没完。我们发现模数是不定的,为了保证单次查询的复杂度压在sqrt(n)以内,我们还有最后一件事情要做:在sqrt(n)的时间内求出2(r-l+1)和所有的2(r-l+1-k)。这里安利一个神奇的方式:每次查询只需要做一次时间复杂度为sqrt(n)的预处理就可以O(1)查询了。
假设查询区间长度为len(len=r-l+1),我们记siz=sqrt(len),而后计算20,21,22...2sqrt(len),存在数组pow1中;再计算2sqrt(len),2(2*sqrt(len)),2(3*sqrt(len)),2(4*sqrt(len))...,2^(sqrt(len)*sqrt(len)),存在数组pow[2]中。以上计算都在mod p意义下进行。
这样求2的任意次方都可以O(1)出解:2k=2(k/siz)*2^(k%siz)=pow2[k/siz]*pow1[k%siz](记得模p)。
代码
#include <bits/stdc++.h>
using namespace std;
#define MAXN 100005
int n,m;
int a[MAXN],val[MAXN*2],cnt[MAXN],ans[MAXN];
int tot,blosiz,powsiz;
int bel[MAXN],pow1[MAXN],pow2[MAXN];
bool calced[MAXN];
long long sum[MAXN];
struct query {
int id,l,r,p;
bool operator<(const query &sb) const {
return bel[l]!=bel[sb.l] ? bel[l]<bel[sb.l] : (bel[l]&1 ? r<sb.r : r>sb.r);
}
} q[MAXN];
void add(int x)
{
sum[cnt[x]]-=x;
cnt[x]++;
sum[cnt[x]]+=x;
++tot;
val[tot]=cnt[x];
}
void del(int x)
{
sum[cnt[x]]-=x;
cnt[x]--;
sum[cnt[x]]+=x;
++tot;
val[tot]=cnt[x];
}
int power(int x,int p)
{
return 1LL*pow1[x%powsiz]*pow2[x/powsiz]%p;
}
int main()
{
scanf("%d%d",&n,&m);
blosiz=sqrt(n);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]),bel[i]=(i-1)/blosiz;
for (int i=1;i<=m;i++)
scanf("%d%d%d",&q[i].l,&q[i].r,&q[i].p),q[i].id=i;
sort(q+1,q+m+1);
int l=1,r=0;
for (int i=1;i<=m;i++) {
while (l>q[i].l) add(a[--l]);
while (r<q[i].r) add(a[++r]);
while (l<q[i].l) del(a[l++]);
while (r>q[i].r) del(a[r--]);
int newtot=0;
int len=r-l+1;
for (int j=1;j<=tot;j++)
if (val[j]&&sum[val[j]]!=0&&!calced[val[j]])
calced[val[j]]=1,val[++newtot]=val[j];
tot=newtot;
powsiz=sqrt(len)+1;
pow1[0]=pow2[0]=1;
for (int j=1;j<=powsiz;j++)
pow1[j]=(pow1[j-1]+pow1[j-1])%q[i].p;
for (int j=1;j*powsiz<=len;j++)
pow2[j]=1LL*pow2[j-1]*pow1[powsiz]%q[i].p;
int powLen=power(len,q[i].p);
for (int j=1;j<=tot;j++) {
long long num=sum[val[j]]%q[i].p;
ans[q[i].id]=(ans[q[i].id]+num*(powLen-power(len-val[j],q[i].p)))%q[i].p;
calced[val[j]]=0;
}
ans[q[i].id]+=q[i].p;
ans[q[i].id]%=q[i].p;
}
for (int i=1;i<=m;i++)
printf("%d
",ans[i]);
return 0;
}