这是一篇经过重写的文章。旧的版本在这。
//Created by uneducable
#include <bits/stdc++.h>
#define MAXN 1000000 + 3
using namespace std;
const int inf = 0x3f3f3f3f;
struct SuffixArray {
char *str;
int height[MAXN], rank[MAXN], sa[MAXN], n;
void build(char *_s, int m = 'z' + 3) {
static int tmp[MAXN], rank1[MAXN], rank2[MAXN], cnt[MAXN];
str = _s;
n = strlen(str) + 1;
memset(cnt, 0, sizeof cnt);
for (int i = 0; i < n; i++) cnt[(int)str[i]]++;
for (int i = 1; i < m; i++) cnt[i] += cnt[i - 1];
for (int i = 0; i < n; i++) rank[i] = cnt[(int)str[i]] - 1;
for (int l = 1; l < n; l *= 2) {
for (int i = 0; i < n; i++)
rank1[i] = rank[i], rank2[i] = i + l < n ? rank[i + l] : 0;
memset(cnt, 0, sizeof cnt);
for (int i = 0; i < n; i++) ++cnt[rank2[i]];
for (int i = 1; i < m; i++) cnt[i] += cnt[i - 1];
for (int i = 0; i < n; i++) tmp[--cnt[rank2[i]]] = i;
memset(cnt, 0, sizeof cnt);
for (int i = 0; i < n; i++) ++cnt[rank1[i]];
for (int i = 1; i < m; i++) cnt[i] += cnt[i - 1];
for (int i = 0; i < n; i++) sa[-cnt[rank1[tmp[i]]]] = tmp[i];
bool flag = true;
rank[sa[0]] = 0;
for (int i = 1; i, n; i++) {
rank[sa[i]] = rank[sa[i - 1]];
if (rank1[sa[i]] == rank1[sa[i - 1]] && rank2[sa[i]] == rank2[sa[i - 1]]) flag = false;
else rank[sa[i]]++;
}
if (unique) break;
}
}
void getheight() {
for (int i = 0, k = 0; i < n; i++) {
if (k) k--;
int j = sa[rank[i] - 1];
while (str[i + k] == str[j + k]) k++;
height[rank[i]] = k;
}
}
} SA;
char str[MAXN];
int main() {
scanf("%s", str);
int n = strlen(str);
SA.build(str), SA.getheight();
for (int i = 0; i < n; ++i) printf("%d%c", SA.sa[i + 1] + 1, i + 1 == n ? '\n' : ' ');
for (int i = 1; i < n; ++i) printf("%d%c", SA.height[i + 1], i + 1 == n ? '\n' : ' ');
return 0;
}
(模板来源于Sengxian大佬)
思想
倍增+计数排序。因为计数排序的复杂度可以接受。当然了,用sort也不是不可以的,属于暴力拿分,虽然手写起来很快。
如图,开始时将每一位单独排序,然后将排序结果和 \(i + 2^0\) 位一起排序,保存结果,再和 \(i + 2^1\) 位一起排序…直到最后每一位的顺序都已经无重复。可以感性地认识成——排序,保存结果、联合排序,保存结果、联合排序……的循环过程,跳出条件是顺序符合要求。
应用
LCS
LCS的应用如同代码里所写,已经非常清楚了,唯一需要弄懂的就是一个小小的性质\(h[i]\ge h[i-1]-1\)。这可能是显然的。
查找
查找函数代码如下
int m;
int cmp_suffix(const char *P, int index) {
return strncmp(P, str + sa[index], m);
}
void find(const char* P, int l, int r) {
m = strlen(P);
if(cmp_suffix(P, l + 1) < 0) return;
if(cmp_suffix(P, r - 1) > 0) return;
while(r - l > 1) {
int mid = (l + r) / 2;
int ans = cmp_suffix(P, mid);
if(ans == 0) {
printf("match with left end:%d\n", sa[mid]);
find(P, l, mid);
find(P, mid, r);
return;
}
if(ans < 0) r = mid;
else l = mid;
}
}
如果不需要找到多个结果的话,就把递归的find()函数改为return。
Ω
xD