Let us consider the following 2 problems to understand Segment Tree.
Problem 1 - Sum of Given Range
We have an array arr[0 . . . n-1]. We should be able to
1 Find the sum of elements from index l to r where 0 <= l <= r <= n-1
2 Change value of a specified element of the array arr[i] = x where 0 <= i <= n-1.
A simple solution is to run a loop from l to r and calculate sum of elements in given range. To update a value, simply do arr[i] = x. The first operation takes O(n) time and second operation takes O(1) time.
Another solution is to create another array and store sum from start to i at the ith index in this array. Sum of a given range can now be calculated in O(1) time, but update operation takes O(n) time now. This works well if the number of query operations are large and very few updates.
We can use a Segment Tree to do both operations in O(Logn) time.
Problem 2 - Range Minimum Query
We have an array arr[0 . . . n-1]. We should be able to efficiently find the minimum value from index qs (query start) to qe (query end) where 0 <= qs <= qe <= n-1. The array is static (elements are not deleted and inserted during the series of queries).
A simple solution is to run a loop from qs to qe and find minimum element in given range. This solution takes O(n) time in worst case.
Another solution is to create a 2D array where an entry [i, j] stores the minimum value in range arr[i..j]. Minimum of a given range can now be calculated in O(1) time, but preprocessing takes O(n^2) time. Also, this approach needs O(n^2) extra space which may become huge for large input arrays.
With Segment Tree, preprocessing time is O(n), time for range minimum query is O(Logn). The extra space required is O(n) to store the segment tree.
Implementation ( e.g. just for problem 1)
1. Create class STNode to store nodes.
1 package Tree; 2 /* 3 * O(log n) to 4 * - Find the sum from array[l] to array[r] 5 * - Change value of array[i] to x 6 */ 7 public class SegmentTree { 8 private int[] array = null; 9 STNode root; 10 11 class STNode { 12 int start, end; 13 int sum; 14 STNode left, right; 15 public STNode(int start, int end) { 16 this.start = start; 17 this.end = end; 18 } 19 } 20 21 public SegmentTree(int[] arr) { 22 array = arr; 23 int start = 0, end = array.length-1; 24 root = construct(array, start, end); 25 } 26 public STNode construct(int[] arr, int start, int end) { 27 if (start > end) return null; 28 STNode node = new STNode(start, end); 29 if (start == end) { 30 node.sum = arr[start]; 31 return node; 32 } 33 int mid = (end - start) / 2 + start; 34 node.left = construct(arr, start, mid); 35 node.right = construct(arr, mid+1, end); 36 node.sum = node.left.sum + node.right.sum; 37 return node; 38 } 39 40 public int getSum(STNode cur, int left, int right) { 41 if (cur == null) return 0; 42 if (left <= cur.start && right >= cur.end) { 43 return cur.sum; 44 } 45 if (left > cur.end || right < cur.start) 46 return 0; 47 int mid = (cur.end - cur.start) / 2 + cur.start; 48 if (left <= mid && mid < right) 49 return getSum(cur.left, left, mid) + getSum(cur.right, mid+1, right); 50 else if (left > mid) return getSum(cur.right, left, right); 51 else if (right <= mid) return getSum(cur.left, left, right); 52 return 0; 53 } 54 55 public int updateVal(STNode cur, int newVal, int idx) { 56 if (cur.start == idx && cur.end == idx) { 57 int ret = cur.sum - newVal; 58 cur.sum = newVal; 59 return ret; 60 } 61 int mid = (cur.end - cur.start) / 2 + cur.start, diff = 0; 62 if (idx <= mid) { 63 diff = updateVal(cur.left, newVal, idx); 64 } else { 65 diff = updateVal(cur.right, newVal, idx); 66 } 67 cur.sum -= diff; 68 return diff; 69 } 70 71 public static void main(String[] args) { 72 int[] arr = {1, 3, 5, 7, 9, 11}; 73 SegmentTree st = new SegmentTree(arr); 74 st.getSum(st.root, 0, 4); 75 st.updateVal(st.root, 8, 1); 76 System.out.println(st.root.sum); 77 } 78 }
2. Use array to store nodes.