树状数组(binary index tree) 的主要作用是单点修改与区间查询(区间和等),而借用差分的思想,树状数组可以达到区间修改与区间查询的功能。而又由于BIT的常数比线段树小很多,所以很多用线段树解决的问题反而可以用BIT乱搞一通说不懂更快。

基本思想

//基本的BIT操作不再赘述

修改区间的时候,引用一个delta[]数组,查询前缀和的时候算上delta[]数组的贡献即可:

令sum[i]为数组a[]的前i项和,则:

\[\begin{split} sum[i] = a[1] + a[2] + … + a[i - 1] + a[i] + delta[1] \times i \\+ delta[2] \times (i - 1) + … + delta[i - 1] \times 2 + delta[i] \times 1 \end{split}\]

其中,delta[$i$]为区间$[i,n]$的共同增量,而修改区间$[l, r]$时则可以将delta[$l$] += $x$, delta[$r + 1$] -= $x$.

单靠上面的式子肯定是不够的,进行变换一下: \begin{equation} \begin{aligned} sum[i] &=\sum_{x=1}^ia[x]+\sum_{x=1}^idelta[x] \times (i + 1 - x) \\&= \sum_{x=1}^ia[x] + (i + 1) \times \sum_{x=1}^i delta[x] - \sum_{x=1}^idelta[x] \times x \end{aligned} \end{equation}
这样就很OK了。

然后,注意到式子中的delta[x], delta[x] * x我们只需要维护两个树状数组delta与deltai,分别记录delta[x]与delta[x]*x的前缀和。而sum[]数组则可以用O(n)来维护,因为其不需要单独修改。

如此一来,借用差分的思想,用数组记录区间修改带来的影响,然后用“增减量”来表示区间的和。

这样,查询区间[l, r]的和则可以用sum[r] - sum[l - 1]来求之。

实现

这样操作完以后,我们就实现了利用树状数组的区间修改与区间查询。

可惜大多数题目根本不会这么裸

模板题:CODEVS 1082 线段树练习

附上AC代码(比线段树快):

//#include </Users/lntr/stdc++.h>

#include <bits/stdc++>

#define MAXN 200000 + 10

using namespace std;
typedef long long ll;

inline int read() {
    int data = 0, w = 1; char ch = getchar();
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') w = -1, ch = getchar();
    while (ch >= '0' && ch <= '9') data = data * 10 + ch - '0', ch = getchar();
    return w * data;
}

int n, m;

ll delta[MAXN];
ll deltai[MAXN];
ll sum[MAXN];

ll lowbit(int x) {
    return x & (-x);
}

ll query(ll *a, int x) {
    ll sum = 0ll;
    while (x > 0) {
        sum += a[x];
        x -= lowbit(x);
    }
    return sum;
}

void add(ll *a, int x, int d) {
    while (x <= n) {
        a[x] += d;
        x += lowbit(x);
    }
}

int main() {
    n = read();
    for (int i = 1, t; i <= n; i++) {
        t = read();
        sum[i] = sum[i - 1] + t;
    }
    m = read();
    while (m--) {
        int sign = read();
        if (sign == 1) {
            int l = read(), r = read(), x = read();
            add(delta, l, x);
            add(delta, r + 1, -x);
            add(deltai, l, x * l);
            add(deltai, r + 1, -x * (r + 1));
        }
        else {
            int l = read(), r = read();
            ll sum1 = sum[l - 1] + l * query(delta, l - 1) - query(deltai, l - 1);
            ll sum2 = sum[r] + (r + 1) * query(delta, r) - query(deltai, r);
            printf("%lld\n", sum2 - sum1);
        }
    }
    return 0;
}