题目
题目链接:http://noip.ybtoj.com.cn/contest/90/problem/1
给定一个长度为 \(n\) 的 \(01\) 序列 \(a_1\sim a_n\)。
请你求出有多少个整数三元组 \((l,r,p)\),使得 \(1\leq l<p<r\leq n\) 并且 \(a_p=1\),且 \([l,p]\) 和 \([p,r]\) 中 \(1\) 的个数相同。
思路
显然题目要求的就是区间和为奇数并且中间的 \(1\) 不是区间端点的区间个数。
那么先求出有多少个区间和为 \(1\)。预处理出 \(sum\) 表示前缀异或和,以及 \(nxt[i]\) 表示 \(i\) 后面第一个 \(1\) 的位置。
枚举区间左端点,假设 \(sum[i]=1\),那么和为 \(1\) 的区间就是 \(sum\) 在 \(i\) 后面为 \(0\) 的数量。可以在右移左端点的同时计算。
然后减去中间的 \(1\) 在端点的情况,当 \(sum[i]=1\) 时,显然到下一个 \(1\) 之前所有区间都是不合法的,所以答案减去 \((nxt[i]-i)\);当 \(sum[i]=0\) 时,不合法区间只有 \([i,nxt[i]]\) 一个,答案减一。
时间复杂度 \(O(n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1000010;
int type,n,cnt[2],a[N],nxt[N],sum[N];
ll ans;
int main()
{
freopen("puzzle.in","r",stdin);
freopen("puzzle.out","w",stdout);
scanf("%d%d",&type,&n);
for (int i=1;i<=n;i++)
{
scanf("%1d",&a[i]);
sum[i]=sum[i-1]^a[i];
cnt[sum[i]]++;
}
for (int i=n,last=n+1;i>=1;i--)
{
nxt[i]=last;
if (a[i]==1) last=i;
}
for (int i=1;i<=n;i++)
{
if (i>1) cnt[sum[i-1]]--;
ans+=cnt[sum[i-1]^1];
if (!a[i] && nxt[i]<=n) ans--;
if (a[i]) ans-=(nxt[i]-i);
}
printf("%lld",ans);
return 0;
}