#451 Div2 F
题意
给出一个由数字组成的字符串,要求添加一个加号和等号,满足数字无前导 0 且等式成立。
分析
对于这种只有数字的字符串,可以快速计算某一区间的字符串变成数字后并取模的值,首先从右到左,将字符串转化为数字并取模,那么 (h[i]) 表示字符串 (S[i...len]) 转化成数字后并取模的值,如果要求区间 ([i, j]) 所表示的数字是多少,首先求出 (h[i] - h[j + 1]),后面的 0 可以除掉,求一下逆元即可。
然后枚举一下,比一下就行了。
code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MOD = 1e9 + 7;
const int N = 1e6 + 10;
int len;
char s[N], z[N];
ll h[N], b[N], inv[N];
ll POW(ll x, ll k) {
ll r = 1;
while(k) {
if(k & 1) r = r * x % MOD;
x = x * x % MOD;
k >>= 1;
}
return r;
}
ll getval(int l, int r) {
return ((h[l] - h[r + 1]) * inv[len - r - 1] % MOD + MOD) % MOD;
}
int solve(int l1, int l2, int l3) {
int pos1 = 0, pos2 = l1, pos3 = len - l3;
if((!s[pos1] && l1 != 1) || (!s[pos2] && l2 != 1) || (!s[pos3] && l3 != 1)) return 0;
if((getval(pos1, pos2 - 1) + getval(pos2, pos3 - 1)) % MOD == getval(pos3, len - 1)) {
int mnl = min(l1, l2), mxl = max(l1, l2), nl = 0, y = 0;
for(int i = 0; i < mnl; i++) {
z[nl] = (s[pos2 - 1 - i] + s[pos3 - 1 - i] + y) % 10;
y = (s[pos2 - 1 - i] + s[pos3 - 1 - i] + y) / 10;
if(len - nl - 1 < 0 || s[len - nl - 1] != z[nl]) return 0;
nl++;
}
int np = (l1 < l2 ? pos3 : pos2);
for(int i = mnl; i < mxl; i++) {
z[nl] = (s[np - 1 - i] + y) % 10;
y = (s[np - 1 - i] + y) / 10;
if(len - nl - 1 < 0 || s[len - nl - 1] != z[nl]) return 0;
nl++;
}
if(y) {
z[nl] = y;
if(len - nl - 1 < 0 || s[len - nl - 1] != z[nl]) return 0;
nl++;
}
if(nl == l3) {
for(int i = 0; i < l1; i++) printf("%d", s[i]); printf("+");
for(int i = 0; i < l2; i++) printf("%d", s[pos2 + i]); printf("=");
for(int i = 0; i < l3; i++) printf("%d", s[pos3 + i]); printf("
");
return 1;
}
}
return 0;
}
int main() {
scanf("%s", s);
len = strlen(s);
b[0] = 1;
inv[0] = 1;
s[0] -= '0';
for(int i = 1; i < len; i++) {
b[i] = b[i - 1] * 10 % MOD;
s[i] -= '0';
inv[i] = POW(b[i], MOD - 2);
}
for(int i = len - 1; i >= 0; i--) {
h[i] = (h[i + 1] + b[len - 1 - i] * s[i]) % MOD;
}
for(int i = len / 3; i < len; i++) {
int l3 = i, l2 = i, l1 = len - 2 * i;
if(l1 > 0 && l3 >= l1 && l3 >= l2) {
if(solve(l1, l2, l3)) return 0;
if(solve(l2, l1, l3)) return 0;
}
l2 = i - 1, l1 = len - l3 - l2;
if(l1 > 0 && l3 >= l1 && l3 >= l2) {
if(solve(l1, l2, l3)) return 0;
if(solve(l2, l1, l3)) return 0;
}
}
return 0;
}