题意
N个数的圆环上有多少种方案可以使得选出来的一段数是K的倍数(N<=50000, K<=200, a[i]<=1000).
思路
多校第七场1004。圆上的DP……大脑太简单处理一些细节时总是短路= =……
我处理圆的方法是把圆环展开成2*N个数。我们设以第i个数结尾的所有段,按照 mod k 的余数分类的方案数为hash[i][r],r 表示余数。那么如果我们在段末尾加上i+1,我们就可以仅根据余数来确定这些段在添加了第i+1个数之后mod K的余数是(r*exp(10,digits[i+1]) + number[i+1])%k。(一开始样例只有1位数我就直接把digits当1了我是有多逗……)。
注意的是我们需要排除段长度大于N的情况,所以当我们处理到i+N时,需要减去从i到i+N-1的情况,这个可以事先计算出来用dp[i]表示。还要注意一点是要注意我们在处理区间[N+1,2*N]时相当于又把[1,N]中的情况算了一遍,所以最后要减去。(跨区间的计算,即首尾连接的计算没有多余,不用减)。
然后交上去超时了擦。。。目测是卡了常数。。。标程处理圆环的方法是先把以首数字为结尾的情况算出来,然后直接在[1,N]上处理即可,这样不会有重复的计算。于是只好各种优化……精简的精简、再把(r*exp(10,digits[i+1]) + number[i+1])%k这个预处理用数组存了一下,然后终于过了。。。
代码
【我的代码】
[cpp]
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <string>
#include <cstring>
#include <vector>
#include <set>
#include <map>
#include <stack>
#include <queue>
#define MID(x,y) ((x+y)/2)
#define MEM(a,b) memset(a,b,sizeof(a))
#define REP(i, begin, end) for (int i = begin; i <= end; i ++)
using namespace std;
typedef long long LL;
typedef vector <int> VI;
typedef set <int> SETI;
typedef queue <int> QI;
typedef stack <int> SI;
int dp[50005];
int a[50005];
int bit[50005];
int e[200005];
int hash[205], tmp[250];
int mulmod[205][5][205];
inline int get_bit(int x){
int num = 0;
while(x){
num ++;
x /= 10;
}
return num;
}
inline void init(int n, int k, int length){
e[0] = 1%k;
REP(i, 1, length){
e[i] = e[i-1]*10 % k;
}
REP(i, 0, k){
REP(j, 1, 4){
REP(m, 0, k){
mulmod[i][j][m] = (i*e[j]+m)%k;
}
}
}
dp[1] = 0; int len = -bit[1];
REP(i, 1, n){
len += bit[i];
dp[1] = ( dp[1]*e[bit[i]] + a[i] ) % k;
}
REP(i, 2, n){
dp[i] = ( ( ( dp[i-1] - a[i-1]*e[len] )*e[bit[i-1]] + a[i-1]) % k + k ) % k;
len = len + bit[i-1] - bit[i];
}
}
int main(){
int n, k;
while(scanf("%d %d", &n, &k) != EOF){
int length = 0;
REP(i, 1, n){
scanf("%d", &a[i]);
bit[i] = get_bit(a[i]);
length += bit[i];
a[i] = a[i] % k;
}
init(n, k, length);
//main
MEM(hash, 0);
int res = 0, rn = 0;
REP(i, 1, n){
res += hash[0];
MEM(tmp, 0);
tmp[a[i]] ++;
for (int j = 0; j < k; j ++){
if (hash[j] > 0) tmp[mulmod[j][bit[i]][a[i]]] += hash[j];
}
for (int j = 0; j < k; j ++){
hash[j] = tmp[j];
}
}
res += hash[0];
rn = res;
REP(i, n+1, 2*n){
res += hash[0];
if (i == n + 1) rn = res;
hash[dp[i-n]] --;
MEM(tmp, 0);
tmp[a[i-n]] ++;
for (int j = 0; j < k; j ++){
if (hash[j] > 0) tmp[mulmod[j][bit[i-n]][a[i-n]]] += hash[j];
}
for (int j = 0; j < k; j ++){
hash[j] = tmp[j];
}
}
res += hash[0];
printf("%d
", res-rn);
}
return 0;
}
[/cpp]
【标程】
[cpp]
#include <cstdio>
#include <cstring>
//Assume n<=3*10^4, mod<=10^3
const int maxn = 50005, maxmod=205;
int n, mod, number[maxn], digits[maxn], e[maxn * 3];
int count[maxn][maxmod];
inline int count_digits(int number) {
if (!number) return 1;
int ret = 0;
while (number) ret++, number /= 10;
return ret;
}
int main() {
//freopen("input.txt", "r", stdin);
//freopen("output.txt","w",stdout);
e[0] = 1;
while (scanf("%d%d", &n, &mod) != EOF) {
memset(count,0,sizeof(int) * maxmod * n);
for (int i = 1; i < n * 3; i++)
e[i] = e[i - 1] * 10 % mod;
for (int i = 0; i < n; i++) {
scanf("%d", number + i);
digits[i] = count_digits(number[i]);
}
number[n] = number[0];
digits[n] = digits[0];
int s = 0, length = 0, answer = 0;
for (int i = n; i; i--) {
s = (s + number[i] * e[length]) % mod;
length += digits[i];
count[0][s]++;
}
answer += count[0][0];
for (int i = 1; i < n; i++) {
for (int r = 0; r < mod; r++)
count[i][(r * e[digits[i]] + number[i]) % mod] += count[i - 1][r];
s = (s * e[digits[i]] + number[i]) % mod;
count[i][s]--;
count[i][number[i] % mod]++;
s = ((s - number[i] * e[length]) % mod + mod) % mod;
answer += count[i][0];
}
printf("%d
", answer);
}
}
[/cpp]