嘟嘟嘟
这是今天做的第二道九条可怜的题,现在对他的题的印象是:表面清真可做,实则毒瘤坑人。
首先要感谢吉司机,我期望学的特烂,好在样例直接告诉我们期望怎么求了。
令(b_i)表示第(i)个不同的数的出现次数,那么期望就是
所以我们只要让分母尽量小就行了。
这个问题也不难,想想就知道,只要让所有(b_i)尽可能接近就行了。
证明很简单,以两个数为例:如果(a + b = x),问(a, b)取什么值的时候(a! * b!)最小。
设(a_1 + b_1 = x, a_2 + b_2 = x),其中(a_1 geqslant b_1, a_2 geqslant b_2, a_1 > a_2, b_1 < b_2),那么(a_1! * b_1! = a_2! * b_2! * frac{(a_2 + 1) * (a_2 + 2) * ldots * a_1}{(b_1 + 1) * (b_1 + 2) * ldots * b_2})。因为(a_1 = a_2 + Delta t, b_1 = b_2 - Delta t),所以后面的分数的分子分母项数相同,而分子的每一项都大于分母,所以整个分数必定大于1.
尽量接近,翻译过来就是最大的最小,那么自然想到二分。然后似乎就没了。
但为什么裸二分不对呢,因为有的(b_i)可能在原数列中存在,且非常大,那么二分的时候就不能考虑他了,必须在后面单独算贡献。
令二分的答案是(x),我们再列一个二元一次方程组解出来哪些(b_i)等于(x),哪些(b_i)是(x - 1)。(手算一下就行)
这题最后一个坑点就是用map会超时,秉着死也不写哈希表的精神,离散化后直接(O(nlogn))预处理加上吸氧,总算给过了(loj就可以过)。虽然都带个(log),但map常数似乎真不小……
代码奉上
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<map>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
//#define int long long
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e6 + 5;
const int maxN = 1.2e7 + 5;
const ll mod = 998244353;
const int base = 999979;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, _n, m, l, r;
ll a[maxn], b[maxn], c[maxn], cnt = 0, tot = 0;
ll bnum[maxn], cnum[maxn];
In ll quickpow(ll a, ll b)
{
ll ret = 1;
for(; b; b >>= 1, a = a * a % mod)
if(b & 1) ret = ret * a % mod;
return ret;
}
ll fac[maxN], inv[maxN];
In void init()
{
fac[0] = inv[0] = 1;
for(int i = 1; i < maxN; ++i) fac[i] = fac[i - 1] * i % mod;
inv[maxN - 1] = quickpow(fac[maxN - 1], mod - 2);
for(int i = maxN - 2; i; --i) inv[i] = inv[i + 1] * (i + 1) % mod;
}
In bool judge(int x)
{
ll sum = cnt * x;
for(int i = 1; i <= tot; ++i) if(cnum[i] < x) sum += x - cnum[i];
return sum >= m;
}
In ll solve()
{
n = read(), m = read(), l = read(), r = read();
cnt = r - l + 1, tot = 0;
for(int i = 1; i <= n; ++i)
{
a[i] = b[i] = read();
bnum[i] = 0;
}
sort(b + 1, b + n + 1);
_n = unique(b + 1, b + n + 1) - b - 1;
for(int i = 1; i <= n; ++i)
++bnum[lower_bound(b + 1, b + _n + 1, a[i]) - b];
for(int i = 1; i <= _n; ++i)
if(b[i] >= l && b[i] <= r) c[++tot] = b[i], cnum[tot] = bnum[i], --cnt;
ll L = 0, R = m + n;
while(L < R)
{
int mid = (L + R) >> 1;
if(judge(mid)) R = mid;
else L = mid + 1;
}
ll cnt2 = cnt, sum = 0;
for(int i = 1; i <= tot; ++i) if(cnum[i] < L) ++cnt2, sum += cnum[i];
ll ans = fac[n + m];
if(cnt2 * L == m + sum) ans = ans * quickpow(inv[L], cnt2) % mod;
else
{
ll x = (1 - L) * cnt2 + m + sum, y = cnt2 * L - m - sum;
ll tp1 = quickpow(inv[L], x), tp2 = quickpow(inv[L - 1], y);
ans = ans * tp1 % mod * tp2 % mod;
}
for(int i = 1; i <= _n; ++i) if(b[i] < l || b[i] > r) ans = ans * inv[bnum[i]] % mod;
for(int i = 1; i <= tot; ++i) if(cnum[i] >= L) ans = ans * inv[cnum[i]] % mod;
return ans;
}
int main()
{
init();
int T = read();
while(T--) write(solve()), enter;
return 0;
}