Skip to content

Segment Tree

静态线段树

仅有区间最值查询操作。(静态区间求和用Cumsum即可)

(然而RMQ更好写)

// balanced lineup
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define P pair<int,int>
#define lc rt*2+1
#define rc rt*2+2
using namespace std;

const int maxn = 50005;
const int inf = 0x7fffffff;
int h[maxn];
int N, Q;

struct node {
    int l, r;
    int mx, mn;
    int m() { return (l + r) / 2; }
} tr[maxn*4];

void build(int rt, int l, int r) {
    tr[rt].l = l;
    tr[rt].r = r;
    if (l == r) {
        tr[rt].mx = tr[rt].mn = h[l];
        return;
    }
    int m = (l + r) / 2;
    build(lc, l, m);
    build(rc, m + 1, r);
    tr[rt].mx = max(tr[lc].mx, tr[rc].mx);
    tr[rt].mn = min(tr[lc].mn, tr[rc].mn);
}

int mx, mn;
void query(int rt, int l, int r) {
    if (l == tr[rt].l && r == tr[rt].r) {
        mx = max(mx, tr[rt].mx);
        mn = min(mn, tr[rt].mn);
        return;
    }
    int m = tr[rt].m();
    if (r <= m) query(lc, l, r);
    else if (l > m) query(rc, l, r);
    else {
        query(lc, l, m);
        query(rc, m + 1, r);
    }
}

int x, y;

int main() {
    scanf("%d%d", &N, &Q);
    for (int i = 0; i < N; i++) scanf("%d", h + i);
    build(0, 0, N - 1);
    for (int i = 0; i < Q; i++) {
        scanf("%d%d", &x, &y);
        mn = inf, mx = 0;
        query(0, x - 1, y - 1);
        printf("%d\n", mx - mn);
    }
}

动态线段树:区间最值替换

替换某个区间的最值为某个数。而非增加到某个数(见下)。也不是取当前值与目标值的最大值(无法实现)。

lazy代表是否被修改,是bool型。

区间最值也应该使用lazy更新,才能保证复杂度。pushdown操作即分解区间,把lazy标志(叶节点标志)下移一层。pushup即用子区间反推父区间的值。modify和query在分解子区间时都需要pushdown。modify最后还需要调用pushup。

const static int maxn = 10005;

struct node {
    int l, r;
    int mx, lazy;
    int m() { return (l + r) / 2; }
} seg[4 * maxn];

void build(int rt, int l, int r) {
    seg[rt].l = l;
    seg[rt].r = r;
    seg[rt].mx = 0;
    seg[rt].lazy = 0;
    if (l == r) return;
    build(2 * rt + 1, l, (l + r) / 2);
    build(2 * rt + 2, (l + r) / 2 + 1, r);
}

void pushup(int rt){
    seg[rt].mx = max(seg[2 * rt + 1].mx, seg[2 * rt + 2].mx);
}

void pushdown(int rt) {
    if (seg[rt].lazy) {
        //cout<<"push down "<<rt<<" "<< seg[rt].l << "-" << seg[rt].r << endl;
        seg[2 * rt + 1].mx = seg[rt].mx;
        seg[2 * rt + 2].mx = seg[rt].mx;
        seg[2 * rt + 2].lazy = seg[2 * rt + 1].lazy = 1;
        seg[rt].lazy = 0;
    }
}

int query(int rt, int l, int r) {
    //cout << "query " <<rt<<" "<< seg[rt].l << "-" << seg[rt].r << " " << l << "-" << r << endl;
    if (l == seg[rt].l && r == seg[rt].r)
        return seg[rt].mx;
    // push down
    pushdown(rt);
    int m = seg[rt].m();
    if (r <= m) return query(2 * rt + 1, l, r);
    else if (l > m) return query(2 * rt + 2, l, r);
    else return max(query(2 * rt + 1, l, m), query(2 * rt + 2, m + 1, r));
}

void modify(int rt, int l, int r, int v) {
    if (l == seg[rt].l && r == seg[rt].r) {
        seg[rt].mx = v; // set to v, not add v.
        seg[rt].lazy = 1;
        return;
    }
    // push down
    pushdown(rt);
    // interval decomp
    int m = seg[rt].m();
    if (r <= m) modify(2 * rt + 1, l, r, v);
    else if (l > m) modify(2 * rt + 2, l, r, v);
    else {
        modify(2 * rt + 1, l, m, v);
        modify(2 * rt + 2, m + 1, r, v);
    }
    // push up
    pushup(rt);
}
LeetCode 699 Falling Squares
class Solution {
public:
    const static int maxn = 10005;

    struct node {
        int l, r;
        int mx, lazy;
        int m() { return (l + r) / 2; }
    } seg[4 * maxn];

    void build(int rt, int l, int r) {
        seg[rt].l = l;
        seg[rt].r = r;
        seg[rt].mx = 0;
        seg[rt].lazy = 0;
        if (l == r) return;
        build(2 * rt + 1, l, (l + r) / 2);
        build(2 * rt + 2, (l + r) / 2 + 1, r);
    }

