题解:建议看官方题解
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
typedef long long LL;
typedef pair<int, int> pii;
const int maxn = 3e6 + 50;
struct state
{
int len, link;
int nex[26];
} st[maxn * 2];
int sz, last;
void sam_init(){
st[0].len = 0;
st[0].link = -1;
sz = 1, last = 0;
}
int stk[maxn], top, pos[maxn];
int sam_extend(int c){
int cur = sz++;
st[cur].len = st[last].len + 1;
int p = last;
last = cur;
while(p != -1 && !st[p].nex[c]){
st[p].nex[c] = cur;
p = st[p].link;
}
if(p == -1){
st[cur].link = 0;
return last;
}
int q = st[p].nex[c];
if(st[q].len == st[p].len + 1){
st[cur].link = q;
return last;
}
int clone = sz++;
st[clone].len = st[p].len + 1;
for(int i = 0; i < 26; i++){
st[clone].nex[i] = st[q].nex[i];
}
st[clone].link = st[q].link;
while(p != -1 && st[p].nex[c] == q){
st[p].nex[c] = clone;
p = st[p].link;
}
st[q].link = st[cur].link = clone;
return last;
}
string s;
int main() {
std::ios::sync_with_stdio(false);
cin >> s;
int n = s.size();
pos[n] = 0;
stk[1] = n;
s[n] = 'z' + 1;
top = 1;
sam_init();
for(int i = n - 1; i >= 0; i--){
while(s[i] - 'a' > s[stk[top]] - 'a' && top) {
top--;
}
last = pos[stk[top]];
for(int j = i; j < stk[top]; j++) {
last = sam_extend(s[i] - 'a');
}
pos[i] = last;
stk[++top] = i;
}
LL ans = 0;
for(int i = 1; i < sz; i++) ans += st[i].len - st[st[i].link].len;
cout << ans << '
';
}