题目
题解
首先考虑暴力怎么做。直观感觉就是选择最小的后缀拼起来。但是当前的选择会受到后面字符串的影响。因此,考虑从后往前选择。假设当前选到第(i)个串((s[i])),第(i+1)到(n)后缀拼接的最小字符串为(t)。那么如果考虑到当前第(i)个串,最优选择就是字符串(s[i]+t)的至少包含(s[i])最后一个字符的最小后缀。证明略。
因此问题就转换为快速求一个串的最小后缀,而且要求如果在首部插入字符可以快速地实时更新新串的最小后缀。后缀数组或许可以,但是它是(O(nlog n))求得所有后缀的排序,而这里只需要最小的后缀即可,并且后缀数组不支持更新,每插入一个新字符就要重新求。
因此这里使用字符串哈希求最小后缀。对于后缀(i),(j)(即它们首字符在串中位置),二分lcp长度,从而找到第一个值不同的位置比较即可。字符串哈希区间判断相等是(O(1)),二分时间复杂度(O(log n))。维护字符串哈希可以将字符串逆序,然后头插变为尾插,这样就可以(O(1))更新。
每次只在遍历新插入的字符串,总时间复杂度为(O(nlog n))
#include <bits/stdc++.h>
#define endl '
'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 1e6 + 10;
const int M = 1e9 + 7;
const double eps = 1e-5;
string ans;
vector<ull> hs;
string s[N];
const int base = 30;
ull pw[N], rpw[N];
inline ll qpow(ll a, ll b, ll m) {
ll res = 1;
while(b) {
if(b & 1) res = (res * a) % m;
a = (a * a) % m;
b = b >> 1;
}
return res;
}
void insert(string t) {
reverse(t.begin(), t.end());
ans += t;
ull res = 0, cb = pw[hs.size()];
if(!hs.empty()) res = hs.back();
for(int i = 0; i < t.size(); i++) {
res += (t[i] - 'a') * cb % M;
cb = cb * base % M;
hs.push_back(res);
}
res %= M;
}
ull cal(int l, int r) {
ull res = hs[r];
if(l - 1 >= 0) res -= hs[l - 1];
res = (res + M) * rpw[l] % M;
return res;
}
bool cmp(int i, int j) { // <
int len = min(i, j) + 1;
int l = 1, r = len;
while(l <= r) {
int mid = (l + r) / 2;
if(cal(i - mid + 1, i) == cal(j - mid + 1, j)) l = mid + 1;
else r = mid - 1;
}
if(l > len) return i < j;
return ans[i - l + 1] < ans[j - l + 1];
}
int main() {
IOS;
rpw[0] = pw[0] = 1;
for(int i = 1; i < N; i++) {
pw[i] = pw[i - 1] * base % M;
rpw[i] = rpw[i - 1] * qpow(base, M - 2, M) % M;
}
int t;
cin >> t;
while(t--) {
ans.clear();
hs.clear();
int n;
cin >> n;
for(int i = 1; i <= n; i++) cin >> s[i];
for(int i = n; i >= 1; i--) {
int tar = ans.size();
insert(s[i]);
for(int j = tar + 1; j < ans.size(); j++) {
if(cmp(j, tar)) {
tar = j;
}
}
ans.erase(ans.begin() + tar + 1, ans.end());
hs.erase(hs.begin() + tar + 1, hs.end());
}
reverse(ans.begin(), ans.end());
cout << ans << endl;
}
}