    void pushdown(int rt) {
        if (seg[rt].lazy) {
            seg[2 * rt + 1].mx = seg[rt].mx;
            seg[2 * rt + 2].mx = seg[rt].mx;
            seg[2 * rt + 2].lazy = seg[2 * rt + 1].lazy = 1;
            seg[rt].lazy = 0;
        }
    }

    int query(int rt, int l, int r) {
        if (l == seg[rt].l && r == seg[rt].r)
            return seg[rt].mx;
        // push down
        pushdown(rt);
        int m = seg[rt].m();
        if (r <= m) return query(2 * rt + 1, l, r);
        else if (l > m) return query(2 * rt + 2, l, r);
        else return max(query(2 * rt + 1, l, m), query(2 * rt + 2, m + 1, r));
    }

    void modify(int rt, int l, int r, int v) {
        if (l == seg[rt].l && r == seg[rt].r) {
            seg[rt].mx = v;
            seg[rt].lazy = 1;
            return;
        }
        // push down
        pushdown(rt);
        // interval decomp
        int m = seg[rt].m();
        if (r <= m) modify(2 * rt + 1, l, r, v);
        else if (l > m) modify(2 * rt + 2, l, r, v);
        else {
            modify(2 * rt + 1, l, m, v);
            modify(2 * rt + 2, m + 1, r, v);
        }
        // push up
        seg[rt].mx = max(seg[2 * rt + 1].mx, seg[2 * rt + 2].mx);
    }

    // 俄罗斯方块堆叠式的更新
    void blockmodify(int rt, int l, int r, int v) {
        modify(rt, l, r, v + query(rt, l, r));
    }


    vector<int> fallingSquares(vector<pair<int, int>>& positions) {
        // 离散化
        vector<int> xs;
        for(int i=0; i<positions.size(); i++){
            xs.push_back(positions[i].first);
            xs.push_back(positions[i].first + positions[i].second - 1);
        }
        sort(xs.begin(), xs.end());
        int uN = unique(xs.begin(), xs.end()) - xs.begin();
        map<int, int> m;
        for(int i=0; i<uN; i++) m[xs[i]]=i;
        // 线段树
        build(0, 0, uN);
        vector<int> ans;
        for(int i=0; i<positions.size(); i++){
            blockmodify(0,
                        m[positions[i].first],
                        m[positions[i].first + positions[i].second - 1],
                        positions[i].second
                       );
            ans.push_back(query(0, 0, uN));
        }
        return ans;
    }
};

动态线段树:区间加减

支持区间求和、求最值。

lazy是int型,代表没有加上去的值。

#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <iomanip>
#include <cstring>
#include <algorithm>
#include <string>
#include <queue>
#include <cmath>
#include <vector>
#include <stack>
using namespace std;

#define ll long long

const int maxn = 100005;
ll arr[maxn];
int N, Q;

struct node {
    int l, r;
    ll sum, inc;
    int mid() { return (l + r) / 2; }
} tr[maxn<<2];

#define lc 2*rt+1
#define rc 2*rt+2

void pushup(int rt) {
    tr[rt].sum = tr[lc].sum + tr[rc].sum;
}

void pushdown(int rt) {
    if (tr[rt].inc) {
        tr[rt].sum += (tr[rt].r - tr[rt].l + 1)*tr[rt].inc;
        tr[lc].inc += tr[rt].inc;
        tr[rc].inc += tr[rt].inc;
        tr[rt].inc = 0;
    }
}

void build(int rt, int l, int r) {
    tr[rt].l = l;
    tr[rt].r = r;
    tr[rt].inc = 0;
    if (l == r) {
        tr[rt].sum = arr[l];
        return;
    }
    int m = (l + r) / 2;
    build(lc, l, m);
    build(rc, m + 1, r);
    pushup(rt);
}


void add(int rt, int l, int r, int v) {
    if (tr[rt].l == l && tr[rt].r == r) {
        tr[rt].inc += v;
        return;
    }
    tr[rt].sum += (r - l + 1)*v;
    int m = tr[rt].mid();
    if (l > m) add(rc, l, r, v);
    else if (r <= m) add(lc, l, r, v);
    else {
        add(lc, l, m, v);
        add(rc, m + 1, r, v);
    }
}


ll query(int rt, int l, int r) {
    if (tr[rt].l == l && tr[rt].r == r) return tr[rt].sum + tr[rt].inc*(r - l + 1);
    pushdown(rt);
    int m = tr[rt].mid();
    if (l > m) return query(rc, l, r);
    else if (r <= m) return query(lc, l, r);
    else return query(lc, l, m) + query(rc, m + 1, r);
}

char c;
int a, b, d;

int main() {
    cin >> N >> Q;
    for (int i = 0; i < N; i++) cin >> arr[i];
    build(0, 0, N - 1);
    for (int i = 0; i < Q; i++) {
        cin >> c;
        if (c == 'Q') {
            cin >> a >> b;
            cout << query(0, a-1, b-1) << endl;
        }
        else {
            cin >> a >> b >> d;
            add(0, a-1, b-1, d);
        }
    }
}