假设每一次方向变化时(包括结束时)所抓的最后一个人位置(绝对值)依次为$a_{1},a_{2},...,a_{k}$,则不难得到答案为$a_{k}+\sum_{i=1}^{k-1}(3^{k-i}+3^{k-i-1})a_{i}$
另外,$a_{1},a_{2},...,a_{k}$必然是左右交替的,但首尾并不能确定,可以暴力分类讨论
进一步的,考虑每一个$a_{i}$的贡献,最终可得$L_{i}$对总答案的贡献为($R_{i}$类似)
$$
\frac{4}{3}\sum_{i=1}^{n-1}L_{i}\sum_{x=1}^{n}\left(4{m-1\choose x-1}+{m-1\choose x-2}+3{m-1\choose x}\right)\sum_{j=1}^{x-1}9^{j}{n-i-1\choose j-1}{i-1\choose x-j-1}\\\sum_{x=1}^{n}{n-1\choose x-1}\left(5{m-1\choose x-1}+{m-1\choose x-2}+4{m-1\choose x}\right)L_{n}
$$
下式显然可以直接计算,上式中三个组合数相加不妨用${m-1\choose x-1}$代替,并交换$x$和$j$的枚举顺序,即
$$
\frac{4}{3}\sum_{i=1}^{n-1}L_{i}\sum_{j=1}^{n-1}9^{j}{n-i-1\choose j-1}\sum_{x=j+1}^{n}{i-1\choose x-j-1}{m-1\choose x-1}
$$
考虑关于$x$的枚举,将${i-1\choose x-j-1}$改为${i-1\choose i-(x-j)}$,根据${n+m\choose k}=\sum{n\choose x}{m\choose k-x}$,即${i+m-2\choose i+j-1}$
(另外,简单分析可得$x$取到所有${i-1\choose i-(x-j)}$非0的情况)
将组合数均用阶乘展开,并简单分类,即
$$
\frac{4}{3}\sum_{i=1}^{n-1}L_{i}(n-i-1)!(i+m-2)!\sum_{j=1}^{n-1}\frac{9^{j}}{(j-1)!(m-j-1)!}\cdot \frac{1}{(n-i-j)!(i+j-1)!}
$$
显然可以用多项式乘法处理,再使用ntt优化即可
时间复杂度为$o(n\log n)$,可以通过
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N (1<<19) 4 #define mod 998244353 5 #define ll long long 6 #define Add(x,y) (x+y<mod ? x+y : x+y-mod) 7 #define Dec(x,y) (x>=y ? x-y : x-y+mod) 8 int n,m,ans,mi[N],rev[N],fac[N],inv[N],a[N],A[N],B[N],g[N]; 9 int C(int n,int m){ 10 if (n<m)return 0; 11 return (ll)fac[n]*inv[m]%mod*inv[n-m]%mod; 12 } 13 int qpow(int n,int m){ 14 int s=n,ans=1; 15 while (m){ 16 if (m&1)ans=(ll)ans*s%mod; 17 s=(ll)s*s%mod,m>>=1; 18 } 19 return ans; 20 } 21 void ntt(int *a,int p=0){ 22 for(int i=0;i<N;i++) 23 if (i<rev[i])swap(a[i],a[rev[i]]); 24 for(int i=2,l=0;i<=N;i<<=1,l++){ 25 int s=qpow(3,(mod-1)/i); 26 if (p)s=qpow(s,mod-2); 27 g[0]=1; 28 for(int k=1;k<(i>>1);k++)g[k]=(ll)g[k-1]*s%mod; 29 for(int j=0;j<N;j+=i) 30 for(int k=0;k<(i>>1);k++){ 31 int x=a[j+k],y=(ll)a[j+k+(i>>1)]*g[k]%mod; 32 a[j+k]=Add(x,y),a[j+k+(i>>1)]=Dec(x,y); 33 } 34 } 35 if (p){ 36 int s=qpow(N,mod-2); 37 for(int i=0;i<N;i++)a[i]=(ll)a[i]*s%mod; 38 } 39 } 40 int calc(){ 41 int ans=0; 42 memset(A,0,sizeof(A)); 43 for(int i=1;i<n;i++)A[i]=(ll)a[i]*fac[n-i-1]%mod*fac[i+m-2]%mod; 44 ntt(A); 45 46 memset(B,0,sizeof(B)); 47 for(int i=1;i<min(n,m);i++)B[i]=(ll)mi[i]*inv[i-1]%mod*inv[m-i-1]%mod; 48 ntt(B); 49 for(int i=0;i<N;i++)B[i]=(ll)A[i]*B[i]%mod; 50 ntt(B,1); 51 for(int i=1;i<=n;i++)ans=(ans+4LL*B[i]*inv[n-i]%mod*inv[i-1])%mod; 52 53 memset(B,0,sizeof(B)); 54 for(int i=1;i<min(n,m+1);i++)B[i]=(ll)mi[i]*inv[i-1]%mod*inv[m-i]%mod; 55 ntt(B); 56 for(int i=0;i<N;i++)B[i]=(ll)A[i]*B[i]%mod; 57 ntt(B,1); 58 for(int i=2;i<=n;i++)ans=(ans+(ll)B[i]*inv[n-i]%mod*inv[i-2])%mod; 59 60 memset(B,0,sizeof(B)); 61 for(int i=1;i<min(n,m-1);i++)B[i]=(ll)mi[i]*inv[i-1]%mod*inv[m-i-2]%mod; 62 ntt(B); 63 for(int i=0;i<N;i++)B[i]=(ll)A[i]*B[i]%mod; 64 ntt(B,1); 65 for(int i=0;i<=n;i++)ans=(ans+3LL*B[i]*inv[n-i]%mod*inv[i])%mod; 66 67 ans=4LL*(mod+1)/3*ans%mod; 68 for(int x=1;x<=n;x++)ans=(ans+(ll)5*C(n-1,x-1)*C(m-1,x-1)%mod*a[n]%mod+mod)%mod; 69 for(int x=2;x<=n;x++)ans=(ans+(ll)C(n-1,x-1)*C(m-1,x-2)%mod*a[n]%mod+mod)%mod; 70 for(int x=1;x<=n;x++)ans=(ans+(ll)4*C(n-1,x-1)*C(m-1,x)%mod*a[n])%mod; 71 return ans; 72 } 73 int main(){ 74 mi[0]=fac[0]=inv[0]=inv[1]=1; 75 for(int i=1;i<N;i++)mi[i]=(ll)9*mi[i-1]%mod; 76 for(int i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)+((i&1)*(N>>1)); 77 for(int i=1;i<N;i++)fac[i]=(ll)fac[i-1]*i%mod; 78 for(int i=2;i<N;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 79 for(int i=1;i<N;i++)inv[i]=(ll)inv[i-1]*inv[i]%mod; 80 scanf("%d%d",&n,&m); 81 for(int i=1;i<=n;i++)scanf("%d",&a[i]); 82 ans=calc(); 83 for(int i=1;i<=m;i++)scanf("%d",&a[i]); 84 swap(n,m),ans=(ans+calc())%mod; 85 printf("%d\n",ans); 86 return 0; 87 }