Skip to content

Segment Tree

use

\(log(N)\) 区间更新,区间查询。

definition

root is an interval [a, b], left child is [a, (a+b)/2], right child is [(a+b)/2+1, b]. Leaves are numbers.

depth is \(ceil(log_2(b-a+1))+1\)

Implement

  • linked nodes
struct node{
    int L,R;
    node *left, *right;
    int data;
};
  • one-dim array

    Not a complete tree, but nearly.

\[ \displaylines{ 2*2^{ceil(log_2(n))}-1 \le 4n-1 } \]

so it is safe if we assign [4 * maxn] nodes, then we can use \(2*i+1, 2*i+2\) instead of node*

Operations

  • 区间分解

    递归从根节点开始分解,找到若干个终止节点。每层最多两个终止节点。

    Time Complexity: \(O(log(N))\)

  • 区间查询本质就是区间分解
  • 区间更新: 区间分解+lazy updating

Examples

  • Balanced Lineup POJ 3264

    #define _CRT_SECURE_NO_WARNINGS
    #include <cstdio>
    #include <iostream>
    #include <cstring>
    #include <string>
    #include <assert.h>
    #include <vector>
    #include <algorithm>
    #define ll long long
    using namespace std;
    
    const int maxn = 50005;
    const int inf = 0x7fffffff;
    int N, Q;
    int a, b, d;
    char c;
    
    int MIN = inf, MAX = 0;
    
    struct node {
      int L, R;
      int ma, mi;
      int mid() { return (L + R) / 2; }
    };
    
    node tree[4 * maxn];
    int arr[maxn];
    
    void build(int root, int l, int r) {
      tree[root].L = l;
      tree[root].R = r;
      if (l == r) {
          tree[root].mi = tree[root].ma = arr[l];
          return;
      }
      build(2 * root + 1, l, (l + r) / 2);
      build(2 * root + 2, (l + r) / 2 + 1, r);
      tree[root].mi = min(tree[2 * root + 1].mi, tree[2 * root + 2].mi);
      tree[root].ma = max(tree[2 * root + 1].ma, tree[2 * root + 2].ma);
    }
    
    void query(int root, int l, int r) {
      //cout << "query " << root << " " << l << "-" << r << endl;
      if (tree[root].L == l && tree[root].R == r) {
          MIN = min(MIN, tree[root].mi);
          MAX = max(MAX, tree[root].ma);
          return;
      }
      int mid = tree[root].mid();
      if (r <= mid)
          query(2 * root + 1, l, r);
      else if (l > mid)
          query(2 * root + 2, l, r);
      else {
          query(2 * root + 1, l, mid);
          query(2 * root + 2, mid + 1, r);
      }
    }
    
    int main() {
      scanf("%d%d", &N, &Q);
      for (int i = 0; i < N; i++){
          scanf("%d", &arr[i]);
      }
      build(0, 0, N - 1);
      for (int i = 0; i < Q; i++) {
          MIN = inf, MAX = 0;
          scanf("%d%d", &a, &b);
          query(0, a - 1, b - 1);
          cout << MAX - MIN << endl;
      }
    }
    
  • A simple problem with integers POJ 3468

    #include <iostream>
    #include <cstring>
    #include <string>
    #include <assert.h>
    #include <vector>
    #include <algorithm>
    #define ll long long
    
    using namespace std;
    
    const int maxn = 100005;
    int N, Q;
    int a, b;
    ll d;
    char c;
    
    struct node {
      int L, R;
      ll sum, inc;
      int mid() { return (L + R) / 2; }
    };
    
    node tree[4 * maxn];
    
    void build(int root, int l, int r) {
      tree[root].L = l;
      tree[root].R = r;
      tree[root].sum = 0;
      tree[root].inc = 0;
      if (l == r) return;
      build(2 * root + 1, l, (l + r) / 2);
      build(2 * root + 2, (l + r) / 2 + 1, r);
    }
    
    void modify(int root, int l, int r, ll v) {
      if (tree[root].L == l && tree[root].R == r) {
          tree[root].inc += v;
          return;
      }
      tree[root].sum += (r - l + 1) * v;
      int mid = tree[root].mid();
      if (r <= mid)
          modify(2 * root + 1, l, r, v);
      else if (l > mid)
          modify(2 * root + 2, l, r, v);
      else {
          modify(2 * root + 1, l, mid, v);
          modify(2 * root + 2, mid + 1, r, v);
      }
    }
    
    ll query(int root, int l, int r) {
      if (tree[root].L == l && tree[root].R == r)
          return tree[root].sum + tree[root].inc * (r - l + 1);
      if (tree[root].inc) {
          tree[root].sum += (tree[root].R - tree[root].L + 1)*tree[root].inc;
          tree[2 * root + 1].inc += tree[root].inc;
          tree[2 * root + 2].inc += tree[root].inc;
          tree[root].inc = 0;
      }
      int mid = tree[root].mid();
      if (r <= mid)
          return query(2 * root + 1, l, r);
      else if (l > mid)
          return query(2 * root + 2, l, r);
      else
          return query(2 * root + 1, l, mid) + query(2 * root + 2, mid + 1, r);
    }
    
    int main() {
      cin >> N >> Q;
      build(0, 0, N - 1);
      for (int i = 0; i < N; i++){
          cin >> d;
          modify(0, i, i, d);
      }
      for (int i = 0; i < Q; i++) {
          cin >> c;
          if (c == 'Q') {
              cin >> a >> b;
              cout << query(0, a - 1, b - 1) << endl;
          }
          else {
              cin >> a >> b >> d;
              modify(0, a - 1, b - 1, d);
          }
      }
    }
    
  • Lost Cows

    • 倒序更新,查找VIS。
    #include <iostream>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    const int maxn = 8005;
    
    struct node {
      int L, R, len;
    } tree[maxn << 2];
    
    void build(int rt, int l, int r) {
      tree[rt].L = l;
      tree[rt].R = r;
      if (l == r) {
          tree[rt].len = 1;
          return;
      }
      int m = (l + r) / 2;
      build(2 * rt + 1, l, m);
      build(2 * rt + 2, m + 1, r);
      tree[rt].len = tree[2 * rt + 1].len + tree[2 * rt + 2].len;
    }
    
    int query(int rt, int k) {
      //cout << "Q " << rt << " " << k <<" len:"<<tree[rt].len<< endl;
      tree[rt].len--;
      if (tree[rt].L == tree[rt].R) return tree[rt].L;
      if (tree[2 * rt + 1].len > k) return query(2 * rt + 1, k);
      else return query(2 * rt + 2, k - tree[2 * rt + 1].len);
    }
    
    int N;
    int arr[maxn], ans[maxn];
    int main() {
      cin >> N;
      build(0, 0, N - 1);
      memset(arr, 0, sizeof(arr));
      for (int i = 1; i < N; i++) cin >> arr[i];
      for (int i = N - 1; i >= 0; i--) ans[i] = query(0, arr[i]) + 1;
      for (int i = 0; i < N; i++) cout << ans[i] << endl;
    }
    
    • 二分查找的BIT
    #include <iostream>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    const int maxn = 8005;
    int arr[maxn], ans[maxn], vis[maxn], bit[maxn];
    int N;
    
    int lowbit(int x) { return x & (-x); }
    
    void modify(int i, int v) {
      vis[i] += v;
      for (i; i <= N; i += lowbit(i)) bit[i] += v;
    }
    
    int getsum(int i) {
      int res = 0;
      for (i; i > 0; i -= lowbit(i)) res += bit[i];
      return res;
    }
    
    int main() {
      memset(bit, 0, sizeof(bit));
      memset(vis, 0, sizeof(vis));
      arr[0] = 0;
      cin >> N;
      for (int i = 2; i <= N; i++) cin >> arr[i];
      for (int i = N; i >= 1; i--) {
          int l = 1, r = N, m;
          while (l < r) {
              m = (r + l + 1) / 2;
              int s = getsum(m-1);
              if (s + arr[i] == m - 1) l = m;
              else if (s + arr[i] > m -1) l = m + 1;
              else r = m - 1; 
          }
          ans[i] = l;
          modify(l, 1);
      }
      for (int i = 1; i <= N; i++) cout << ans[i] << endl;
    }
    
  • Mayor's Posters

    动态构造线段树,防止超出内存限制

    #define _CRT_SECURE_NO_WARNINGS
    #include <cstdio>
    #include <iostream>
    #include <cstring>
    #include <algorithm>
    #define lc(rt) rt->ll
    #define rc(rt) rt->rr
    using namespace std;
    
    const int maxl = 10000005;
    const int maxn = 10005;
    int n;
    int a[maxn], b[maxn];
    bool flag = false;
    
    struct node {
     int L, R, occ;
     node *ll, *rr;
    } t[maxl<<2];
    
    int cnt = 0;
    
    void build(node* rt, int l, int r) {
     rt->L = l;
     rt->R = r;
     rt->occ = 0;
     lc(rt) = NULL;
     rc(rt) = NULL;
    }
    
    void occupy(node* rt, int l, int r) {
     //cout << "occ " << rt->L << "-" << rt->R << " " << l << "-" << r << endl;
     if (rt->L == l && rt->R == r) {
         if (!rt->occ) {
             //cout << "fill" << endl;
             rt->occ = 1;
             flag = true;
         }
         return;
     }
      // build sub tree
     if (lc(rt) == NULL) {
         cnt++;
         build(t + cnt, rt->L, (rt->L + rt->R) / 2);
         lc(rt) = t + cnt;
     }
     if (rc(rt) == NULL) {
         cnt++;
         build(t + cnt, (rt->L + rt->R) / 2 + 1, rt->R);
         rc(rt) = t + cnt;
     }
      //pushdown
     if (rt->occ) {
         lc(rt)->occ = 1;
         rc(rt)->occ = 1;
     }
     int mid = (rt->L + rt->R) / 2;
     if (r <= mid) occupy(lc(rt), l, r);
     else if (l > mid) occupy(rc(rt), l, r);
     else {
         occupy(lc(rt), l, mid);
         occupy(rc(rt), mid + 1, r);
     }
     if (lc(rt)->occ && rc(rt)->occ) rt->occ = 1;
    }
    
    int main() {
     int cas;
     scanf("%d", &cas);
     while (cas--) {
         scanf("%d", &n);
         int ml = 0;
         for (int i = 0; i < n; i++) {
             scanf("%d%d", &a[i], &b[i]);
             ml = max(ml, b[i]);
         }
         cnt = 0;
         build(t, 0, ml - 1);
         int res = 0;
         for (int i = n - 1; i >= 0; i--) {
             flag = false;
             occupy(t, a[i] - 1, b[i] - 1);
             if (flag) res++;
         }
         cout << res << endl;
     }
    }
    

    Discreatization

    #define _CRT_SECURE_NO_WARNINGS
    #include <iostream>
    #include <map>
    #include <iomanip>
    #include <cstring>
    #include <algorithm>
    #include <string>
    #include <queue>
    #include <cmath>
    #include <vector>
    #include <stack>
    using namespace std;
    
    // 9:36 suspicious bugs
    // 9:47 apple tree (add in[x], not x)
    // 10:02 trie (match forgotten, and build before match!)
    // 10:20 popular cows
    // 10:30 currency exchange (so, BF needs no inf check?)
    // 10:40 ------
    
    const int maxn = 10005;
    int p[maxn][2];
    int N;
    
    struct node {
     int l, r;
     bool occ;
     int mid() { return (l + r) / 2; }
    } tr[maxn*8];
    
    #define lc 2*rt+1
    #define rc 2*rt+2
    
    void pushup(int rt) {
     tr[rt].occ = tr[lc].occ & tr[rc].occ;
    }
    
    void pushdown(int rt) {
     if (tr[rt].occ) {
         tr[lc].occ = 1;
         tr[rc].occ = 1;
     }
    }
    
    void build(int rt, int l, int r) {
     tr[rt].l = l;
     tr[rt].r = r;
     tr[rt].occ = 0;
     if (l == r) return;
     int m = tr[rt].mid();
     build(lc, l, m);
     build(rc, m + 1, r);
    }
    
    bool add(int rt, int l, int r) {
     if (tr[rt].l == l && tr[rt].r == r) {
         if (tr[rt].occ == 0) {
             tr[rt].occ = 1;
             return true;
         }
         else return false;
     }
     pushdown(rt);
     int m = tr[rt].mid();
     bool flag;
     if (l > m) flag = add(rc, l, r);
     else if (r <= m) flag = add(lc, l, r);
     else flag = add(lc, l, m) | add(rc, m + 1, r);
     pushup(rt);
     return flag;
    }
    
    int main() {
     int cas;
     cin >> cas;
     while (cas--) {
         cin >> N;
         vector<int> xs;
         map<int, int> m;
         for (int i = 0; i < N; i++) {
             cin >> p[i][0] >> p[i][1];
             xs.push_back(p[i][0]);
             xs.push_back(p[i][1]);
         }
         sort(xs.begin(), xs.end());
         int uN = unique(xs.begin(), xs.end()) - xs.begin();
         for (int i = 0; i < uN; i++) m[xs[i]] = i;
         build(0, 0, uN - 1);
         int ans = 0;
         for (int i = N - 1; i >= 0; i--) {
             if (add(0, m[p[i][0]], m[p[i][1]])) 
                 ans++;
         }
         cout << ans << endl;
     }
    }