Segment Tree
use
\(log(N)\) 区间更新,区间查询。
definition
root is an interval [a, b], left child is [a, (a+b)/2], right child is [(a+b)/2+1, b]. Leaves are numbers.
depth is \(ceil(log_2(b-a+1))+1\)
Implement
- linked nodes
struct node{
int L,R;
node *left, *right;
int data;
};
-
one-dim array
Not a complete tree, but nearly.
\[
\displaylines{
2*2^{ceil(log_2(n))}-1 \le 4n-1
}
\]
so it is safe if we assign [4 * maxn]
nodes, then we can use \(2*i+1, 2*i+2\) instead of node*
Operations
-
区间分解
递归从根节点开始分解,找到若干个终止节点。每层最多两个终止节点。
Time Complexity: \(O(log(N))\)
- 区间查询本质就是区间分解
- 区间更新: 区间分解+lazy updating
Examples
-
Balanced Lineup POJ 3264
#define _CRT_SECURE_NO_WARNINGS #include <cstdio> #include <iostream> #include <cstring> #include <string> #include <assert.h> #include <vector> #include <algorithm> #define ll long long using namespace std; const int maxn = 50005; const int inf = 0x7fffffff; int N, Q; int a, b, d; char c; int MIN = inf, MAX = 0; struct node { int L, R; int ma, mi; int mid() { return (L + R) / 2; } }; node tree[4 * maxn]; int arr[maxn]; void build(int root, int l, int r) { tree[root].L = l; tree[root].R = r; if (l == r) { tree[root].mi = tree[root].ma = arr[l]; return; } build(2 * root + 1, l, (l + r) / 2); build(2 * root + 2, (l + r) / 2 + 1, r); tree[root].mi = min(tree[2 * root + 1].mi, tree[2 * root + 2].mi); tree[root].ma = max(tree[2 * root + 1].ma, tree[2 * root + 2].ma); } void query(int root, int l, int r) { //cout << "query " << root << " " << l << "-" << r << endl; if (tree[root].L == l && tree[root].R == r) { MIN = min(MIN, tree[root].mi); MAX = max(MAX, tree[root].ma); return; } int mid = tree[root].mid(); if (r <= mid) query(2 * root + 1, l, r); else if (l > mid) query(2 * root + 2, l, r); else { query(2 * root + 1, l, mid); query(2 * root + 2, mid + 1, r); } } int main() { scanf("%d%d", &N, &Q); for (int i = 0; i < N; i++){ scanf("%d", &arr[i]); } build(0, 0, N - 1); for (int i = 0; i < Q; i++) { MIN = inf, MAX = 0; scanf("%d%d", &a, &b); query(0, a - 1, b - 1); cout << MAX - MIN << endl; } }
-
A simple problem with integers POJ 3468
#include <iostream> #include <cstring> #include <string> #include <assert.h> #include <vector> #include <algorithm> #define ll long long using namespace std; const int maxn = 100005; int N, Q; int a, b; ll d; char c; struct node { int L, R; ll sum, inc; int mid() { return (L + R) / 2; } }; node tree[4 * maxn]; void build(int root, int l, int r) { tree[root].L = l; tree[root].R = r; tree[root].sum = 0; tree[root].inc = 0; if (l == r) return; build(2 * root + 1, l, (l + r) / 2); build(2 * root + 2, (l + r) / 2 + 1, r); } void modify(int root, int l, int r, ll v) { if (tree[root].L == l && tree[root].R == r) { tree[root].inc += v; return; } tree[root].sum += (r - l + 1) * v; int mid = tree[root].mid(); if (r <= mid) modify(2 * root + 1, l, r, v); else if (l > mid) modify(2 * root + 2, l, r, v); else { modify(2 * root + 1, l, mid, v); modify(2 * root + 2, mid + 1, r, v); } } ll query(int root, int l, int r) { if (tree[root].L == l && tree[root].R == r) return tree[root].sum + tree[root].inc * (r - l + 1); if (tree[root].inc) { tree[root].sum += (tree[root].R - tree[root].L + 1)*tree[root].inc; tree[2 * root + 1].inc += tree[root].inc; tree[2 * root + 2].inc += tree[root].inc; tree[root].inc = 0; } int mid = tree[root].mid(); if (r <= mid) return query(2 * root + 1, l, r); else if (l > mid) return query(2 * root + 2, l, r); else return query(2 * root + 1, l, mid) + query(2 * root + 2, mid + 1, r); } int main() { cin >> N >> Q; build(0, 0, N - 1); for (int i = 0; i < N; i++){ cin >> d; modify(0, i, i, d); } for (int i = 0; i < Q; i++) { cin >> c; if (c == 'Q') { cin >> a >> b; cout << query(0, a - 1, b - 1) << endl; } else { cin >> a >> b >> d; modify(0, a - 1, b - 1, d); } } }
-
Lost Cows
- 倒序更新,查找VIS。
#include <iostream> #include <cstring> #include <algorithm> using namespace std; const int maxn = 8005; struct node { int L, R, len; } tree[maxn << 2]; void build(int rt, int l, int r) { tree[rt].L = l; tree[rt].R = r; if (l == r) { tree[rt].len = 1; return; } int m = (l + r) / 2; build(2 * rt + 1, l, m); build(2 * rt + 2, m + 1, r); tree[rt].len = tree[2 * rt + 1].len + tree[2 * rt + 2].len; } int query(int rt, int k) { //cout << "Q " << rt << " " << k <<" len:"<<tree[rt].len<< endl; tree[rt].len--; if (tree[rt].L == tree[rt].R) return tree[rt].L; if (tree[2 * rt + 1].len > k) return query(2 * rt + 1, k); else return query(2 * rt + 2, k - tree[2 * rt + 1].len); } int N; int arr[maxn], ans[maxn]; int main() { cin >> N; build(0, 0, N - 1); memset(arr, 0, sizeof(arr)); for (int i = 1; i < N; i++) cin >> arr[i]; for (int i = N - 1; i >= 0; i--) ans[i] = query(0, arr[i]) + 1; for (int i = 0; i < N; i++) cout << ans[i] << endl; }
- 二分查找的BIT
#include <iostream> #include <cstring> #include <algorithm> using namespace std; const int maxn = 8005; int arr[maxn], ans[maxn], vis[maxn], bit[maxn]; int N; int lowbit(int x) { return x & (-x); } void modify(int i, int v) { vis[i] += v; for (i; i <= N; i += lowbit(i)) bit[i] += v; } int getsum(int i) { int res = 0; for (i; i > 0; i -= lowbit(i)) res += bit[i]; return res; } int main() { memset(bit, 0, sizeof(bit)); memset(vis, 0, sizeof(vis)); arr[0] = 0; cin >> N; for (int i = 2; i <= N; i++) cin >> arr[i]; for (int i = N; i >= 1; i--) { int l = 1, r = N, m; while (l < r) { m = (r + l + 1) / 2; int s = getsum(m-1); if (s + arr[i] == m - 1) l = m; else if (s + arr[i] > m -1) l = m + 1; else r = m - 1; } ans[i] = l; modify(l, 1); } for (int i = 1; i <= N; i++) cout << ans[i] << endl; }
-
Mayor's Posters
动态构造线段树,防止超出内存限制。
#define _CRT_SECURE_NO_WARNINGS #include <cstdio> #include <iostream> #include <cstring> #include <algorithm> #define lc(rt) rt->ll #define rc(rt) rt->rr using namespace std; const int maxl = 10000005; const int maxn = 10005; int n; int a[maxn], b[maxn]; bool flag = false; struct node { int L, R, occ; node *ll, *rr; } t[maxl<<2]; int cnt = 0; void build(node* rt, int l, int r) { rt->L = l; rt->R = r; rt->occ = 0; lc(rt) = NULL; rc(rt) = NULL; } void occupy(node* rt, int l, int r) { //cout << "occ " << rt->L << "-" << rt->R << " " << l << "-" << r << endl; if (rt->L == l && rt->R == r) { if (!rt->occ) { //cout << "fill" << endl; rt->occ = 1; flag = true; } return; } // build sub tree if (lc(rt) == NULL) { cnt++; build(t + cnt, rt->L, (rt->L + rt->R) / 2); lc(rt) = t + cnt; } if (rc(rt) == NULL) { cnt++; build(t + cnt, (rt->L + rt->R) / 2 + 1, rt->R); rc(rt) = t + cnt; } //pushdown if (rt->occ) { lc(rt)->occ = 1; rc(rt)->occ = 1; } int mid = (rt->L + rt->R) / 2; if (r <= mid) occupy(lc(rt), l, r); else if (l > mid) occupy(rc(rt), l, r); else { occupy(lc(rt), l, mid); occupy(rc(rt), mid + 1, r); } if (lc(rt)->occ && rc(rt)->occ) rt->occ = 1; } int main() { int cas; scanf("%d", &cas); while (cas--) { scanf("%d", &n); int ml = 0; for (int i = 0; i < n; i++) { scanf("%d%d", &a[i], &b[i]); ml = max(ml, b[i]); } cnt = 0; build(t, 0, ml - 1); int res = 0; for (int i = n - 1; i >= 0; i--) { flag = false; occupy(t, a[i] - 1, b[i] - 1); if (flag) res++; } cout << res << endl; } }
Discreatization
#define _CRT_SECURE_NO_WARNINGS #include <iostream> #include <map> #include <iomanip> #include <cstring> #include <algorithm> #include <string> #include <queue> #include <cmath> #include <vector> #include <stack> using namespace std; // 9:36 suspicious bugs // 9:47 apple tree (add in[x], not x) // 10:02 trie (match forgotten, and build before match!) // 10:20 popular cows // 10:30 currency exchange (so, BF needs no inf check?) // 10:40 ------ const int maxn = 10005; int p[maxn][2]; int N; struct node { int l, r; bool occ; int mid() { return (l + r) / 2; } } tr[maxn*8]; #define lc 2*rt+1 #define rc 2*rt+2 void pushup(int rt) { tr[rt].occ = tr[lc].occ & tr[rc].occ; } void pushdown(int rt) { if (tr[rt].occ) { tr[lc].occ = 1; tr[rc].occ = 1; } } void build(int rt, int l, int r) { tr[rt].l = l; tr[rt].r = r; tr[rt].occ = 0; if (l == r) return; int m = tr[rt].mid(); build(lc, l, m); build(rc, m + 1, r); } bool add(int rt, int l, int r) { if (tr[rt].l == l && tr[rt].r == r) { if (tr[rt].occ == 0) { tr[rt].occ = 1; return true; } else return false; } pushdown(rt); int m = tr[rt].mid(); bool flag; if (l > m) flag = add(rc, l, r); else if (r <= m) flag = add(lc, l, r); else flag = add(lc, l, m) | add(rc, m + 1, r); pushup(rt); return flag; } int main() { int cas; cin >> cas; while (cas--) { cin >> N; vector<int> xs; map<int, int> m; for (int i = 0; i < N; i++) { cin >> p[i][0] >> p[i][1]; xs.push_back(p[i][0]); xs.push_back(p[i][1]); } sort(xs.begin(), xs.end()); int uN = unique(xs.begin(), xs.end()) - xs.begin(); for (int i = 0; i < uN; i++) m[xs[i]] = i; build(0, 0, uN - 1); int ans = 0; for (int i = N - 1; i >= 0; i--) { if (add(0, m[p[i][0]], m[p[i][1]])) ans++; } cout << ans << endl; } }