1. 题目描述
$A[i]$表示二级制表示的$i$的数字之和。求$1 le i < j le n$并且$A[i]>A[j]$的$(i,j)$的总对数。
2. 基本思路
$n le 10^300$。$n$这么大,显然只能用数位DP来做,我们可以预先处理一下将$n$表示成二进制,然后再进行DP。
$dp[i][j][k]$表示长度为i,两者$A$的差为$j$,状态为$k$的总数。
不妨令$|n| = l$,因此$j in [-l, l]$,因此需要$+l$,将$j$映射到$[0,l*2]$上。
在考虑$k$有多少种情况?不妨令$(x,y), x<y$表示一对可行解。
(0) $Pref(x) < Pref(y), Pref(y) < Pref(n)$;
(1) $Pref(x) < Pref(y), Pref(y) == Pref(n)$;
(2) $Pref(x) == Pref(y), Pref(y) < Pref(n)$;
(3) $Pref(x) == Pref(y), Pref(y) == Pref(n)$;
上面4中情况分别对应$k in [0, 3]$,剩下的就是状态转移就好了,还是挺简单的。总对数就是
[sum_{j = l+1}^{l*2}{dp[l][j][0]+dp[l][j][1]}]
可以使用滚动数组优化,其实也可以不使用。
3. 代码
1 /* 5632 */ 2 #include <iostream> 3 #include <sstream> 4 #include <string> 5 #include <map> 6 #include <queue> 7 #include <set> 8 #include <stack> 9 #include <vector> 10 #include <deque> 11 #include <bitset> 12 #include <algorithm> 13 #include <cstdio> 14 #include <cmath> 15 #include <ctime> 16 #include <cstring> 17 #include <climits> 18 #include <cctype> 19 #include <cassert> 20 #include <functional> 21 #include <iterator> 22 #include <iomanip> 23 using namespace std; 24 //#pragma comment(linker,"/STACK:102400000,1024000") 25 26 #define sti set<int> 27 #define stpii set<pair<int, int> > 28 #define mpii map<int,int> 29 #define vi vector<int> 30 #define pii pair<int,int> 31 #define vpii vector<pair<int,int> > 32 #define rep(i, a, n) for (int i=a;i<n;++i) 33 #define per(i, a, n) for (int i=n-1;i>=a;--i) 34 #define clr clear 35 #define pb push_back 36 #define mp make_pair 37 #define fir first 38 #define sec second 39 #define all(x) (x).begin(),(x).end() 40 #define SZ(x) ((int)(x).size()) 41 #define lson l, mid, rt<<1 42 #define rson mid+1, r, rt<<1|1 43 44 const int mod = 998244353; 45 const int maxl = 305; 46 const int maxn = 1205; 47 char ss[maxl]; 48 int a[maxn]; 49 int dp[2][maxn<<1][4]; 50 51 void solve() { 52 int l = 0, tmp; 53 int len = strlen(ss); 54 55 rep(i, 0, len) 56 ss[i] -= '0'; 57 58 int b = 0; 59 60 while (b<len && ss[b]==0) 61 ++b; 62 if (b >= len) { 63 puts("0"); 64 return ; 65 } 66 67 while (1) { 68 a[l++] = ss[len-1] & 1; 69 tmp = 0; 70 rep(i, b, len) { 71 if (ss[i] & 1) { 72 ss[i] = (tmp+ss[i])>>1; 73 tmp = 10; 74 } else { 75 ss[i] = (tmp+ss[i])>>1; 76 tmp = 0; 77 } 78 } 79 while (b<len && ss[b]==0) 80 ++b; 81 if (b >= len) 82 break; 83 } 84 85 reverse(a, a+l); 86 87 int l2 = l + l; 88 int p = 0, q = 1; 89 90 memset(dp, 0, sizeof(dp)); 91 92 rep(ii, 0, a[0]+1) { 93 rep(jj, 0, a[0]+1) { 94 if (ii > jj) 95 continue; 96 97 int nj = ii - jj + l; 98 int nk = (ii==jj) ? (jj==a[0])|2 : (jj==a[0]); 99 ++dp[p][nj][nk]; 100 } 101 } 102 103 rep(i, 1, l) { 104 rep(j, 0, l2+1) { 105 // i < j 106 rep(k, 0, 2) { 107 if (!dp[p][j][k]) 108 continue; 109 110 int mn1, mn2, nj, nk; 111 112 mn1 = 1; 113 mn2 = (k&1) ? a[i]:1; 114 115 rep(ii, 0, mn1+1) { 116 rep(jj, 0, mn2+1) { 117 nj = j + ii - jj; 118 nk = (k==1) && (jj==a[i]); 119 if (nj >= 0) 120 dp[q][nj][nk] = (dp[q][nj][nk] + dp[p][j][k]) % mod; 121 } 122 } 123 } 124 // i = j 125 rep(k, 2, 4) { 126 if (!dp[p][j][k]) 127 continue; 128 129 int mn, nj, nk; 130 131 mn = (k&1) ? a[i]:1; 132 rep(ii, 0, mn+1) { 133 rep(jj, 0, mn+1) { 134 if (ii > jj) 135 continue; 136 137 nj = j + (ii==1) - (jj==1); 138 if (k == 2) { 139 nk = (ii<jj) ? 0:2; 140 } else { 141 nk = (ii<jj) ? (jj==a[i]) : (jj==a[i])|2; 142 } 143 if (nj >= 0) 144 dp[q][nj][nk] = (dp[q][nj][nk] + dp[p][j][k]) % mod; 145 } 146 } 147 } 148 } 149 p ^= 1; 150 q ^= 1; 151 memset(dp[q], 0, sizeof(dp[q])); 152 } 153 154 int ans = 0; 155 156 rep(j, l+1, l2+1) 157 rep(k, 0, 2) 158 ans = (ans + dp[p][j][k]) % mod; 159 160 printf("%d ", ans); 161 } 162 163 int main() { 164 ios::sync_with_stdio(false); 165 #ifndef ONLINE_JUDGE 166 freopen("data.in", "r", stdin); 167 freopen("data.out", "w", stdout); 168 #endif 169 170 int t; 171 172 scanf("%d", &t); 173 while (t--) { 174 scanf("%s", ss); 175 solve(); 176 } 177 178 #ifndef ONLINE_JUDGE 179 printf("time = %d. ", (int)clock()); 180 #endif 181 182 return 0; 183 }
4. 数据生成器
1 import sys 2 import string 3 from random import randint, shuffle 4 5 6 def GenData(fileName): 7 with open(fileName, "w") as fout: 8 t = 10 9 fout.write("%d " % (t)) 10 ld = string.digits 11 for tt in xrange(t): 12 length = randint(200, 300) 13 L = [0] * length 14 for i in xrange(length): 15 L[i] = randint(0, 9) 16 L[0] = randint(1, 9) 17 fout.write("".join(map(str, L)) + " ") 18 19 20 def MovData(srcFileName, desFileName): 21 with open(srcFileName, "r") as fin: 22 lines = fin.readlines() 23 with open(desFileName, "w") as fout: 24 fout.write("".join(lines)) 25 26 27 def CompData(): 28 print "comp" 29 srcFileName = "F:Qt_prjhdojdata.out" 30 desFileName = "F:workspacecpp_hdojdata.out" 31 srcLines = [] 32 desLines = [] 33 with open(srcFileName, "r") as fin: 34 srcLines = fin.readlines() 35 with open(desFileName, "r") as fin: 36 desLines = fin.readlines() 37 n = min(len(srcLines), len(desLines))-1 38 for i in xrange(n): 39 ans2 = int(desLines[i]) 40 ans1 = int(srcLines[i]) 41 if ans1 > ans2: 42 print "%d: wrong" % i 43 44 45 if __name__ == "__main__": 46 srcFileName = "F:Qt_prjhdojdata.in" 47 desFileName = "F:workspacecpp_hdojdata.in" 48 GenData(srcFileName) 49 MovData(srcFileName, desFileName) 50