(n) 和 (m) 和 (k) 和序列 (b_i(1le ile m,1le b_ile b_{i+1}le frac n2))。对于字符串 (s),如果选其 (b_i) 前缀和 (b_i) 后缀翻转并交换能变成字符串 (s'),则 (s,s') 是相等的。求有多少个长度为 (n) 的字符串,字符集大小为 (k),互不相等。答案 (mod 998244353)。
数据范围:(2le nle 10^9),(1le mle min(frac n2,2cdot 10^5)),(1le kle 10^9)。
听神仙说这是群论,是 ( t Polya),小蒟蒻没有听说过,但是手推出来这题了。
既然蒟蒻啥都不会,那么就慢慢做这题。
把字符串分成 (2m+1) 段。
如 (m=2) 字符串就是 (overrightarrow{A}overrightarrow{B}overrightarrow{C}overrightarrow{D}overrightarrow{E}):
(overrightarrow{S}) 指的是一个有顺序的字符串。
(overrightarrow{A}) 表示 (s_{1 o b_1})。
(overrightarrow{B}) 表示 (s_{b_1+1 o b_2})。
(overrightarrow{C}) 表示 (s_{b_2+1 o n-b_2})。
(overrightarrow{D}) 表示 (s_{n-b_2+1 o n-b_1})。
(overrightarrow{E}) 表示 (s_{n-b_1+1 o n})。
- 解决 (m=1) 的情况:
如果 (overrightarrow{A}=overleftarrow{C}):(k^{n-b_1}) 种。
如果 (overrightarrow{A} ot=overleftarrow{C}):((k^n-k^{n-b_1})cdot 2^{-1}) 种。
答案为 (k^{n-b_1}+(k^n-k^{n-b_1})cdot 2^{-1}=(k^n+k^{n-b_1})cdot 2^{-1}) 种。
- 解决 (m=2) 的情况:
如果 (overrightarrow{A}=overleftarrow{E}&overrightarrow{B}=overleftarrow{D}):(k^{n-b_2}) 种。
如果 (overrightarrow{A} ot=overleftarrow{E}&overrightarrow{B}=overleftarrow{D}):((k^{n-b_2+b_1}-k^{n-b_2})cdot 2^{-1}) 种。
如果 (overrightarrow{A}=overleftarrow{E}&overrightarrow{B} ot=overleftarrow{D}):((k^{n-b_1}-k^{n-b_2})cdot 2^{-1}) 种。
如果 (overrightarrow{A} ot=overleftarrow{E}&overrightarrow{B} ot=overleftarrow{D}):((k^n-k^{n-b_1}-k^{n-b_2+b_1}+k^{n-b_2})cdot 2^{-2}) 种。
答案为 ((k^n+k^{n-b_1}+k^{n-b_2+b_1}+k^{n-b_2})cdot 2^{-2}) 种。
小蒟蒻于是就发现规律了:
令集合 (st={b_1,b_2-b_1,...,b_m-b_{m-1}})。
看似诡异,但是其实这个一把这个式子裂开就发现是可以合并的:
然后就做完了。因为还要 (mod 998244353),所以时间复杂度为 (Theta(mlog mod))。
- 代码
#include <bits/stdc++.h>
using namespace std;
//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair(a,b)
#define x first
#define y second
#define b(a) a.begin()
#define e(a) a.end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
//Data
const int N=2e5;
const int mod=998244353;
int n,m,k,b[N+7];
//Pow
int Pow(int a,int x){
if(!a) return 0; int res=1;
for(;x;a=(ll)a*a%mod,x>>=1)if(x&1) res=(ll)res*a%mod;
return res;
}
//Main
int main(){
scanf("%d%d%d",&n,&m,&k),k%=mod; //这句不加 Wrong on test 11
for(int i=1;i<=m;i++) scanf("%d",&b[i]);
int res=1;
for(int i=1;i<=m;i++) res=(ll)res*(Pow(k,b[i]-b[i-1])+1)%mod;
res=(ll)res*Pow(k,n-b[m])%mod;
res=(ll)res*Pow(Pow(2,m),mod-2)%mod;
printf("%d
",res);
return 0;
}
祝大家学习愉快!