弱智题一眼秒,但是挺经典的概率dp解法,所以写一篇
但是调了n久......原因是发现自己公式里有个数写错了......
不知道为什么就这点人过
直接排个序,然后很容易就能想到\(dp[i][j]\)表示抽第\(i\)轮,这轮抽到第\(j\)张牌的概率
自然就有了\(O(n^3)\)的转移
\[dp[i][j] = \sum_{k = i - 1}^{a[k]<a[j]}{\frac{1}{n-i+1}dp[i-1][k]}
\]
这个打前缀和优化一下就搞定了
也有了个\(O(n^3)\)的答案
\[Ans = \sum_{i=1}^{n-1}\sum_{j=i}^{n-1}\sum_{k=j+1}^{n}\frac{1}{n-i}dp[i][j]*[a[j]==a[k]]
\]
这个更简单,记个cnt数组就优化掉了一个n......
甚至不用滚动数组,512mB还是很大方的
#include <bits/stdc++.h>
#define pii pair<int,int>
#define LL long long
#define l first
#define r second
#define MAXN 100000
using namespace std;
int n;
int a[5008];
LL dp[5008][5008];
LL sum[5008][5008];
LL inv[5008];
LL cnt[5008];
LL cntsum[5008];
const LL mod = 998244353;
LL qp(LL a, LL b)
{
LL ret = 1;
while (b)
{
if (b & 1)
ret = ret * a % mod;
a = a * a % mod;
b >>= 1;
}
return ret;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
{
scanf("%d", &a[i]);
cnt[a[i]]++;
inv[i] = qp(i, mod - 2);
}
for (int i = 1; i <= n; ++i)
{
cntsum[i] = cntsum[i - 1] + cnt[i];
}
sort(a + 1, a + n + 1);
for (int j = 1; j <= n; ++j)
{
dp[1][j] = inv[n];
sum[1][j] =(sum[1][j - 1] + dp[1][j]) % mod;
}
for (int i = 2; i <= n; ++i)
{
for (int j = i; j <= n; ++j)
{
dp[i][j] = (sum[i - 1][cntsum[a[j] - 1]] - sum[i - 1][i - 2] + mod) % mod * inv[n - i + 1] % mod;
sum[i][j] = (sum[i][j - 1] + dp[i][j]) % mod;
}
}
LL ans = 0;
for (int i = 1; i <= n - 1; ++i)
{
for (int j = i; j <= n - 1; ++j)
{
ans += inv[n - i] * dp[i][j] % mod * (cntsum[a[j]] - j) % mod;
ans %= mod;
}
}
printf("%I64d\n", ans * 2 % mod);
return 0;
}