一直想写个关于线段树的博客,奈何自己太懒惰了。

蓝后,今天开个坑。

0.1. 最普通的线段树

最普通的线段树,顾名思义就是最普通的,不加任何修饰的线段树,常见于各种教科书和菜鸡(比如我)的代码中。

基本思路应该就是从顶向底建树,从底向顶更新。

0.1.1. 单点修改&求区间最小/最大值

代码很短,很直观。

#include <bits/stdc++.h>
using namespace std;
const int MAX_NODE = (1 << 17) * 2, INF = 0x3f3f3f3f;
int n;

int mn[MAX_NODE];
#define ls (((o) << 1) + 1)                 //ls, rs: left children and right children
#define rs (((o) << 1) + 2)
#define mid (((l) + (r)) >> 1)


// [l, r) [a, b)
int query(int o, int l, int r, int a, int b) {     //o:node now; l,r:segTree's left and right; a, b: query sequence's left and right;
    if (r <= a || l >= b) return INF;
    if (l >= a && r <= b) return mn[o];
    else return min(query(ls, l, mid, a, b), query(rs, mid, r, a, b));
}

inline void pushUp(int o) {
    mn[o] = min(mn[ls], mn[rs]);    
}

inline void modify(int o, int l, int r, int pos, int v) {       //update
    if (r - l == 1) mn[o] = v;
    else {
        if (pos < mid) modify(ls, l, mid, pos, v);
        else modify(rs, mid, r, pos, v);
        pushUp(o);
    }
}

int main() {
    memset(mn, 0x3f, sizeof mn);
    modify(0, 0, 4, 0, 10);         //insert 10 to position 0 in seq 0 ~ 4. first '0' is an initial node
    modify(0, 0, 4, 2, 10);         //insert 10 to position 2 in seq 0 ~ 4. first '0' is an initial node
    for (int i = 0; i <= 7; i++ ){ 
        cout << mn[i] << endl;
    }
    cout << query(0, 0, 4, 1, 3);
    return 0;
}

最大值最小值好搞,那么如果是求和呢?很简单,把取最值操作更改成两者相加操作就OK。

下一波问题。

0.1.2. 区间加/减&区间求和

看到区间修改,芳心一动就想拿树状数组乱搞?

其实确实可以用树状数组乱搞,而且常数小,跑的也快…但是为了应题,还是考虑一下线段树的做法吧。

首先,区间的修改意味着不能再像基础线段树一样简单的递归。我们可以考虑用一个数组存储一下当前结点的信息。对于区间加来说,用数组表示区间的值的变更,因为是整个区间加上某个数,所以不用考虑具体到每个数,笼统一加就可以了。

下面是代码,用的是简单明了的lrj的板子。

//Created on 2017/8/18 by uneducable
//No Rights Reserved
/*
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstring>
*/
using namespace std;
typedef long long ll;
const int MAXN = 800020;

inline int read() {
    int data = 0, w = 1; char ch = 0;
    while(ch != '-' && (ch < '0' || ch > '9')) ch=getchar();
    if(ch == '-') w = -1,ch = getchar();
    while(ch >= '0' && ch <= '9') data = data * 10 + ch - '0',ch = getchar();
    return data * w;
}

int op, qL, qR, v;
ll _sum;

struct Tree {
    ll sumv[MAXN], addv[MAXN];

    void maintain(int o, int L, int R) {
        int lc = o * 2, rc = o * 2 + 1;
        sumv[o] = 0;
        if (R > L) {
            sumv[o] = sumv[lc] + sumv[rc];
        }
        if (addv[o]) {
            sumv[o] += addv[o] * (R - L + 1);
        }
    }
    void update(int o, int L, int R) {
        int lc = o * 2, rc = o * 2 + 1;
        if (qL <= L && qR >= R) addv[o] += v;
        else {
            int M = L + (R - L) / 2;
            if (qL <= M) update(lc, L, M);
            if (qR > M) update(rc, M + 1, R);
        }
        maintain(o, L, R);  
    }

