一直想写个关于线段树的博客,奈何自己太懒惰了。
蓝后,今天开个坑。
0.1. 最普通的线段树
最普通的线段树,顾名思义就是最普通的,不加任何修饰的线段树,常见于各种教科书和菜鸡(比如我)的代码中。
基本思路应该就是从顶向底建树,从底向顶更新。
0.1.1. 单点修改&求区间最小/最大值
代码很短,很直观。
#include <bits/stdc++.h>
using namespace std;
const int MAX_NODE = (1 << 17) * 2, INF = 0x3f3f3f3f;
int n;
int mn[MAX_NODE];
#define ls (((o) << 1) + 1) //ls, rs: left children and right children
#define rs (((o) << 1) + 2)
#define mid (((l) + (r)) >> 1)
// [l, r) [a, b)
int query(int o, int l, int r, int a, int b) { //o:node now; l,r:segTree's left and right; a, b: query sequence's left and right;
if (r <= a || l >= b) return INF;
if (l >= a && r <= b) return mn[o];
else return min(query(ls, l, mid, a, b), query(rs, mid, r, a, b));
}
inline void pushUp(int o) {
mn[o] = min(mn[ls], mn[rs]);
}
inline void modify(int o, int l, int r, int pos, int v) { //update
if (r - l == 1) mn[o] = v;
else {
if (pos < mid) modify(ls, l, mid, pos, v);
else modify(rs, mid, r, pos, v);
pushUp(o);
}
}
int main() {
memset(mn, 0x3f, sizeof mn);
modify(0, 0, 4, 0, 10); //insert 10 to position 0 in seq 0 ~ 4. first '0' is an initial node
modify(0, 0, 4, 2, 10); //insert 10 to position 2 in seq 0 ~ 4. first '0' is an initial node
for (int i = 0; i <= 7; i++ ){
cout << mn[i] << endl;
}
cout << query(0, 0, 4, 1, 3);
return 0;
}
最大值最小值好搞,那么如果是求和呢?很简单,把取最值操作更改成两者相加操作就OK。
下一波问题。
0.1.2. 区间加/减&区间求和
看到区间修改,芳心一动就想拿树状数组乱搞?
其实确实可以用树状数组乱搞,而且常数小,跑的也快…但是为了应题,还是考虑一下线段树的做法吧。
首先,区间的修改意味着不能再像基础线段树一样简单的递归。我们可以考虑用一个数组存储一下当前结点的信息。对于区间加来说,用数组表示区间的值的变更,因为是整个区间加上某个数,所以不用考虑具体到每个数,笼统一加就可以了。
下面是代码,用的是简单明了的lrj的板子。
//Created on 2017/8/18 by uneducable
//No Rights Reserved
/*
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstring>
*/
using namespace std;
typedef long long ll;
const int MAXN = 800020;
inline int read() {
int data = 0, w = 1; char ch = 0;
while(ch != '-' && (ch < '0' || ch > '9')) ch=getchar();
if(ch == '-') w = -1,ch = getchar();
while(ch >= '0' && ch <= '9') data = data * 10 + ch - '0',ch = getchar();
return data * w;
}
int op, qL, qR, v;
ll _sum;
struct Tree {
ll sumv[MAXN], addv[MAXN];
void maintain(int o, int L, int R) {
int lc = o * 2, rc = o * 2 + 1;
sumv[o] = 0;
if (R > L) {
sumv[o] = sumv[lc] + sumv[rc];
}
if (addv[o]) {
sumv[o] += addv[o] * (R - L + 1);
}
}
void update(int o, int L, int R) {
int lc = o * 2, rc = o * 2 + 1;
if (qL <= L && qR >= R) addv[o] += v;
else {
int M = L + (R - L) / 2;
if (qL <= M) update(lc, L, M);
if (qR > M) update(rc, M + 1, R);
}
maintain(o, L, R);
}
void query(int o, int L, int R, int add) {
if (qL <= L && qR >= R) {
_sum += sumv[o] + add * (R - L + 1);
}
else {
int M = L + (R - L)/2;
if (qL <= M) query(o * 2, L, M, add + addv[o]);
if (qR > M) query(o * 2 + 1, M + 1, R, add + addv[o]);
}
}
};
const int inf = 0x3f3f3f3f;
Tree tree;
int main() {
int n, m;
memset(&tree, 0, sizeof tree);
n = read();
for (int i = 1; i <= n; i++) {
v = read();
qL = qR = i;
tree.update(1, 1, n);
}
m = read();
while (m--) {
op = read(); qL = read(); qR = read();
if (op == 1) {
v = read();
tree.update(1, 1, n);
}
else {
_sum = 0;
tree.query(1, 1, n, 0);
printf("%lld\n", _sum);
}
}
return 0;
}
0.1.3. 区间赋值&区间求和
由于题目要求区间赋值,多个操作的顺序不同,其公共区间的值的最终结果也就可能不同。 所以引入pushdown操作,每次update时更新当前区间中没有打set标记的儿子,递归之后maintain操作拉上来就好。
// Fast Sequence Operations II
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxnode = 1<<17;
int _sum, _min, _max, op, qL, qR, v;
struct IntervalTree {
int sumv[maxnode], minv[maxnode], maxv[maxnode], setv[maxnode];
void maintain(int o, int L, int R) {
int lc = o*2, rc = o*2+1;
if(R > L) {
sumv[o] = sumv[lc] + sumv[rc];
minv[o] = min(minv[lc], minv[rc]);
maxv[o] = max(maxv[lc], maxv[rc]);
}
if(setv[o] >= 0) { minv[o] = maxv[o] = setv[o]; sumv[o] = setv[o] * (R-L+1); }
}
void pushdown(int o) {
int lc = o*2, rc = o*2+1;
if(setv[o] >= 0) {
setv[lc] = setv[rc] = setv[o];
setv[o] = -1;
}
}
void update(int o, int L, int R) {
int lc = o*2, rc = o*2+1;
if(qL <= L && qR >= R) {
setv[o] = v;
} else {
pushdown(o);
int M = L + (R-L)/2;
if(qL <= M) update(lc, L, M); else maintain(lc, L, M);
if(qR > M) update(rc, M+1, R); else maintain(rc, M+1, R);
}
maintain(o, L, R);
}
void query(int o, int L, int R) {
if(setv[o] >= 0) {
_sum += setv[o] * (min(R,qR)-max(L,qL)+1);
_min = min(_min, setv[o]);
_max = max(_max, setv[o]);
} else if(qL <= L && qR >= R) {
_sum += sumv[o];
_min = min(_min, minv[o]);
_max = max(_max, maxv[o]);
} else {
int M = L + (R-L)/2;
if(qL <= M) query(o*2, L, M);
if(qR > M) query(o*2+1, M+1, R);
}
}
};
const int INF = 1000000000;
IntervalTree tree;
int main() {
int n, m;
while(scanf("%d%d", &n, &m) == 2) {
memset(&tree, 0, sizeof(tree));
memset(tree.setv, -1, sizeof(tree.setv));
tree.setv[1] = 0;
while(m--) {
scanf("%d%d%d", &op, &qL, &qR);
if(op == 1) {
scanf("%d", &v);
tree.update(1, 1, n);
} else {
_sum = 0; _min = INF; _max = -INF;
tree.query(1, 1, n);
printf("%d %d %d\n", _sum, _min, _max);
}
}
}
return 0;
}
0.2. Tag&zkw线段树
普通的线段树自上而下递归建树、传递标记,自下而上更新信息。直到某一天一个上古时期的大佬想,我们可不可以用确定的数组来存储线段树而不用指针,也不用递归来更新线段树会不会更好?
zkw说,要有更imba的线段树,于是便有了zkw线段树
重口味线段树之所以叫重口味线段树是因为发明它的大佬叫zkw,他的ppt在这:统计的力量。
参考下列伪代码脑补出整个zkw线段树吧…一言半语是讲不清楚它的核心思想的。建议好好读一下统计的力量,读完你就会有醍醐灌顶的感觉,但要注意原ppt上代码有错误…
#define MAXN 100000
#define lson(x) ((x) << 1)
#define rson(x) ((x) << 1 | 1)
using namespace std;
const int inf = 0x3f3f3f3f;
int N, M;
inline void maintain(int x) {
tree[x] = tree[lson(x)] + tree[rson(x)];
}
inline void build() {
for (M = 1; M < N; M <<= 1);
for (int i = M + 1; i <= M + N; i++) tree[i] = read();
for (int i = M - 1; i; i--) maintain(i);
}
//-----------------------------单点修改区间求和---------------------------
int sum[MAXN], tree[MAXN];
inline void update(int pos, int v) {
pos += M;
tree[pos] = v;
for (pos >>= 1; pos; pos >>= 1) maintain(pos);
}
inline int sum(int l, int r) {
int ans = 0;
for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
if (~l & 1) ans += tree[l ^ 1];
if (r & 1) ans += tree[r ^ 1];
}
return ans;
}
//-----------------------------------------------------------------------
/*
sum(x) sum of the interval x
lson(x) x * 2 (left son of node x)
rson(x) x * 2 + 1 (right son of node x)
L(x) left end of interval x
R(x) right end of interval x
tag(x) lazy tag of node x
maintain(x) maintain the node x, collect information from sons(left and right)
*/
int tag[MAXN], L[MAXN], R[MAXN];
//replace L(i) with l[i], R(i) with r[i], sum(i) with sum[i], tag(i) with tag[i]
inline void pushdown(int x) {
if (tag(x) && x < M) {
tag(lson(x)) += tag(x);
tag(rson(x)) += tag(x);
sum(lson(x)) += (R(rson(x)) - L(rson(x)) + 1) * tag(x);
sum(rson(x)) += (R(rson(x)) - L(rson(x)) + 1) * tag(x);
tag(x) = 0;
}
}
inline void build() {
for (M = 1; M < N; M <<= 1);
for (int i = M + 1; i <= M + N; i++) {
cin >> sum(i);
L(i) = R(i) = i - M;
}
for (int i = M - 1; i; i--) {
sum(i) = sum(lson(i)) + sum(rson(i));
L(i) = L(lson(i));
R(i) = R(rson(i));
}
}
//-----------------------------------------------------------------------
//-----------------------------区间修改区间求和---------------------------
inline void applytag(int x) {
stack<int> s;
while (x) {
stack.push(x);
x >>= 1;
}
while (s.size()) {
pushdown(stack.top());
stack.pop();
}
}
inline void range_update(int l, int r, int v) {
bool vl = false, vr = false;
int x;
int sl, sr;
for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
if (~l & 1) {
x = l ^ 1;
if (!vl) {
vl = true;
sl = x;
applytag(x);
}
tag(x) += v;
sum(x) += (R(x) - L(x) + 1) * v;
}
if (r & 1) {
x = r ^ 1;
if (!vr) {
vr = true;
sr = x;
applytag(x);
}
tag(x) += v;
sum(x) += (R(x) - L(x) + 1) * v;
}
}
for (sl >>= 1; sl; sl >>= 1) {
maintain(sl);
}
for (sr >>= 1; sr; sr >>= 1) {
maintain(sr);
}
}
inline int query (int l, int r) {
bool vl, vr;
vl = vr = false;
for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
if (~l & 1) {
if (!vl) applytag(l ^ 1);
ans += sum(l ^ 1);
vl = true;
}
if (r & 1) {
if (!vr) applytag(r ^ 1);
ans += sum(r ^ 1);
vr = true;
}
}
return ans;
}
zkw其实还可以做到更多,包括但不局限于查询区间最值等等。
0.3. 复杂的线段树问题
接下来的代码都不使用重口味线段树,而是使用相对好看一点的带懒标记的小清新线段树。名字并不是我瞎起的
0.3.1. 区间取模类
首先,知道一个结论:每个数最多被膜 \(\log x\) 次,再膜就膜不下去了。证明十分simple:一个数被膜一次,起码要减少一半以上。这个证明告诉我们了什么?告诉我们了如果用优化一点点的暴力瞎搞一下,复杂度还是可以接受的。所以,构建一个线段树,每一段维护一个区间最大值,如果当前膜操作到了这个区间,而被膜的数却比区间最大值都还要大,那么就跳过区间,而不是继续夏姬八膜。当然了,头铁的除外。
在修改方面(如果还要对值进行操作的话),单点修改和区间修改都是可以用这种暴力做一做。至于更高明的方法,嘿嘿,我还不会。
#include <bits/stdc++.h>
#define MAXN 400005
using namespace std;
const int maxn = 1e5 + 100;
long long a[maxn], m, n;
struct Node {
long long _max, sum;
Node(long long a = 0, long long b = 0) : _max(a), sum(b) {}
Node operator + (const Node& rhs) {
return Node(max(_max, rhs._max), sum + rhs.sum);
}
} I[maxn * 4];
inline int read() {
int f = 1, w = 0;
char ch = 0;
ch = getchar();
while ((ch > '9' || ch < '0') && ch != '-') ch = getchar();
if (ch == '-') f = -1, ch = getchar();
while (ch >= '0' && ch <= '9') w = w * 10 + ch - '0', ch = getchar();
return f * w;
}
void build(int o = 1, int L = 1, int R = n) {
if(L == R)
I[o] = Node(a[L], a[L]);
else {
int M = (L + R) >> 1;
build(o << 1, L, M);
build(o << 1 | 1, M + 1, R);
I[o] = I[o << 1] + I[o << 1 | 1];
}
}
long long ql, qr, q;
void query(int o = 1, int L = 1, int R = n) {
if (ql <= L && R <= qr)
q += I[o].sum;
else {
int M = (L + R) >> 1;
if(ql <= M) query(o << 1, L, M);
if(M < qr) query(o << 1 | 1, M + 1, R);
}
}
void update(int o = 1, int L = 1, int R = n) {
if(L == R)
I[o] = Node(I[o].sum % q, I[o].sum % q);
else {
int M = (L + R) >> 1;
if(ql <= M && I[o << 1]._max >= q) update(o << 1, L, M);
if(M < qr && I[o << 1 | 1]._max >= q) update(o << 1 | 1, M + 1, R);
I[o] = I[o << 1] + I[o << 1 | 1];
}
}
void setv(int o = 1, int L = 1, int R = n) {
if(L == ql && ql == R)
I[o] = Node(q, q);
else {
int M = (L + R) >> 1;
if(ql <= M) setv(o << 1, L, M);
else setv(o << 1 | 1, M + 1, R);
I[o] = I[o << 1] + I[o << 1 | 1];
}
}
void solve() {
n = read(); m = read();
for(int i = 1; i <= n; i++) a[i] = read();
build();
while(m--) {
int type; type = read();
if(type == 1) {
ql = read(); qr = read(); q = 0; query();
printf("%lld\n", q);
} else if(type == 2) {
ql = read(); qr = read(); q = read(); update();
} else {
ql = read(); q = read(); setv();
}
}
}
int main() {
//freopen("mod.in", "r", stdin);
//freopen("mod.out", "w", stdout);
solve();
return 0;
}
0.3.2. 区间开根类
这道题的思路和上道题差不多,也是更优雅地暴力。
对于区间取模来说,比这个区间最大值还大的膜数就不用再膜了。对于区间开根来说,喜欢按计算器的朋友们都知道,再大的数开着开着也会为1。我们拿BZOJ 3211 为例。题目给出的数据是 \(0 \le data[i] \le 10^9\) 。10的9次方是不是看起来很大?按5次根号就变成1.9几了。不过对于计算机的精度处理,谁信谁睿智,所以我们认为它最多操作6次便成为了1。
所以对于线段树上每段,我们维护一个变量记录这个区间已经被开了几次根,如果开了6次以上,那就跳过。同样,头铁的可以不判。对于BZOJ3211,这道题还有树状数组+并查集的做法。
线段树代码:
#include <bits/stdc++.h>
#define MAXN 1000100
int n, m;
struct shit {
int L, R, t;
long long num;
} s[4 * MAXN];
int w[MAXN];
void push_up(int p) {
s[p].num = s[p << 1].num + s[p << 1 | 1].num;
return ;
}
void build(int p, int l, int r) {
s[p].L = l;
s[p].R = r;
if(l == r) {
s[p].num = w[l];
return ;
}
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
push_up(p);
return ;
}
void fuck(int p, int a, int b) {
if(a <= s[p].L && s[p].R <= b) {
s[p].t++;
if(s[p].L == s[p].R) {
s[p].num = sqrt(s[p].num);
return ;
}
}
int mid = (s[p].L + s[p].R) >> 1;
if(s[p << 1].t < 6 && a <= mid)fuck(p << 1, a, b);
if(s[p << 1 | 1].t < 6 && b > mid)fuck(p << 1 | 1, a, b);
push_up(p);
return ;
}
long long Q(int p, int a, int b) {
if(a <= s[p].L && s[p].R <= b) {
return s[p].num;
}
long long ans = 0;
int mid = (s[p].L + s[p].R) >> 1;
if(a <= mid) ans += Q(p << 1, a, b);
if(b > mid)ans += Q(p << 1 | 1, a, b);
return ans;
}
int main() {
int a, b, f;
scanf("%d", &n);
for(int i = 1; i <= n; i++)scanf("%d", w + i);
scanf("%d", &m);
build(1, 1, n);
while(m--) {
scanf("%d%d%d", &f, &a, &b);
if(f - 1)fuck(1, a, b);
else printf("%lld\n", Q(1, a, b));
}
return 0;
}
0.3.3. 区间排序类
这个还是有点恐怖的,起码对于我当时第一次在雅礼见的时候还是很不可做。现在依旧不可做。好像是Codeforce上的原题。
因为是对字母进行排序操作,所以很不自然地想到建26棵线段树,每个线段树保存一个字母,记录在某一区间出现的次数。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#define out(i) <<#i<<"="<<(i)<<" "
#define OUT1(a1) cout out(a1) <<endl
#define OUT2(a1,a2) cout out(a1) out(a2) <<endl
#define OUT3(a1,a2,a3) cout out(a1) out(a2) out(a3)<<endl
#define maxn 100007
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
using namespace std;
int n, q;
char str[maxn];
struct Node {
int d[26];
int D;
bool sorted;
bool Inc;
void Insert(int v) {
memset(d, 0, sizeof(d));
D = d[v] = 1;
sorted = false;
}
void Take(bool Left, int N) {
D = N;
if(Left) {
for(int i = 0; i < 26; ++i) {
if(N >= d[i]) N -= d[i];
else {
d[i] = N;
N = 0;
}
}
} else {
for(int i = 25; i >= 0; --i) {
if(N >= d[i]) N -= d[i];
else {
d[i] = N;
N = 0;
}
}
}
}
void Drop(bool Left, int N) {
D = D - N;
if(Left) {
for(int i = 0; i < 26; ++i) {
if(N >= d[i]) N -= d[i], d[i] = 0;
else {
d[i] -= N;
N = 0;
}
}
} else {
for(int i = 25; i >= 0; --i) {
if(N >= d[i]) N -= d[i], d[i] = 0;
else {
d[i] -= N;
N = 0;
}
}
}
}
Node operator +(const Node &B)const {
Node C;
C.sorted = false;
for(int i = 0; i < 26; ++i) C.d[i] = d[i] + B.d[i];
C.D = D + B.D;
return C;
}
void show() {
printf("D=%d\n", D);
for(int i = 0; i < 26; ++i) {
if(d[i]) printf("d[%d]=%d\n", i, d[i]);
}
}
} ST[maxn << 2];
void PushDown(int rt) {
Node &L = ST[rt << 1], &R = ST[rt << 1 | 1];
if(ST[rt].sorted) {
int N = L.D;
L = ST[rt];
L.Take(ST[rt].Inc, N);
N = R.D;
R = ST[rt];
R.Take(!ST[rt].Inc, N);
ST[rt].sorted = false;
}
}
void Build(int l, int r, int rt) {
if(l == r) {
ST[rt].Insert(str[l] - 'a');
return;
}
int m = (l + r) >> 1;
Build(ls);
Build(rs);
ST[rt] = ST[rt << 1] + ST[rt << 1 | 1];
}
Node Query(int L, int R, int l, int r, int rt) {
if(L <= l && r <= R) {
return ST[rt];
}
PushDown(rt);
int m = (l + r) >> 1;
Node LANS, RANS;
int X = 0;
if(L <= m) LANS = Query(L, R, ls), X += 1;
if(R > m) RANS = Query(L, R, rs), X += 2;
if(X == 1) return LANS;
if(X == 2) return RANS;
return LANS + RANS;
}
Node Sum;
void Update(int L, int R, int l, int r, int rt) {
if(L <= l && r <= R) {
int N = ST[rt].D;
ST[rt] = Sum;
ST[rt].Take(Sum.Inc, N);
Sum.Drop(Sum.Inc, N);
return;
}
int m = (l + r) >> 1;
if(L <= m) Update(L, R, ls);
if(R > m) Update(L, R, rs);
ST[rt] = ST[rt << 1] + ST[rt << 1 | 1];
}
void Sort(int L, int R, int Inc) {
Sum = Query(L, R, 1, n, 1);
Sum.sorted = true;
Sum.Inc = Inc;
Update(L, R, 1, n, 1);
}
void Down(int l, int r, int rt) {
if(l == r) {
for(int i = 0; i < 26; ++i) {
if(ST[rt].d[i]) {
str[l] = i + 'a';
break;
}
}
return;
}
PushDown(rt);
int m = (l + r) >> 1;
Down(ls);
Down(rs);
}
int main(void) {
while(~scanf("%d%d", &n, &q)) {
scanf("%s", str + 1);
Build(1, n, 1);
for(int i = 0; i < q; ++i) {
int x, y, k;
scanf("%d%d%d", &x, &y, &k);
Sort(x, y, k);
}
Down(1, n, 1);
printf("%s\n", str + 1);
}
return 0;
}
0.4. 小结
暂时就想到这么多,做个记录以便日后回顾。代码和语言纰漏之处请指出。
xD;