题意:你的任务是计算满足如下性质的数组的数量
1.每个数组包含n个元素
2.每个元素的范围是([1, m])
3.对于每个数组,有一对相同的元素
4.对于每个数组,存在着一个下标i,使得左边的元素严格递增,右边的元素严格递减
分析:每个数组包含n个元素,并且有一对相同,意味着有n - 1个不同的数,每个元素的范围是[1, m],对于第一个元素有m种选法,第二个元素有m - 1种选法,第三个元素有m - 2种选法(dots)再除以内部的顺序,总共有(C_{m}^{n - 1})种选法。然后最大的元素在峰顶,剩下有(n - 2)个元素,其中两个是重复的元素,它们分别在峰的左右两端(因为每一侧都是严格的),然后可以从(n - 2)个元素中选取一个作为重复的元素放在左右两侧,有((n - 2))种选法。剩下的元素可以考虑在左边或者右边,然后有(2^{n - 3})种选法。总的选法是(C_{m}^{n - 1} * 2^{n - 3} * (n - 2))。
做法:求组合数取模可以使用乘法逆元或者卢卡斯定理,然后(2^{n - 3})可以使用快速幂。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
using LL = long long;
const int N = 400005;
const int mod = 998244353;
int n, m;
int fact[N], infact[N];
int qmi(int a, int k, int p)
{
int res = 1;
while (k)
{
if (k & 1) res = (LL)res * a % p;
a = (LL)a * a % p;
k >>= 1;
}
return res;
}
int C(int a, int b)
{
return (LL)fact[a] * infact[a - b] % mod * infact[b] % mod;
}
int main()
{
scanf("%d%d", &n, &m);
fact[0] = infact[0] = 1;
for (int i = 1; i < N; ++i)
{
fact[i] = (LL)fact[i - 1] * i % mod;
infact[i] = (LL)infact[i - 1] * qmi(i, mod - 2, mod) % mod;
}
//特判
if (n == 2)
{
printf("%d
", 0);
return 0;
}
int res = (LL)C(m, n - 1) * (n - 2) % mod;
int p = qmi(2, n - 3, mod) % mod;
printf("%d
", (LL)res * p % mod);
return 0;
}