    void query(int o, int L, int R, int add) {
        if (qL <= L && qR >= R) {
            _sum += sumv[o] + add * (R - L + 1);
        }
        else  {
            int M = L + (R - L)/2;
            if (qL <= M) query(o * 2, L, M, add + addv[o]);
            if (qR > M) query(o * 2 + 1, M + 1, R, add + addv[o]);
        }
    }
};

const int inf = 0x3f3f3f3f;

Tree tree;

int main() {
    int n, m;
    memset(&tree, 0, sizeof tree);
    n = read();
    for (int i = 1; i <= n; i++) {
        v = read();
        qL = qR = i;
        tree.update(1, 1, n);
    }
    m = read();
    while (m--) {
        op = read(); qL = read(); qR = read();
        if (op == 1) {
            v = read();
            tree.update(1, 1, n);
        }
        else {
            _sum = 0;
            tree.query(1, 1, n, 0);
            printf("%lld\n", _sum);
        }
    }
	return 0;
}

0.1.3. 区间赋值&区间求和

由于题目要求区间赋值,多个操作的顺序不同,其公共区间的值的最终结果也就可能不同。 所以引入pushdown操作,每次update时更新当前区间中没有打set标记的儿子,递归之后maintain操作拉上来就好。

// Fast Sequence Operations II

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int maxnode = 1<<17;

int _sum, _min, _max, op, qL, qR, v;

struct IntervalTree {
  int sumv[maxnode], minv[maxnode], maxv[maxnode], setv[maxnode];

  void maintain(int o, int L, int R) {
    int lc = o*2, rc = o*2+1;
    if(R > L) {
      sumv[o] = sumv[lc] + sumv[rc];
      minv[o] = min(minv[lc], minv[rc]);
      maxv[o] = max(maxv[lc], maxv[rc]);
    }
    if(setv[o] >= 0) { minv[o] = maxv[o] = setv[o]; sumv[o] = setv[o] * (R-L+1); }
  }

  void pushdown(int o) {
    int lc = o*2, rc = o*2+1;
    if(setv[o] >= 0) { 
      setv[lc] = setv[rc] = setv[o];
      setv[o] = -1; 
    }
  }

  void update(int o, int L, int R) {
    int lc = o*2, rc = o*2+1;
    if(qL <= L && qR >= R) { 
      setv[o] = v;
    } else {
      pushdown(o);
      int M = L + (R-L)/2;
      if(qL <= M) update(lc, L, M); else maintain(lc, L, M);
      if(qR > M) update(rc, M+1, R); else maintain(rc, M+1, R);
    }
    maintain(o, L, R);
  }

  void query(int o, int L, int R) {
    if(setv[o] >= 0) { 
      _sum += setv[o] * (min(R,qR)-max(L,qL)+1);
      _min = min(_min, setv[o]);
      _max = max(_max, setv[o]);
    } else if(qL <= L && qR >= R) { 
      _sum += sumv[o]; 
      _min = min(_min, minv[o]);
      _max = max(_max, maxv[o]);
    } else { 
      int M = L + (R-L)/2;
      if(qL <= M) query(o*2, L, M);
      if(qR > M) query(o*2+1, M+1, R);
    }
  }
};

const int INF = 1000000000;

IntervalTree tree;

int main() {
  int n, m;
  while(scanf("%d%d", &n, &m) == 2) {
    memset(&tree, 0, sizeof(tree));
    memset(tree.setv, -1, sizeof(tree.setv));
    tree.setv[1] = 0;
    while(m--) {
      scanf("%d%d%d", &op, &qL, &qR);
      if(op == 1) {
        scanf("%d", &v);
        tree.update(1, 1, n);
      } else {
        _sum = 0; _min = INF; _max = -INF;
        tree.query(1, 1, n);
        printf("%d %d %d\n", _sum, _min, _max);
      }
    }
  }
  return 0;
}

0.2. Tag&zkw线段树

