题目
(话说这个题好像同时也是 Codeforces 1262F2 )
翻译
描述
「你的程序又挂了。这次 WA233 。」
这是这道题的难版。这个版本中,(1leq nleq 2cdot 10^5) 。如果你锁了这道题,你可以叉这道题。但只有你锁了这道题和上一道题,你才能叉上一道题。(译者注:这段话是对当时比赛规则的介绍,与题意无关。)
这道题是要完成 (n) 个单选题。每个问题有 (k) 个选项,只有一个选项是正确的。第 (i) 题的答案是 (h_i) 。如果你第 (i) 题的答案是 (h_i) ,在这个问题上你将得 (1) 分,否则得 (0) 分。在这道题中 (h_1,h_2,dots, h_n) 是已知的。
然而,你的程序中有一个错误导致答案会顺时针移动!考虑 (n) 个答案写成一个环。由于程序中的这个错误,它们会循环移动一个位置。
形式化地,这个错误导致第 (i) 题的答案变成了第 (i+1mod n) 题的答案。第 (1) 题的答案变成了第 (2) 题的答案,第 (2) 题的答案变成了第 (3) 题的答案,…… ,第 (n) 题的答案变成了第 (1) 题的答案。
我们把 (n) 个答案合称为一个答案组。共有 (k^n) 个可能的答案组。
你想知道,有多少个答案组满足这个性质:顺时针移动 (1) 位后,新答案组的总分严格大于旧答案组的总分。你需要求出答案模 (998244353) 。
例如,如果 (n=5) ,你的答案组是 (a=[1,2,3,4,5]) ,在提交时由于程序的错误会变成 (a'=[5,1,2,3,4]) 。如果正确答案组是 (h=[5,2,2,3,4]) ,答案组 (a) 得 (1) 分而答案组 (a') 得 (4) 分。既然 (4>1) ,那么答案组 (a=[1,2,3,4,5]) 应该被计入最终的结果。
输入
第一行包含两个整数 (n,k(1leq n leq 2cdot 10^5,1leq k leq 10^9)) —— 问题的数量和每题的选项数量。
下一行包含 (n) 个整数 (h_1,h_2,dots,h_n(1leq h_ileq k)) —— 每道题的答案。
输出
输出一个整数:满足给定性质的答案组数模 (998244353) 。
分析
为什么这么水的题能放 Div2 最后一题啊我都一眼秒了。
那为什么我还不能稳定黄名啊打一场就回紫了我好菜啊。
一个很显然的性质是,每题的答案具有一定的独立性。也就是说,如果第 (i) 题选了一个特定的选项,那么无论别的题怎么选,「第 (i) 题的答案转到第 (i+1) 题」这件事对分数的影响是一定的。
具体来说,如果第 (i) 题和第 (i+1) 题正确答案相同,那么无论第 (i) 题选什么,移位都不会对分数有影响;否则,就分为选了第 (i) 题的正确答案(移位后分数减 (1) )、选了第 (i+1) 题的正确答案(移位后分数加 (1) )和选了其他答案(移位后分数不变)三种情况。
那么问题就变成了:有 (n) 个格子,每个格子要么有 (k) 种方法 比如正着写、反着写、先正着画半圈再反着画半圈等 填 (0) ,要么有 (1) 种方法填 (-1) 、(1) 种方法填 (1) 、(k-2) 种方法填 (0) ,求有多少种填法使 (n) 个数的和大于 (0) 。
当然此时肯定有一个朴素的 dp 思路:(f_{i,j}) 表示前 (i) 个空的和为 (j) 的方案数。不过这个 dp 看起来没什么优化的空间。恭喜你 A 掉了 1227F1
这每一步转移看起来都差不多啊你 d 个鬼 p 。
既然格子的顺序没有关系,只跟多少个填 (1) 多少个填 (-1) 有关,那么直接枚举有 (a) 个填了 (-1) ,那么只需要钦定 (a+1) 个点填 (1) ,剩下的随便填 (1) (如果可以)或 (0) 就行了。
代码
果然 A 了题过几个月再写博客的缺点就在于自己分析了一通发现和代码的实现不一样 ……
我当时的做法是枚举有 (a) 个格子填 (1) 或 (-1) 。如果 (a) 是偶数就首先排除掉 (1) 的数量和 (-1) 的数量相等的情况。然后,每一种非法方案全部取相反数就会变成合法方案,所以合法方案就是总方案数 (2^a) ( (a) 是奇数)或 (2^a-C_a^{frac{a}{2}}) ( (a) 是偶数)的一半。
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
namespace zyt
{
typedef long long ll;
const int N = 2e5 + 10, P = 998244353;
int n, k, h[N], fac[N], finv[N];
int power(int a, int b)
{
int ans = 1;
while (b)
{
if (b & 1)
ans = (ll)ans * a % P;
a = (ll)a * a % P;
b >>= 1;
}
return ans;
}
int inv(const int a)
{
return power(a, P - 2);
}
void init()
{
fac[0] = 1;
for (int i = 1; i < N; i++)
fac[i] = (ll)fac[i - 1] * i % P;
finv[N - 1] = inv(fac[N - 1]);
for (int i = N - 1; i > 0; i--)
finv[i - 1] = (ll)finv[i] * i % P;
}
int C(const int n, const int m)
{
return (ll)fac[n] * finv[m] % P * finv[n - m] % P;
}
int work()
{
init();
scanf("%d%d", &n, &k);
if (k == 1)
{
puts("0");
return 0;
}
for (int i = 1; i <= n; i++)
scanf("%d", &h[i]);
h[n + 1] = h[1];
int num = 0, ans = 0;
for (int i = 1; i <= n; i++)
if (h[i] != h[i + 1])
++num;
for (int i = 1; i <= num; i++)
{
int tmp = power(2, i);
if (!(i & 1))
tmp = (tmp - C(i, i >> 1) + P) % P;
ans = (ans + (ll)power(k - 2, num - i) * tmp % P * inv(2) % P * C(num, i) % P) % P;
}
printf("%lld", (ll)ans * power(k, n - num) % P);
return 0;
}
}
int main()
{
return zyt::work();
}