E - Shuffle and Swap
Time limit
配点 : 1700 点
問題文
0 と 1 からなる同じ長さの二つの文字列 A=A1A2…An と B=B1B2…Bn があります。 A,B に含まれる 1 の個数は等しいです。
あなたは次のアルゴリズムによって A を変化させることにしました。
- a1, a2, ..., ak を、A で 1 が出現する位置の添字とする。
- b1, b2, ..., bk を、B で 1 が出現する位置の添字とする。
- a,b の要素をそれぞれ無作為に並び替える。これらの無作為な並び替えは一様かつ独立である。
- 1 から k までの各 i に対して、順に Aai と Abi の値を入れ替える。
この手順のあと、文字列 A,B が一致する確率を P とします。
さらに、Z=P×(k!)2 とします。明らかに、Z は整数です。
Z を 998244353 で割った余りを求めてください。
制約
- 1≤|A|=|B|≤10,000
- A,B は 0 と 1 からなる。
- A,B に含まれる 1 の個数は等しい。
- A,B には 1 が少なくとも一つ含まれる。
部分点
- 1≤|A|=|B|≤500 を満たすデータセットに正解すると、1200 点が与えられる。
入力
入力は以下の形式で標準入力から与えられる。
A B
出力
Z を 998244353 で割った余りを出力せよ。
入力例 1
1010 1100
出力例 1
3
最初の二つのステップで、a=[1,3], b=[1,2] となります。a,b の要素を無作為に並び替えたあとの結果としてありうるものは次の 4 つです。
- a=[1,3], b=[1,2]。はじめ A=1010。A1 と A1 の入れ替え後 A=1010。A3 と A2 の入れ替え後 A=1100。
- a=[1,3], b=[2,1]。はじめ A=1010。A1 と A2 の入れ替え後 A=0110。A3 と A1 の入れ替え後 A=1100。
- a=[3,1], b=[1,2]。はじめ A=1010。A3 と A1 の入れ替え後 A=1010。A1 と A2 の入れ替え後 A=0110。
- a=[3,1], b=[2,1]。はじめ A=1010。A3 と A2 の入れ替え後 A=1100。A1 と A1 の入れ替え後 A=1100。
この 4 つの結果のうち、3 つで A=B となっています。よって、P=3 / 4 であり、Z=3 となります。
入力例 2
01001 01001
出力例 2
4
A の要素の入れ替えによって A が変化することはなく、したがって必ず A=B となります。
入力例 3
101010 010101
出力例 3
36
三回の A の要素の入れ替えがどのように起こっても、A に含まれる 1 は適切な位置に移動します。
入力例 4
1101011011110 0111101011101
出力例 4
932171449
Score : 1700 points
Problem Statement
You have two strings A=A1A2…An and B=B1B2…Bn of the same length consisting of 0 and 1. The number of 1's in A and B is equal.
You've decided to transform A using the following algorithm:
- Let a1, a2, ..., ak be the indices of 1's in A.
- Let b1, b2, ..., bk be the indices of 1's in B.
- Replace a and b with their random permutations, chosen independently and uniformly.
- For each i from 1 to k, in order, swap Aai and Abi.
Let P be the probability that strings A and B become equal after the procedure above.
Let Z=P×(k!)2. Clearly, Z is an integer.
Find Z modulo 998244353.
Constraints
- 1≤|A|=|B|≤10,000
- A and B consist of 0 and 1.
- A and B contain the same number of 1's.
- A and B contain at least one 1.
Partial Score
- 1200 points will be awarded for passing the testset satisfying 1≤|A|=|B|≤500.
Input
Input is given from Standard Input in the following format:
A B
Output
Print the value of Z modulo 998244353.
Sample Input 1
1010 1100
Sample Output 1
3
After the first two steps, a=[1,3] and b=[1,2]. There are 4 possible scenarios after shuffling a and b:
- a=[1,3], b=[1,2]. Initially, A=1010. After swap(A1, A1), A=1010. After swap(A3, A2), A=1100.
- a=[1,3], b=[2,1]. Initially, A=1010. After swap(A1, A2), A=0110. After swap(A3, A1), A=1100.
- a=[3,1], b=[1,2]. Initially, A=1010. After swap(A3, A1), A=1010. After swap(A1, A2), A=0110.
- a=[3,1], b=[2,1]. Initially, A=1010. After swap(A3, A2), A=1100. After swap(A1, A1), A=1100.
Out of 4 scenarios, 3 of them result in A=B. Therefore, P=3 / 4, and Z=3.
Sample Input 2
01001 01001
Sample Output 2
4
No swap ever changes A, so we'll always have A=B.
Sample Input 3
101010 010101
Sample Output 3
36
Every possible sequence of three swaps puts the 1's in A into the right places.
Sample Input 4
1101011011110 0111101011101
Sample Output 4
932171449
真的是非常好的一道题,在此将官方题解的解释简单总结一下,并加入一些式子的推导过程。
首先,将ai、bi随机打乱的过程等价于 进行2个步骤。
1、将ai、bi匹配 共k!种
2、将匹配后形成的pair 排列 也是k!种
假设当前已经进行完上述步骤1,接下来需要求2中的排列数,使得进行后A=B。为此进行一个转化,建立一个点数为len(A、B字符串长度)的图,从ai连向bi一条有向边。这样点i的出度即为Ai,入度即为Bi
将无边连入连出的点去掉。这样余下的边形成的图就是由一些环、一些路径组成的。注意到环不论操作顺序如何,均不改变A、B的匹配状态(因为全是1),只需要考虑链。
而链中的点在A中的1/0状态一定为 111111……0 在B中的状态一定为 0111111…… 故只有唯一的一种操作顺序使得A、B进行完毕后相等。
设有2e个 Ai、Bi中只有一个为1的i, m个Ai、Bi均为1的点。易知图中总共有e+m条边,且有e条链。问题由此转化为了将m个点分配到e条链中(每条链分到的个数>=0)的情况数。
用f(i,j)表示考虑了前i条链,放恰好j个点的情况数。(这里的情况数的意义包含了边出现的顺序)
边界条件f(0,0)=1 f(0,j)=0(j>0) 有递推
这是因为第i条链加进了u个点,总共u+1条边,而其出现顺序应该是固定的,故需要除以(u+1)!
最后所求即为∑f(e,j)*e!*m!*(e+m)! (j从0到m)
这个式子是因为: 枚举放入链中的点的个数 将所有边全排列乘以f(e,j) 即为恰有j个点放入链的情况个数,同一个链中的边出现的顺序是固定的,这里(e+m)!*f(e,j) 确定了同一个链中边在所有边中出现的顺序,内部还需要自己排序(每条链首尾固定,只有中间需要排列)故乘以m!再对余下的链首全排列,故需要再乘以e!
先不看后面统一乘的阶乘,由f(i,j)的递推式可以看出恰为卷积形式,且其中一个多项式是固定的,又f(0,0)=1,故所求恰为一个多项式的幂次,只需要进行快速幂计算即可。
这里进行NTT时要注意每次只保留前m+1项,不然其余本就应该舍弃的项可能会影响最终计算结果。
1 #include <cstdio> 2 #include <iostream> 3 #include <algorithm> 4 #include <vector> 5 #include <set> 6 #include <map> 7 #include <string> 8 #include <cstring> 9 #include <stack> 10 #include <queue> 11 #include <cmath> 12 #include <ctime> 13 #include<bitset> 14 #include <utility> 15 #include <assert.h> 16 using namespace std; 17 #define rank rankk 18 #define mp make_pair 19 #define pb push_back 20 #define xo(a,b) ((b)&1?(a):0) 21 //#define LL ll 22 typedef unsigned long long ull; 23 typedef pair<int,int> pii; 24 typedef long long ll; 25 typedef pair<ll,int> pli; 26 const int INF=0x3f3f3f3f; 27 const ll INFF=0x3f3f3f3f3f3f3f3fll; 28 const int MAX=2e4+5; 29 const int MAX_N=MAX; 30 const ll MOD=998244353; 31 const long double pi=acos(-1.0); 32 //const double eps=0.00000001; 33 int gcd(int a,int b){return b?gcd(b,a%b):a;} 34 template<typename T>inline T abs(T a) {return a>0?a:-a;} 35 template<class T> inline 36 void read(T& num) { 37 bool start=false,neg=false; 38 char c; 39 num=0; 40 while((c=getchar())!=EOF) { 41 if(c=='-') start=neg=true; 42 else if(c>='0' && c<='9') { 43 start=true; 44 num=num*10+c-'0'; 45 } else if(start) break; 46 } 47 if(neg) num=-num; 48 } 49 inline ll powMM(ll a,ll b,ll M){ 50 ll ret=1; 51 a%=M; 52 // b%=M; 53 while (b){ 54 if (b&1) ret=ret*a%M; 55 b>>=1; 56 a=a*a%M; 57 } 58 return ret; 59 } 60 void open() 61 { 62 freopen("1009.in","r",stdin); 63 freopen("out.txt","w",stdout); 64 } 65 const int N = 1 << 19; 66 const ll P = 998244353; 67 const int G = 3;//原根 68 const int NUM = 20; 69 70 ll wn[NUM]; 71 ll A[N], B[N],C[N]; 72 73 ll quick_mod(ll a, ll b, ll m) 74 { 75 ll ans = 1; 76 a %= m; 77 while(b) 78 { 79 if(b & 1) 80 { 81 ans = ans * a % m; 82 b--; 83 } 84 b >>= 1; 85 a = a * a % m; 86 } 87 return ans; 88 } 89 90 void GetWn()//预处理原根的幂次 91 { 92 for(int i = 0; i < NUM; i++) 93 { 94 int t = 1 << i; 95 wn[i] = quick_mod(G, (P - 1) / t, P); 96 } 97 } 98 99 void Rader(ll a[], int len) 100 { 101 int j = len >> 1; 102 for(int i = 1; i < len - 1; i++) 103 { 104 if(i < j) swap(a[i], a[j]); 105 int k = len >> 1; 106 while(j >= k) 107 { 108 j -= k; 109 k >>= 1; 110 } 111 if(j < k) j += k; 112 } 113 } 114 void NTT(ll a[], int len, int on=1)//NTT的数组 下标从0开始 数组长度len 115 { 116 Rader(a, len); 117 int id = 0; 118 for(int h = 2; h <= len; h <<= 1) 119 { 120 id++; 121 for(int j = 0; j < len; j += h) 122 { 123 ll w = 1; 124 for(int k = j; k < j + h / 2; k++) 125 { 126 ll u = a[k] % P; 127 ll t = w * a[k + h / 2] % P; 128 a[k] = (u + t) % P; 129 a[k + h / 2] = (u - t + P) % P; 130 w = w * wn[id] % P; 131 } 132 } 133 } 134 if(on == -1) 135 { 136 for(int i = 1; i < len / 2; i++) 137 swap(a[i], a[len - i]); 138 ll inv = quick_mod(len, P - 2, P); 139 for(int i = 0; i < len; i++) 140 a[i] = a[i] * inv % P; 141 } 142 } 143 void Conv(ll a[], ll b[], int n)//多项式乘法 NTT 与其还原 144 { 145 NTT(a, n, 1);NTT(b, n, 1); 146 for(int i = 0; i < n; i++) 147 a[i] = a[i] * b[i] % P; 148 NTT(a, n, -1); 149 } 150 char a[MAX],b[MAX]; 151 int e,m; 152 ll inv[MAX],fac[MAX<<2]; 153 int main() 154 { 155 GetWn(); 156 scanf("%s",a);scanf("%s",b); 157 int le=strlen(a); 158 for(int i=0;i<le;i++){ 159 if(a[i]=='1'||b[i]=='1'){ 160 if(a[i]==b[i])++m; 161 else ++e; 162 } 163 } 164 e>>=1; 165 fac[0]=1;for(int i=1;i<=20001;i++)fac[i]=(ll)i*fac[i-1]%MOD; 166 inv[10001]=powMM(fac[10001],MOD-2,MOD); 167 for(int i=10001;i>=1;i--)inv[i-1]=(ll)i*inv[i]%MOD; 168 for(int i=0;i<=m;i++) 169 C[i]=inv[i+1]; 170 A[0]=1; 171 ll f=e; 172 int len=1; 173 while((1<<len)<=(m<<1))++len; 174 for(ll t=e;t;t>>=1) 175 { 176 if(t&1LL) 177 { 178 memset(B,0,sizeof(B)); 179 for(int i=0;i<=m;i++)B[i]=C[i]; 180 Conv(A,B,1<<len); 181 for(int i=m+1;i<(1<<len);i++)A[i]=0; 182 } 183 NTT(C, 1<<len, 1); 184 for(int i = 0; i <(1<<len); i++) 185 C[i] = C[i] * C[i] % P; 186 NTT(C, 1<<len, -1); 187 for(int i=m+1;i<(1<<len);i++)C[i]=0; 188 } 189 ll an=0; 190 ll tem=fac[m]*fac[f]%MOD*fac[f+m]%MOD; 191 for(int i=0;i<=m;i++) 192 an=(an+A[i])%MOD; 193 printf("%lld ",an*tem%MOD); 194 }