普通的线段树自上而下递归建树、传递标记,自下而上更新信息。直到某一天一个上古时期的大佬想,我们可不可以用确定的数组来存储线段树而不用指针,也不用递归来更新线段树会不会更好?

zkw说,要有更imba的线段树,于是便有了zkw线段树

重口味线段树之所以叫重口味线段树是因为发明它的大佬叫zkw,他的ppt在这:统计的力量

参考下列伪代码脑补出整个zkw线段树吧…一言半语是讲不清楚它的核心思想的。建议好好读一下统计的力量,读完你就会有醍醐灌顶的感觉,但要注意原ppt上代码有错误…


#define MAXN 100000
#define lson(x) ((x) << 1)
#define rson(x) ((x) << 1 | 1)

using namespace std;
const int inf = 0x3f3f3f3f;
int N, M;

inline void maintain(int x) {
    tree[x] = tree[lson(x)] + tree[rson(x)];
}

inline void build() {
    for (M = 1; M < N; M <<= 1);
    for (int i = M + 1; i <= M + N; i++) tree[i] = read();
    for (int i = M - 1; i; i--) maintain(i);
}
//-----------------------------单点修改区间求和---------------------------

int sum[MAXN], tree[MAXN];
inline void update(int pos, int v) {
    pos += M;
    tree[pos] = v;
    for (pos >>= 1; pos; pos >>= 1) maintain(pos);
}

inline int sum(int l, int r) {
    int ans = 0;
    for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
        if (~l & 1) ans += tree[l ^ 1];
        if (r & 1) ans += tree[r ^ 1];
    }
    return ans;
}
//-----------------------------------------------------------------------
/*
    sum(x) sum of the interval x
    lson(x) x * 2 (left son of node x)
    rson(x) x * 2 + 1 (right son of node x)
    L(x) left end of interval x
    R(x) right end of interval x
    tag(x) lazy tag of node x
    maintain(x) maintain the node x, collect information from sons(left and right)

*/
int tag[MAXN], L[MAXN], R[MAXN];
//replace L(i) with l[i], R(i) with r[i], sum(i) with sum[i], tag(i) with tag[i]

inline void pushdown(int x) {
    if (tag(x) && x < M) {
        tag(lson(x)) += tag(x);
        tag(rson(x)) += tag(x);
        sum(lson(x)) += (R(rson(x)) - L(rson(x)) + 1) * tag(x);
        sum(rson(x)) += (R(rson(x)) - L(rson(x)) + 1) * tag(x);
        tag(x) = 0;
    }
}

inline void build() {
    for (M = 1; M < N; M <<= 1);
    for (int i = M + 1; i <= M + N; i++) {
        cin >> sum(i);
        L(i) = R(i) = i - M;
    }
    for (int i = M - 1; i; i--) {
        sum(i) = sum(lson(i)) + sum(rson(i));
        L(i) = L(lson(i));
        R(i) = R(rson(i));
    }
}
//-----------------------------------------------------------------------
//-----------------------------区间修改区间求和---------------------------
inline void applytag(int x) {
    stack<int> s;
    while (x) {
        stack.push(x);
        x >>= 1;
    }
    while (s.size()) {
        pushdown(stack.top());
        stack.pop();
    }
}

inline void range_update(int l, int r, int v) {
    bool vl = false, vr = false;
    int x;
    int sl, sr;
    for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
        if (~l & 1) {
            x = l ^ 1;
            if (!vl) {
                vl = true;
                sl = x;
                applytag(x);
            }
            tag(x) += v;
            sum(x) += (R(x) - L(x) + 1) * v;
        }
        if (r & 1) {
            x = r ^ 1;
            if (!vr) {
                vr = true;
                sr = x;
                applytag(x);
            }
            tag(x) += v;
            sum(x) += (R(x) - L(x) + 1) * v;
        }
    }
    for (sl >>= 1; sl; sl >>= 1) {
        maintain(sl);
    }
    for (sr >>= 1; sr; sr >>= 1) {
        maintain(sr);
    }
}

