1. Description

A Segment Tree is a binary tree that allows efficient:

  • do range queries like (sum, max, min) in O(log n)
  • do point updates in O(log n)

useful materials:

Youtube: Segment Tree Data Structure - Min Max Queries - Java source code

cp-algorithm: segment tree

2. Data Structure

Assume the range is of size n, the tree can be represented as an array of size 2 * n or 4 * n, depending on the way you go through the tree

2.1. bottom up iterative approach (tree as a flat array of size 2 * n)

internal nodes are in the first half of the array, while leaf nodes are in the last half. We ususally starting the updates/query from the leaf nodes and propagate to the parent nodes, stop until hit the root

The following is the template to update:

vector<int> arr(n);
vector<int> tree(2*n)

// update the segment tree leaf node and populate the effect to the parent node
void update(vector<int>& tree, int index, int incr, int size){
  index += size;
  // update the leaf node
  tree[index] += incr;
  // update the parent
  while(index>1){
    index /= 2;
    tree[index] = tree[2*index] + tree[2*index+1];     // <<< here it's query the range sum, you can update to other operations
  }
}

// return the sum in the range [left, right)          // <<< right is not included
int query(vector<int>& tree, int left, int right, int size) {
  left += size;
  right += size;
  int ans = 0;
  while(left<right){
    if(left & 1) {
      // left is a right node
      ans += tree[left];                              // <<< here it's query the range sum, you can update to other operations
      left ++;
    }
    // ow left is a left node and will be consided in its parent node
    // do nothing here

    // move to parent
    left /= 2;

    if(right & 1) {
      // right is a right node, ignore it
      right --;
      // but need to take care of the left neighbour
      ans += tree[right];                          // <<< here it's query the range sum, you can update to other operations
    }
    right /= 2;
  }

  return ans;
}

how to build the tree? The first way is to use update fucntion for n times to push the initial array into the tree. The other way is the following:

vector<int> arr(n);
vector<int> tree(2*n)

void build(vector<int>& tree, vector<int>& arr, int size) {
  // copy the arr into the last half the tree array
  copy(arr.begin(), arr.end(), tree.begin()+size);

  // populate the internal nodes from right to left
  for(int i=size-1; i>=1; i++){
    tree[i] = tree[2*i] + tree[2*i+1];                   // <<< here it's query the range sum, you can update to other operations
  }
  return;
}

2.2. top down recursive approach (tree size 4 * n)

template

void update(vector<int>& tree, int node, int left, int right, int index, int incr) {
  if(left == right) {
    tree[index] += incr;
    return;
  }

  int mid = (left + right) / 2;
  if(index > mid) {
    update(tree, 2*node+1, mid+1, right, index, incr);
  }
  else {
    update(tree, 2*node+2, left, mid, index, incr);
  }

  tree[node] = tree[2*node+1] + tree[2*node+2];
}

int query(vector<int>& tree, int node, int left, int right, int q_left, int q_right){

  if(q_left > q_right) {
    return 0;
  }

  if(q_left == left && q_right == right) {
    return tree[node];
  }

  int mid = (left + right) / 2;

  int l_res = query(tree, 2*node+1, left, mid, q_left, min(mid, q_right));
  int r_res = query(tree, 2*node+2, mid+1, right, max(q_left, mid+1), q_right);

  return l_res + r_res;
}

void build(vector<int>& tree, int node, int left, int right, vector<int>& data){
  if(left == right) {
    tree[node] = data[left];
    return;
  }
  int mid = (left + right) / 2;
  // build left
  build(tree, 2*node+1, left, mid, data);

  // build right
  build(tree, 2*node+2, mid+1, right, data);

  // update parent
  tree[node] = tree[2*node+1] + tree[2*node+2];

  return;
}