Skip to content

并查集

并查集 Disjoint-Set

  • Operations required in O(1)

    Merge(A, B): merge two sets.

    Query(A, B): query if A and B are in the same set.

  • Naive Algorithms

    • Number each element.

      Query is O(1), but Merge is O(N)

    • Balanced Binary Tree

      merge()will add the depth of a tree.

      getroot() may be O(N), N is the tree's depth.

  • Path Compression (路径压缩)

    Improvement of BBT method, by rearranging every node directly to the root when the node is queried (this step may be O(n)). But next time querying it we need only O(1).

    int parent[N];
    
    void init(int n){
        for(int i=0;i<n;i++){
            parent[i]=i;
        }
    }
    
    // add other data
    int getRoot(int a){
        if(parent[a] != a)
            parent[a] = getRoot(parent[a]);
        return parent[a];
    }
    
    // add other data
    void merge(a, b){
        parent[getRoot(b)] = getRoot(a);
    }
    
    bool query(a, b){
        return getRoot(a) == getRoot(b);
    }
    

    getroot()平摊时间复杂度在N不是很大的时候是常数且不超过4N。

  • Tips

    • "In the same set (or have the same root) " means they have a certain relationship, not that they are in the same group.
    • Weighted Disjoint Set

      weights are used to show the relationship between this node and the root node. (not the parent node!!!)

      int p[N];
      int w[N];
      
      void init(int n){
          for (int i = 0;i < n;i++) p[i] = i, w[i] = 0;
      }
      
      int par(int x){
          if (x == p[x]) return x;
          int fx = p[x];
          p[x] = par(fx);
          w[x] = (w[x] + w[fx]) % 3; // iteratively compress w[]
          return p[x];
      }
      
      void merge(int x, int y){
          int fx = par(x), fy = par(y);
          p[fy] = fx;  // always only change parents
          w[fy] = (w[x] - w[y] + 3) % 3; 
      }
      
  • Examples

    • POJ1611 the suspects
    • POJ1988 cube stacking

      sum[N],under[N]

      When to change under[N] to keep O(1) time complexity?

      only change (original) root's under. When querying child node's under, sum it until root and do path compression.

      #include <iostream>
      using namespace std;
      
      const int maxn = 30005;
      
      int N, M;
      int p[maxn], sum[maxn], under[maxn];
      
      void init(int N) {
          for (int i = 1; i <= N; i++) {
              p[i] = i;
              sum[i] = 1;
              under[i] = 0;
          }
      }
      
      int par(int x) {
          if (p[x] == x) return x;
          int f = p[x];
          p[x] = par(f);
          under[x] += under[f];
          return p[x];
      }
      
      void merge(int x, int y) {
          int fx = par(x);
          int fy = par(y);
          p[fy] = fx;
          under[fy] = sum[fx];
          sum[fx] += sum[fy];
      }
      
      char c;
      int x, y;
      
      int main() {
          init(maxn);
          cin >> M;
          for (int i = 0; i < M; i++) {
              cin >> c;
              if (c == 'C') {
                  cin >> x;
                  par(x);
                  cout << under[x] << endl;
              }
              else {
                  cin >> x >> y;
                  merge(y, x);
              }
          }
      }
      
    • POJ1182 food chains

      use vector to deduce the formula of merge.

      Relationship between objects are usually a circuit.

      fx <- ??? - fy
      ^           ^|
      |w[x]   w[y]||-w[y]     ===> ??? = (w[x] - w[y] + d - 1 + 3) %3
      |           |v                              // +3 for safety
      x  <-(d-1)-- y          // d=1 means the same, d=2 means x eats y so w'[y] is 1.
      
      #include <iostream>
      #include <cstring>
      using namespace std;
      
      const int maxN = 50005;
      
      int par[maxN];
      int dis[maxN];
      
      void init(int n) {
          for (int i = 0; i <= n; i++)
              par[i] = i;
      }
      
      int parent(int x) {
          if (par[x] == x) return par[x];
          int f = par[x];
          par[x] = parent(f);
          dis[x] = (dis[x] + dis[f]) % 3;
          return par[x];
      }
      
      void merge(int d, int x, int y) {
          int fx = parent(x);
          int fy = parent(y);
          par[fy] = fx;
          dis[fy] = (dis[x] - dis[y] + d + 2) % 3;
      }
      
      bool query(int x, int y) {
          return parent(x) == parent(y);
      }
      
      int main() {
          int N, K, D, X, Y;
          int res = 0;
          memset(dis, 0, sizeof(dis));
          cin >> N >> K;
          init(N);
          for (int i = 0; i < K; i++) {
              cin >> D >> X >> Y;
              if (X > N || Y > N) {
                  res++;
                  continue;
              }
              if (!query(X, Y)) {
                  merge(D, X, Y);
              }
              else if (D == 1) {
                  if (dis[X] != dis[Y]) {
                      res++;
                  }
              }
              else if (D == 2) {
                  if (dis[Y] != (dis[X] + 1) % 3) {
                      res++;
                  }
              }
          }
          cout << res << endl;
      }
      
    • A bug's Life

      ditto but mod 2 for same or different.

      #include <iostream>
      #include <cstring>
      using namespace std;
      
      const int maxN = 2005;
      
      int p[maxN];
      int w[maxN];
      
      void init(int n) {
          for (int i = 1; i <= n; i++) {
              p[i] = i;
              w[i] = 0;
          }
      }
      
      int parent(int x) {
          if (p[x] == x) return p[x];
          int f = p[x];
          p[x] = parent(f);
          w[x] = (w[x] + w[f]) % 2;
          return p[x];
      }
      
      void merge(int x, int y) {
          int fx = parent(x);
          int fy = parent(y);
          p[fy] = fx;
          w[fy] = (w[x] - w[y] +1) % 2;
      }
      
      bool query(int x, int y) {
          return parent(x) == parent(y);
      }
      
      int main() {
          int S, N, I, a, b;
          bool flag = false;
          cin >> S;
          for (int s = 1; s <= S; s++) {
              cin >> N >> I;
              init(N);
              flag = false;
              if (s != 1) cout << endl;
              for (int i = 0; i < I; i++) {
                  cin >> a >> b;
                  if (flag) continue; 
                  // in the same tree, maybe suspicious
                  if (query(a, b)) {
                      if (w[a] == w[b]) {
                          flag = true;
                      }
                  }
                  else merge(a, b);
              }
              cout << "Scenario #" << s << ":" << endl;
              if(flag) cout << "Suspicious bugs found!" << endl;
              else cout << "No suspicious bugs found!" << endl;
          }
      }
      
    • Mayor's poster

      奇妙并查集。不使用merge,而仅仅把并查集当做一个链表,也可以用来高效描述区间的一个性质(是否被占据),而不用遍历区间(功能相当于最简化的线段树)。

      #include <iostream>
      #include <cstdio>
      #include <algorithm>
      #include <vector>
      #include <set>
      using namespace std;
      
      const int maxn = 20005; // 2*maxPosters, start+end
      int N;
      
      int par[maxn];
      
      void init(int n) {
          for (int i = 0; i <= n; i++) par[i] = i;
      }
      int findpar(int n) {
          if (n == par[n]) return n;
          par[n] = findpar(par[n]);
          return par[n];
      }
      
      // discretization
      vector<int> order;
      vector<pair<int, int>> posters;
      
      int getindex(vector<int>& order, int key) {
          return lower_bound(order.begin(), order.end(), key) - order.begin();
      }
      
      int main() {
          int cas, x, y;
          cin >> cas;
          while (cas--) {
              cin >> N;
              order.clear();
              posters.clear();
              for (int i = 0; i < N; i++) {
                  cin >> x >> y;
                  order.push_back(x);
                  order.push_back(y);
                  posters.push_back(pair<int, int>(x, y));
              }
              sort(order.begin(), order.end());
              // unique returns new vector's end iterator.
              order.erase(unique(order.begin(), order.end()), order.end());
              // discretize poster's regions
              for (int i = 0; i < N; i++) {
                  posters[i].first = getindex(order, posters[i].first);
                  posters[i].second = getindex(order, posters[i].second);
              }
              init(order.size());
              int res = 0;
              for (int i = N - 1; i >= 0; i--) {
                  bool flag = false;
                  for (int j = posters[i].first; j <= posters[i].second; j = findpar(j)) {
                      if (par[j] == j) {
                          flag = true;
                          par[j] = j + 1;
                          j++;
                      }
                  }
                  if (flag) res++;
              }
              cout << res << endl;
          }
      }