inline int query (int l, int r) {
    bool vl, vr;
    vl = vr = false;
    for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
        if (~l & 1) {
            if (!vl) applytag(l ^ 1);
            ans += sum(l ^ 1);
            vl = true;
        }
        if (r & 1) {
            if (!vr) applytag(r ^ 1);
            ans += sum(r ^ 1);
            vr = true;
        }
    }
    return ans;
}

zkw其实还可以做到更多,包括但不局限于查询区间最值等等。

0.3. 复杂的线段树问题

接下来的代码都不使用重口味线段树,而是使用相对好看一点的带懒标记的小清新线段树。名字不是我瞎起的

0.3.1. 区间取模类

首先,知道一个结论:每个数最多被膜 \(\log x\) 次,再膜就膜不下去了。证明十分simple:一个数被膜一次,起码要减少一半以上。这个证明告诉我们了什么?告诉我们了如果用优化一点点的暴力瞎搞一下,复杂度还是可以接受的。所以,构建一个线段树,每一段维护一个区间最大值,如果当前膜操作到了这个区间,而被膜的数却比区间最大值都还要大,那么就跳过区间,而不是继续夏姬八膜。当然了,头铁的除外。

在修改方面(如果还要对值进行操作的话),单点修改和区间修改都是可以用这种暴力做一做。至于更高明的方法,嘿嘿,我还不会。


#include <bits/stdc++.h>
#define MAXN 400005
using namespace std;
const int maxn = 1e5 + 100;
long long a[maxn], m, n;

struct Node {
	long long _max, sum;
	Node(long long a = 0, long long b = 0) : _max(a), sum(b) {}
	Node operator + (const Node& rhs) {
		return Node(max(_max, rhs._max), sum + rhs.sum);
	}
} I[maxn * 4];

inline int read() {
	int f = 1, w = 0;
	char ch = 0;
	ch = getchar();
	while ((ch > '9' || ch < '0') && ch != '-') ch = getchar();
	if (ch == '-') f = -1, ch = getchar();
	while (ch >= '0' && ch <= '9') w = w * 10 + ch - '0', ch = getchar();
	return f * w;
}


void build(int o = 1, int L = 1, int R = n) {
	if(L == R)
		I[o] = Node(a[L], a[L]);
	else {
		int M = (L + R) >> 1;
		build(o << 1, L, M);
		build(o << 1 | 1, M + 1, R);
		I[o] = I[o << 1] + I[o << 1 | 1];
	}
}
long long ql, qr, q;
void query(int o = 1, int L = 1, int R = n) {
	if (ql <= L && R <= qr)
		q += I[o].sum;
	else {
		int M = (L + R) >> 1;
		if(ql <= M) query(o << 1, L, M);
		if(M < qr) query(o << 1 | 1, M + 1, R);
	}
}
void update(int o = 1, int L = 1, int R = n) {
	if(L == R)
		I[o] = Node(I[o].sum % q, I[o].sum % q);
	else {
		int M = (L + R) >> 1;
		if(ql <= M && I[o << 1]._max >= q) update(o << 1, L, M);
		if(M < qr && I[o << 1 | 1]._max >= q) update(o << 1 | 1, M + 1, R);
		I[o] = I[o << 1] + I[o << 1 | 1];
	}
}
void setv(int o = 1, int L = 1, int R = n) {
	if(L == ql && ql == R)
		I[o] = Node(q, q);
	else {
		int M = (L + R) >> 1;
		if(ql <= M) setv(o << 1, L, M);
		else setv(o << 1 | 1, M + 1, R);
		I[o] = I[o << 1] + I[o << 1 | 1];
	}
}
void solve() {
	n = read(); m = read();
	for(int i = 1; i <= n; i++) a[i] = read();
	build();
	while(m--) {
		int type; type = read();
		if(type == 1) {
			ql = read(); qr = read(); q = 0; query();
			printf("%lld\n", q);
		} else if(type == 2) {
			ql = read(); qr = read(); q = read(); update();
		} else {
			ql = read(); q = read(); setv();
		}
	}
}

