题目
点这里看题目。
分析
设前缀和(s_r=sum_{i=1}^r [S_i='1'])
考虑满足要求的子串((l,r])的要求:
[exists kin N_+, r-l=k(s_r-s_l)
]
单独计算并不好算,考虑一个分块的优化。设置阈值(T)。
对于(1le kle T)的(k),对要求进行变形:
[r-l=k(s_r-s_l)Rightarrow ks_r-r=ks_l-l
]
设(f(i,k)=ks_i-i),那么当(k)固定的时候,设(c(x,k))为(f(i,k)=x)的数量,答案就是:
[sum_{k=1}^Tsum_{x=0}^{nT}inom{c(x,k)}{2}
]
后面的那个循环,由于有值的(c(x,k))不会超过(n)个,因此这一部分可以(O(nT))计算。
对于(T< kle n)的(k),继续变形:
[r-l=k(s_r-s_l)Rightarrow frac{r-l}k=s_r-s_l
]
由于(k>T),所以有(s_r-s_l< frac nT)。也就是说,这种情况下子串内 1 的数量比较少。那么我们就可以直接枚举子串的起点(即式子中的(l))和子串内 1 的数量来计算。
当我们确定了 1 的数量的时候,我们也就可以确定此时右端点的范围((x,y]),那么也就知道了区间长度的范围((x-i,y-i])。此时的询问就变成求区间内的(k)的倍数的数量,可以(O(1))。
需要注意的是,由于这一部分不能前面的记重,因此应该满足:
[egin{aligned}
k>T&Rightarrow frac{r-l}{s_r-s_l}>T\
&Rightarrow r-l>T(s_r-s_l)
end{aligned}]
即区间长度必须大于(T(s_r-s_l)),特判一下即可。
时间复杂度是(O(nT+frac{n^2}T)),取(T=sqrt n)最优。
代码
#include <cmath>
#include <cstdio>
#include <cstring>
typedef long long LL;
const int MAXN = 2e5 + 5, MAXT = 505;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
template<typename _T>
_T MAX( const _T a, const _T b )
{
return a > b ? a : b;
}
int buc[MAXN * MAXT];
int s[MAXN], nxt[MAXN];
int N, T;
char S[MAXN];
int main()
{
LL ans = 0;
scanf( "%s", S + 1 ), N = strlen( S + 1 ), T = ceil( sqrt( N ) );
for( int i = 1 ; i <= N ; i ++ ) s[i] = s[i - 1] + S[i] - '0';
if( s[N] == 0 ) { puts( "0" ); return 0; }
for( int i = N ; ~ i ; i -- )
{
if( s[i] ^ s[i + 1] ) nxt[i] = i + 1;
else nxt[i] = nxt[i + 1];
}
for( int k = 1, t ; k <= T ; k ++ )
{
for( int i = 0 ; i <= N ; i ++ ) buc[s[i] * k - i + N] ++;
for( int i = 0 ; i <= N ; i ++ ) t = buc[s[i] * k - i + N], ans += 1ll * t * ( t - 1 ) / 2, buc[s[i] * k - i + N] = 0;
}
int l, r, tl, tr;
for( int i = 0 ; i < N ; i ++ )
{
l = i, r = nxt[i];
for( int k = 1 ; k <= N / T && s[i] + k <= s[N] ; k ++ )
{
l = nxt[l], r = nxt[r];
tr = r - i - 1, tl = l - i;
tl = MAX( tl, k * T + 1 );
if( tl <= tr ) ans += tr / k - ( tl - 1 ) / k;
}
}
write( ans ), putchar( '
' );
return 0;
}