[toc]
@description@
给定一个长度为 n 的字符串 s,保证只包含前 8 个小写字母 'a', 'b', ... 'h'。
根据该字符串建一个图。两个点 p, q 之间有连边要么 |p - q| = 1,要么 s[p] = s[q]。
求该图直径的长度(所有点对之间的最短距离的最大值),以及直径的数量。
Input 第一行一个整数 n,表示字符串长度。 第二行一个字符串 s。保证只包含前 8 个小写字母。
Output 输出直径长度与直径数量。
Examples Input 3 abc Output 2 1
Input 7 aaabaaa Output 2 4
@solution@
虽然 cf 的难度系统不太准,但至少难度 > 3000 都是我不会做的题.jpg。
先考虑如何快速求两个点 i, j 之间的最短路。 首先注意到最短路上不能出现不相邻的相同字符(即类似于 'a' → ... → 'a'),否则我可以直接从第一个相同字符跳到最后一个相同字符。 这意味着最短路径一定 ⇐ 2*8。
假如不经过相邻的相同字符(即不经过 s[p] = s[q] 类型的边),最短路径长度为 |i - j|。 否则,我们以某种字符为中转,向两边求到 i, j 的最短路,两者之和即 i->j 的最短路。 即如果记 d[c][i] 表示 i 到达某一个字符 c 的最短路,此时最短路径为 min(d[c][i] + d[c][j])。 那么 i, j 之间的最短路一定为 min(|i - j|, min(d[c][i] + d[c][j]))。
怎么求 d[c][i]?我们可以 bfs 搞定。 只是需要注意由于相同字符构成了一个完全图,假如这个字符的所有点已经全部进入队列,我们需要打上 tag 防止之后反复访问。不然时间就炸了。
考虑通用的解法:枚举 i,求以 i 为起点的最长路径及路径数量。但是这样子还是 O(n^2) 的。 首先,只有 i 前面 16 个以及后面 16 个是可能取 |i - j| 为最小值的,这些直接暴算。 那么剩下的 j 只可能取 min(d[c][i] + d[c][j]),我们再研究怎么简化这一部分的复杂度。
注意到这个只跟 d[c][j] 有关。我们或许可以将 d[0][j], d[1][j], ... 相同的 j 放在一起处理。 具体怎么操作?注意到相同字符 x 对应的 j, k,总有 |d[c][j] - d[c][k]| ⇐ 1(因为它们之间有边连接)。 记 mnd[c][x] = min(d[c][j]),那么对于字符 x 对应的任意 j,有 d[c][j] = mnd[c][x] 或 d[c][j] = mnd[c][x] + 1。 我们可以用一个 8 位的二进制状态将 d[c][j] 压缩,并存储该二进制状态对应的数量,就可以实现我们的目的了。
枚举 i 过后,再枚举 x 以及 8 位的二进制状态得到 d[0][j], d[1][j], ...,然后枚举中转字符 c,根据式子算。 注意我们需要把 i 前面 16 个以及后面 16 个的贡献先消掉。不然会计算重复。
@accepted code@
#include <queue>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 100000;
int clr[MAXN + 5], n;
vector<int>v[MAXN + 5];
int d[10][MAXN + 5];
bool tag[10];
void get_dist(int x) {
queue<int>que;
for(int i=0;i<8;i++) tag[i] = false;
for(int i=1;i<=n;i++) d[x][i] = MAXN + 5;
for(int i=0;i<v[x].size();i++)
d[x][v[x][i]] = 0, que.push(v[x][i]);
while( !que.empty() ) {
int f = que.front(); que.pop();
if( !tag[clr[f]] ) {
tag[clr[f]] = true;
for(int i=0;i<v[clr[f]].size();i++) {
int u = v[clr[f]][i];
if( d[x][u] > d[x][f] + 1 )
d[x][u] = d[x][f] + 1, que.push(u);
}
}
if( f != 1 && d[x][f-1] > d[x][f] + 1 )
d[x][f-1] = d[x][f] + 1, que.push(f-1);
if( f != n && d[x][f+1] > d[x][f] + 1 )
d[x][f+1] = d[x][f] + 1, que.push(f+1);
}
}
int mnd[10][10], bts[MAXN + 5], cnt[10][1<<10];
void get_mask(int x) {
for(int i=0;i<8;i++) {
mnd[x][i] = MAXN + 5;
for(int j=0;j<v[x].size();j++)
mnd[x][i] = min(mnd[x][i], d[i][v[x][j]]);
for(int j=0;j<v[x].size();j++)
bts[v[x][j]] |= ((d[i][v[x][j]] - mnd[x][i])<<i);
}
for(int j=0;j<v[x].size();j++)
cnt[x][bts[v[x][j]]]++;
}
int ans1; long long ans2;
void update(int x, int t) {
if( x == ans1 ) ans2 += t;
else if( x > ans1 ) ans1 = x, ans2 = t;
}
char s[MAXN + 5];
int abs(int x) {return x >= 0 ? x : -x;}
int main() {
scanf("%d%s", &n, s + 1);
for(int i=1;i<=n;i++)
v[s[i]-'a'].push_back(i), clr[i] = s[i] - 'a';
for(int i=0;i<8;i++) get_dist(i);
for(int i=0;i<8;i++) get_mask(i);
ans1 = 0, ans2 = 0;
int t = (1<<8);
for(int i=1;i<=n;i++) {
for(int j=max(1,i-16);j<=min(i+16,n);j++) {
int mn = abs(i - j);
for(int k=0;k<8;k++)
mn = min(mn, d[k][i] + 1 + d[k][j]);
cnt[clr[j]][bts[j]]--, update(mn, 1);
}
for(int j=0;j<8;j++) {
for(int s=0;s<t;s++) {
int mn = MAXN + 5;
for(int k=0;k<8;k++)
mn = min(mn, mnd[j][k] + ((s>>k) & 1) + d[k][i] + 1);
if( cnt[j][s] )
update(mn, cnt[j][s]);
}
}
for(int j=max(1,i-16);j<=min(i+16,n);j++)
cnt[clr[j]][bts[j]]++;
}
printf("%d %lld
", ans1, ans2/2);
}
@detail@
因为卡在第一步(求两个点之间的最短路)所以。。。也没什么好说的吧。。。
感觉自己代码能力还是有所提升的样子~~(好久没有那种一道题调一天的快感了)~~。 这个 detail 的板块最近基本也没用来提示代码细节,反倒是在吐槽了 2333。