int main() {
  //freopen("mod.in", "r", stdin);
  //freopen("mod.out", "w", stdout);
	solve();
	return 0;
}

0.3.2. 区间开根类

这道题的思路和上道题差不多,也是更优雅地暴力。

对于区间取模来说,比这个区间最大值还大的膜数就不用再膜了。对于区间开根来说,喜欢按计算器的朋友们都知道,再大的数开着开着也会为1。我们拿BZOJ 3211 为例。题目给出的数据是 \(0 \le data[i] \le 10^9\) 。10的9次方是不是看起来很大?按5次根号就变成1.9几了。不过对于计算机的精度处理,谁信谁睿智,所以我们认为它最多操作6次便成为了1。

所以对于线段树上每段,我们维护一个变量记录这个区间已经被开了几次根,如果开了6次以上,那就跳过。同样,头铁的可以不判。对于BZOJ3211,这道题还有树状数组+并查集的做法。

线段树代码:


#include <bits/stdc++.h>
#define MAXN 1000100

int n, m;
struct shit {
    int L, R, t;
    long long num;
} s[4 * MAXN];
int w[MAXN];
void push_up(int p) {
    s[p].num = s[p << 1].num + s[p << 1 | 1].num;
    return ;
}
void build(int p, int l, int r) {
    s[p].L = l;
    s[p].R = r;
    if(l == r) {
        s[p].num = w[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    push_up(p);
    return ;
}
void fuck(int p, int a, int b) {
    if(a <= s[p].L && s[p].R <= b) {
        s[p].t++;
        if(s[p].L == s[p].R) {
            s[p].num = sqrt(s[p].num);
            return ;
        }
    }
    int mid = (s[p].L + s[p].R) >> 1;
    if(s[p << 1].t < 6 && a <= mid)fuck(p << 1, a, b);
    if(s[p << 1 | 1].t < 6 && b > mid)fuck(p << 1 | 1, a, b);
    push_up(p);
    return ;
}
long long Q(int p, int a, int b) {
    if(a <= s[p].L && s[p].R <= b) {
        return s[p].num;
    }
    long long ans = 0;
    int mid = (s[p].L + s[p].R) >> 1;
    if(a <= mid) ans += Q(p << 1, a, b);
    if(b > mid)ans += Q(p << 1 | 1, a, b);
    return ans;
}
int main() {
    int a, b, f;
    scanf("%d", &n);
    for(int i = 1; i <= n; i++)scanf("%d", w + i);
    scanf("%d", &m);
    build(1, 1, n);
    while(m--) {
        scanf("%d%d%d", &f, &a, &b);
        if(f - 1)fuck(1, a, b);
        else printf("%lld\n", Q(1, a, b));
    }
    return 0;
}

0.3.3. 区间排序类

这个还是有点恐怖的,起码对于我当时第一次在雅礼见的时候还是很不可做。现在依旧不可做。好像是Codeforce上的原题

因为是对字母进行排序操作,所以很自然地想到建26棵线段树,每个线段树保存一个字母,记录在某一区间出现的次数。


#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#define out(i) <<#i<<"="<<(i)<<" "
#define OUT1(a1) cout out(a1) <<endl
#define OUT2(a1,a2) cout out(a1) out(a2) <<endl
#define OUT3(a1,a2,a3) cout out(a1) out(a2) out(a3)<<endl
#define maxn 100007
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1

using namespace std;
int n, q;
char str[maxn];
struct Node {
    int d[26];
    int D;
    bool sorted;
    bool Inc;
    void Insert(int v) {
        memset(d, 0, sizeof(d));
        D = d[v] = 1;
        sorted = false;
    }
    void Take(bool Left, int N) { 
        D = N;
        if(Left) {
            for(int i = 0; i < 26; ++i) {
                if(N >= d[i]) N -= d[i];
                else {
                    d[i] = N;
                    N = 0;
                }
            }
        } else {
            for(int i = 25; i >= 0; --i) {
                if(N >= d[i]) N -= d[i];
                else {
                    d[i] = N;
                    N = 0;
                }
            }
        }
    }
    void Drop(bool Left, int N) { 
        D = D - N;
        if(Left) {
            for(int i = 0; i < 26; ++i) {
                if(N >= d[i]) N -= d[i], d[i] = 0;
                else {
                    d[i] -= N;
                    N = 0;
                }
            }
        } else {
            for(int i = 25; i >= 0; --i) {
                if(N >= d[i]) N -= d[i], d[i] = 0;
                else {
                    d[i] -= N;
                    N = 0;
                }
            }
        }
    }
    Node operator +(const Node &B)const { 
        Node C;
        C.sorted = false;
        for(int i = 0; i < 26; ++i) C.d[i] = d[i] + B.d[i];
        C.D = D + B.D;
        return C;
    }
    void show() {
        printf("D=%d\n", D);
        for(int i = 0; i < 26; ++i) {
            if(d[i]) printf("d[%d]=%d\n", i, d[i]);
        }
    }
} ST[maxn << 2];
void PushDown(int rt) { 
    Node &L = ST[rt << 1], &R = ST[rt << 1 | 1];
    if(ST[rt].sorted) {
        int N = L.D;
        L = ST[rt];
        L.Take(ST[rt].Inc, N);
        N = R.D;
        R = ST[rt];
        R.Take(!ST[rt].Inc, N);
        ST[rt].sorted = false;
    }
}
void Build(int l, int r, int rt) { 
    if(l == r) {
        ST[rt].Insert(str[l] - 'a');
        return;
    }
    int m = (l + r) >> 1;
    Build(ls);
    Build(rs);
    ST[rt] = ST[rt << 1] + ST[rt << 1 | 1];
}
Node Query(int L, int R, int l, int r, int rt) { 
    if(L <= l && r <= R) {
        return ST[rt];
    }
    PushDown(rt);
    int m = (l + r) >> 1;
    Node LANS, RANS;
    int X = 0;
    if(L <= m) LANS = Query(L, R, ls), X += 1;
    if(R >  m) RANS = Query(L, R, rs), X += 2;
    if(X == 1) return LANS;
    if(X == 2) return RANS;
    return LANS + RANS;
}
Node Sum;
void Update(int L, int R, int l, int r, int rt) {
    if(L <= l && r <= R) {
        int N = ST[rt].D;
        ST[rt] = Sum;
        ST[rt].Take(Sum.Inc, N);
        Sum.Drop(Sum.Inc, N);
        return;
    }
    int m = (l + r) >> 1;
    if(L <= m) Update(L, R, ls);
    if(R >  m) Update(L, R, rs);
    ST[rt] = ST[rt << 1] + ST[rt << 1 | 1];
}
void Sort(int L, int R, int Inc) {
    Sum = Query(L, R, 1, n, 1); 
    Sum.sorted = true;
    Sum.Inc = Inc; 
    Update(L, R, 1, n, 1); 
}
void Down(int l, int r, int rt) { 
    if(l == r) {
        for(int i = 0; i < 26; ++i) {
            if(ST[rt].d[i]) {
                str[l] = i + 'a';
                break;
            }
        }
        return;
    }
    PushDown(rt);
    int m = (l + r) >> 1;
    Down(ls);
    Down(rs);
}
int main(void) {
    while(~scanf("%d%d", &n, &q)) {
        scanf("%s", str + 1);
        Build(1, n, 1); 
        for(int i = 0; i < q; ++i) {
            int x, y, k;
            scanf("%d%d%d", &x, &y, &k);
            Sort(x, y, k); 
        }
        Down(1, n, 1); 
        printf("%s\n", str + 1);
    }
    return 0;
}

0.4. 小结

暂时就想到这么多,做个记录以便日后回顾。代码和语言纰漏之处请指出。

xD;