Skip to content

Binary Indexed Tree

Basic

  • Definition

    // build C from a, i starts from 1.
    C[i] = a[i - lowbit(i) + 1] + a[i-lowbit(i)+2] + ... + a[i]
    
    // lowbit operator
    // replacing all 1 except the last to 0.
    // e.g. x = 00001101, -x = 11110011, lowbit(x) = 00000001
    // x2 = x + lowbit(x) = 00001110, -x = 11110010, lowbit(x2) = 00000010
    // x3 = x2 + lowbit(x2) = 00010000
    lowbit(x) = x & (-x)  
    
  • Usage

    \(O(log\ N)\): 单点更新,区间求和

    证明:

    \(\sum_i^j = sum(j) - sum(i-1)\)

    \(sum(k) = C[n_1] + C[n_2] + ... + C[k - lowbit(k)] + C[k]\)

    expanding \(C[\cdot]\) can prove the correctness.

    \(k-lowbit(k)\) can eliminate the leftmost 1 of k.

    so there are at most \(log\ N\) elements.

    updating \(a[i]\) will lead to updating of :

    \(C[i], C[i + lowbit(i)], ..., C[n_m]\)

    until \(n_m + lowbit(n_m) > length(a)\)

    there are at most \(log\ N\) elements again.

    \(O(N)\): build time complexity

\[ \displaylines{ C[k] = sum(k)-sum(k-lowbit(k)) } \]
  • Implementation

    int bit[maxn];
    int N;
    
    int lowbit(int x) { return x & (-x); }
    
    void add(int i, int v) {
      for (i; i <= N; i += lowbit(i)) bit[i] += v;
    }
    
    // sum of arr[1, i]
    int getsum(int i) {
      int res = 0;
      for (i; i > 0; i -= lowbit(i)) res += bit[i];
      return res;
    }
    

    Variant (reversed BIT) :

    int bit[maxn];
    int N;
    
    int lowbit(int x) { return x & (-x); }
    
    void add(int i, int v) {
        for (i; i > 0; i -= lowbit(i)) bit[i] += v;
    }
    
    // sum of arr[i, N]
    int getsum(int i) {
      int res = 0;
        for (i; i <= N; i += lowbit(i)) res += bit[i];
      return res;
    }
    
  • Examples

    • POJ 3321 Apple Tree

      the key point is how to arrange the nodes so that a sub tree falls into a region in disjoint set. By performing a DFS and index nodes with entering and leaving time, we can solve it.

      #include <iostream>
      #include <algorithm>
      #include <cstring>
      #include <string>
      #include <vector>
      using namespace std;
      
      const int maxN = 100000 + 5;
      string q;
      int N, M, a, b;
      vector<int> t[maxN];
      int S[maxN], E[maxN];
      
      int arr[maxN];
      int bit[maxN];
      
      int lowbit(int x) {
          return x & (-x);
      }
      
      void modify(int i, int v) {
          while (i <= N) {
              bit[i] += v;
              i += lowbit(i);
          }
      }
      
      int getsum(int i) {
          int res = 0;
          while (i > 0) {
              res += bit[i];
              i -= lowbit(i);
          }
          return res;
      }
      
      
      int ti = 1;
      void dfs(int n) {
          S[n] = ti;
          for (int to : t[n]) {
              ti++;
              dfs(to);
          }
          E[n] = ti;
      }
      
      int main() {
          memset(bit, 0, sizeof(bit));
          cin >> N;
          for (int i = 1; i <= N; i++) arr[i] = 1;
          for (int i = 1; i <= N; i++) modify(i, 1);
          for (int i = 1; i < N; i++) {
              cin >> a >> b;
              t[a].push_back(b);
          }
          dfs(1);
          cin >> M;
          for (int i = 0; i < M; i++) {
              cin >> q >> a;
              if (q == "Q") {
                  int res = getsum(E[a]) - getsum(S[a] - 1);
                  cout << res << endl;
              }
              else {
                  int value = arr[a] == 0 ? 1 : -1;
                  arr[a] += value;
                  modify(S[a], value);
              }
          }
      }
      

Generalized BIT

use C[k] to store Maximum, Minimum, ...

