这显然可以(Theta(n^3))枚举统计.
也显然可以(Theta(n))处理前缀和然后(Theta(n^2))枚举统计.
然后我们发现,前缀和之后,我们就把问题转化成了这样:
给定一个三元组序列,求有多少对((i,j))满足对应位置的三元组每个位置的差都相等.
即((j_1-i_1=j_2-i_2=j_3-i_3).)
然后我们发现,这个东西其实不需要这样.
我们考虑能对答案造成一个贡献的情况是怎样的,设左端点为(l),右端点为(r).
则充要条件为
[s_{r,0}-s_{l-1,0}=s_{r,1}-s_{l-1,1}=s_{r,2}-s_{l-1,2}
]
我们两两考虑:
则条件变为:
[s_{r,0}-s_{l-1,0}=s_{r,1}-s_{l-1,1}
]
且
[s_{r,1}-s_{l-1,1}=s_{r,2}-s_{l-1,2}
]
分别移项,得:
[s_{r,0}-s_{r,1}=s_{l-1,0}-s_{l-1,1}
]
且
[s_{r,1}-s_{r,2}=s_{l-1,1}-s_{l-1,2}
]
这样我们就将每一项都变得只与自己有关.
所以我们就可以把每一项变成一个二元组:((s_{i,0}-s_{i,1},s_{i,1}-s_{i,2}))
然后就只需要统计有多少对相同的二元组即可,要注意空集.
(Code:)
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <queue>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
#define MEM(x,y) memset ( x , y , sizeof ( x ) )
#define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i)
#define per(i,a,b) for (int i = (a) ; i >= (b) ; -- i)
#define pii pair < int , int >
#define one first
#define two second
#define rint read<int>
#define int long long
#define pb push_back
using std::queue ;
using std::set ;
using std::pair ;
using std::max ;
using std::min ;
using std::priority_queue ;
using std::vector ;
using std::swap ;
using std::sort ;
using std::unique ;
using std::greater ;
using std::map ;
template < class T >
inline T read () {
T x = 0 , f = 1 ; char ch = getchar () ;
while ( ch < '0' || ch > '9' ) {
if ( ch == '-' ) f = - 1 ;
ch = getchar () ;
}
while ( ch >= '0' && ch <= '9' ) {
x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ;
ch = getchar () ;
}
return f * x ;
}
const int N = 1e6 + 100 ;
int n , sum[N][3] , ans ;
pii v[N] ; char s[N] ;
map < pii , int > mk ;
signed main (int argc , char * argv[]) {
scanf ("%s" , s + 1 ) ; n = strlen ( s + 1 ) ;
rep ( i , 1 , n ) {
sum[i][0] = sum[i-1][0] ;
sum[i][1] = sum[i-1][1] ;
sum[i][2] = sum[i-1][2] ;
if ( s[i] == 'A' || s[i] == 'B' || s[i] == 'C' )
sum[i][s[i]-'A'] = sum[i-1][s[i]-'A'] + 1 ;
}
rep ( i , 1 , n ) {
v[i].one = sum[i][0] - sum[i][1] ;
v[i].two = sum[i][1] - sum[i][2] ;
}
sort ( v + 1 , v + n + 1 ) ;
rep ( i , 0 , n ) {
if ( ! mk[v[i]] ) ++ mk[v[i]] ;
else { ans += mk[v[i]] ; ++ mk[v[i]] ; }
}
printf ("%lld
" , ans ) ;
return 0 ;
}