Segment Tree
静态线段树
仅有区间最值查询操作。(静态区间求和用Cumsum即可)
(然而RMQ更好写)
// balanced lineup
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define P pair<int,int>
#define lc rt*2+1
#define rc rt*2+2
using namespace std;
const int maxn = 50005;
const int inf = 0x7fffffff;
int h[maxn];
int N, Q;
struct node {
int l, r;
int mx, mn;
int m() { return (l + r) / 2; }
} tr[maxn*4];
void build(int rt, int l, int r) {
tr[rt].l = l;
tr[rt].r = r;
if (l == r) {
tr[rt].mx = tr[rt].mn = h[l];
return;
}
int m = (l + r) / 2;
build(lc, l, m);
build(rc, m + 1, r);
tr[rt].mx = max(tr[lc].mx, tr[rc].mx);
tr[rt].mn = min(tr[lc].mn, tr[rc].mn);
}
int mx, mn;
void query(int rt, int l, int r) {
if (l == tr[rt].l && r == tr[rt].r) {
mx = max(mx, tr[rt].mx);
mn = min(mn, tr[rt].mn);
return;
}
int m = tr[rt].m();
if (r <= m) query(lc, l, r);
else if (l > m) query(rc, l, r);
else {
query(lc, l, m);
query(rc, m + 1, r);
}
}
int x, y;
int main() {
scanf("%d%d", &N, &Q);
for (int i = 0; i < N; i++) scanf("%d", h + i);
build(0, 0, N - 1);
for (int i = 0; i < Q; i++) {
scanf("%d%d", &x, &y);
mn = inf, mx = 0;
query(0, x - 1, y - 1);
printf("%d\n", mx - mn);
}
}
动态线段树:区间最值替换
替换某个区间的最值为某个数。而非增加到某个数(见下)。也不是取当前值与目标值的最大值(无法实现)。
lazy代表是否被修改,是bool型。
区间最值也应该使用lazy更新,才能保证复杂度。pushdown操作即分解区间,把lazy标志(叶节点标志)下移一层。pushup即用子区间反推父区间的值。modify和query在分解子区间时都需要pushdown。modify最后还需要调用pushup。
const static int maxn = 10005;
struct node {
int l, r;
int mx, lazy;
int m() { return (l + r) / 2; }
} seg[4 * maxn];
void build(int rt, int l, int r) {
seg[rt].l = l;
seg[rt].r = r;
seg[rt].mx = 0;
seg[rt].lazy = 0;
if (l == r) return;
build(2 * rt + 1, l, (l + r) / 2);
build(2 * rt + 2, (l + r) / 2 + 1, r);
}
void pushup(int rt){
seg[rt].mx = max(seg[2 * rt + 1].mx, seg[2 * rt + 2].mx);
}
void pushdown(int rt) {
if (seg[rt].lazy) {
//cout<<"push down "<<rt<<" "<< seg[rt].l << "-" << seg[rt].r << endl;
seg[2 * rt + 1].mx = seg[rt].mx;
seg[2 * rt + 2].mx = seg[rt].mx;
seg[2 * rt + 2].lazy = seg[2 * rt + 1].lazy = 1;
seg[rt].lazy = 0;
}
}
int query(int rt, int l, int r) {
//cout << "query " <<rt<<" "<< seg[rt].l << "-" << seg[rt].r << " " << l << "-" << r << endl;
if (l == seg[rt].l && r == seg[rt].r)
return seg[rt].mx;
// push down
pushdown(rt);
int m = seg[rt].m();
if (r <= m) return query(2 * rt + 1, l, r);
else if (l > m) return query(2 * rt + 2, l, r);
else return max(query(2 * rt + 1, l, m), query(2 * rt + 2, m + 1, r));
}
void modify(int rt, int l, int r, int v) {
if (l == seg[rt].l && r == seg[rt].r) {
seg[rt].mx = v; // set to v, not add v.
seg[rt].lazy = 1;
return;
}
// push down
pushdown(rt);
// interval decomp
int m = seg[rt].m();
if (r <= m) modify(2 * rt + 1, l, r, v);
else if (l > m) modify(2 * rt + 2, l, r, v);
else {
modify(2 * rt + 1, l, m, v);
modify(2 * rt + 2, m + 1, r, v);
}
// push up
pushup(rt);
}
LeetCode 699 Falling Squares
class Solution {
public:
const static int maxn = 10005;
struct node {
int l, r;
int mx, lazy;
int m() { return (l + r) / 2; }
} seg[4 * maxn];
void build(int rt, int l, int r) {
seg[rt].l = l;
seg[rt].r = r;
seg[rt].mx = 0;
seg[rt].lazy = 0;
if (l == r) return;
build(2 * rt + 1, l, (l + r) / 2);
build(2 * rt + 2, (l + r) / 2 + 1, r);
}
void pushdown(int rt) {
if (seg[rt].lazy) {
seg[2 * rt + 1].mx = seg[rt].mx;
seg[2 * rt + 2].mx = seg[rt].mx;
seg[2 * rt + 2].lazy = seg[2 * rt + 1].lazy = 1;
seg[rt].lazy = 0;
}
}
int query(int rt, int l, int r) {
if (l == seg[rt].l && r == seg[rt].r)
return seg[rt].mx;
// push down
pushdown(rt);
int m = seg[rt].m();
if (r <= m) return query(2 * rt + 1, l, r);
else if (l > m) return query(2 * rt + 2, l, r);
else return max(query(2 * rt + 1, l, m), query(2 * rt + 2, m + 1, r));
}
void modify(int rt, int l, int r, int v) {
if (l == seg[rt].l && r == seg[rt].r) {
seg[rt].mx = v;
seg[rt].lazy = 1;
return;
}
// push down
pushdown(rt);
// interval decomp
int m = seg[rt].m();
if (r <= m) modify(2 * rt + 1, l, r, v);
else if (l > m) modify(2 * rt + 2, l, r, v);
else {
modify(2 * rt + 1, l, m, v);
modify(2 * rt + 2, m + 1, r, v);
}
// push up
seg[rt].mx = max(seg[2 * rt + 1].mx, seg[2 * rt + 2].mx);
}
// 俄罗斯方块堆叠式的更新
void blockmodify(int rt, int l, int r, int v) {
modify(rt, l, r, v + query(rt, l, r));
}
vector<int> fallingSquares(vector<pair<int, int>>& positions) {
// 离散化
vector<int> xs;
for(int i=0; i<positions.size(); i++){
xs.push_back(positions[i].first);
xs.push_back(positions[i].first + positions[i].second - 1);
}
sort(xs.begin(), xs.end());
int uN = unique(xs.begin(), xs.end()) - xs.begin();
map<int, int> m;
for(int i=0; i<uN; i++) m[xs[i]]=i;
// 线段树
build(0, 0, uN);
vector<int> ans;
for(int i=0; i<positions.size(); i++){
blockmodify(0,
m[positions[i].first],
m[positions[i].first + positions[i].second - 1],
positions[i].second
);
ans.push_back(query(0, 0, uN));
}
return ans;
}
};
动态线段树:区间加减
支持区间求和、求最值。
lazy是int型,代表没有加上去的值。
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <iomanip>
#include <cstring>
#include <algorithm>
#include <string>
#include <queue>
#include <cmath>
#include <vector>
#include <stack>
using namespace std;
#define ll long long
const int maxn = 100005;
ll arr[maxn];
int N, Q;
struct node {
int l, r;
ll sum, inc;
int mid() { return (l + r) / 2; }
} tr[maxn<<2];
#define lc 2*rt+1
#define rc 2*rt+2
void pushup(int rt) {
tr[rt].sum = tr[lc].sum + tr[rc].sum;
}
void pushdown(int rt) {
if (tr[rt].inc) {
tr[rt].sum += (tr[rt].r - tr[rt].l + 1)*tr[rt].inc;
tr[lc].inc += tr[rt].inc;
tr[rc].inc += tr[rt].inc;
tr[rt].inc = 0;
}
}
void build(int rt, int l, int r) {
tr[rt].l = l;
tr[rt].r = r;
tr[rt].inc = 0;
if (l == r) {
tr[rt].sum = arr[l];
return;
}
int m = (l + r) / 2;
build(lc, l, m);
build(rc, m + 1, r);
pushup(rt);
}
void add(int rt, int l, int r, int v) {
if (tr[rt].l == l && tr[rt].r == r) {
tr[rt].inc += v;
return;
}
tr[rt].sum += (r - l + 1)*v;
int m = tr[rt].mid();
if (l > m) add(rc, l, r, v);
else if (r <= m) add(lc, l, r, v);
else {
add(lc, l, m, v);
add(rc, m + 1, r, v);
}
}
ll query(int rt, int l, int r) {
if (tr[rt].l == l && tr[rt].r == r) return tr[rt].sum + tr[rt].inc*(r - l + 1);
pushdown(rt);
int m = tr[rt].mid();
if (l > m) return query(rc, l, r);
else if (r <= m) return query(lc, l, r);
else return query(lc, l, m) + query(rc, m + 1, r);
}
char c;
int a, b, d;
int main() {
cin >> N >> Q;
for (int i = 0; i < N; i++) cin >> arr[i];
build(0, 0, N - 1);
for (int i = 0; i < Q; i++) {
cin >> c;
if (c == 'Q') {
cin >> a >> b;
cout << query(0, a-1, b-1) << endl;
}
else {
cin >> a >> b >> d;
add(0, a-1, b-1, d);
}
}
}