(复杂度不优秀,代码量也不低,还是用线段树八

// maximum BIT
const int maxn = 100;
int N;
int arr[maxn], bit[maxn];

int lowbit(int x) { return x & (-x); }

void init(int n) {
    for (int i = 1; i <= n; i++) {
        bit[i] = arr[i];
        for (int j = 1; j < lowbit(i); j *= 2) 
            bit[i] = max(bit[i], bit[i - j]);
    }
}

// O(lg^2n)
void update(int i, int x) {
    arr[i] = x;
    for (i; i <= N; i += lowbit(i)) {
        bit[i] = x;
        for (int j = 1; j < lowbit(i); j *= 2)
            bit[i] = max(bit[i], bit[i - j]);
    }
}

// O(lg^2n)
int query(int i, int j) {
    int ans = 0;
    while (j >= i) {
        ans = max(ans, arr[j]);
        for (j-=1; j - lowbit(j) >= i; j -= lowbit(j))
            ans = max(ans, bit[j]);
    }
    return ans;
}
  • Examples

    • 最长上升子序列

      DP is \(O(N^2)\). We need \(O(Nlog\ N)\):

    // binary_search based method, also NlogN.
    #define _CRT_SECURE_NO_WARNINGS
    #include <iostream>
    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    using namespace std;
    
    const int maxn = 300005;
    const int inf = 0x7fffffff;
    int N;
    int arr[maxn], dp[maxn];
    
    int main() {
      scanf("%d", &N);
      for (int i = 0; i < N; i++) scanf("%d", arr + i);
      for (int i = 0; i < N; i++) dp[i] = inf;
      int ans = 0;
      for (int i = 0; i < N; i++) {
          int idx = lower_bound(dp, dp + N, arr[i]) - dp;
          dp[idx] = arr[i];
          if (idx == ans) ans++;
      }
      cout << ans << endl;
    }
    

2-Dim Tree-like BIT

the tree structure of a 2D BIT in 10x10:

  1   2   1   4   1   2   1   8   1   2
  2   4   2   8   2   4   2  16   2   4
  1   2   1   4   1   2   1   8   1   2
  4   8   4  16   4   8   4  32   4   8
  1   2   1   4   1   2   1   8   1   2
  2   4   2   8   2   4   2  16   2   4
  1   2   1   4   1   2   1   8   1   2
  8  16   8  32   8  16   8  64   8  16
  1   2   1   4   1   2   1   8   1   2
  2   4   2   8   2   4   2  16   2   4
  • Examples: Mobile phones

    #include <iostream>
    #include <algorithm>
    #include <cstring>
    #include <string>
    #include <vector>
    using namespace std;
    
    const int maxN = 1024 + 5;
    int N;
    
    int bit[maxN][maxN];
    
    int lowbit(int x) {
      return x & (-x);
    }
    
    void modify(int x, int y, int v) {
      for (int i = x; i <= N; i += lowbit(i)) {
          for (int j = y; j <= N; j += lowbit(j)) {
              bit[i][j] += v;
          }
      }
    }
    
    int getsum(int x, int y) {
      int res = 0;
      for (int i = x; i > 0; i -= lowbit(i)) {
          for (int j = y; j > 0; j -= lowbit(j)) {
              res += bit[i][j];
          }
      }
      return res;
    }
    
    int getsquare(int x1, int y1, int x2, int y2) {
      return getsum(x2, y2) - getsum(x2, y1-1) - getsum(x1-1, y2) + getsum(x1-1, y1-1);
    }
    
    void print() {
      for (int x = 1; x <= N; x++) {
          for (int y = 1; y <= N; y++) {
              cout << bit[x][y] << " ";
          }
          cout << endl;
      }
    }
    
    int main() {
      memset(bit, 0, sizeof(bit));
      int n, a, b, c, d;
      bool flag = true;
      while (cin >> n) {
          switch (n) {
          case 0:
              cin >> N;
              break;
          case 1:
              cin >> a >> b >> c;
              modify(a+1, b+1, c);
              break;
          case 2:
              cin >> a >> b >> c >> d;
              cout << getsquare(a+1, b+1, c+1, d+1) << endl;
              break;
          case 3:
              flag = false;
              break;
          }
          if (!flag) break;
      }
    }