Question
Given a binary search tree, write a function kthSmallest
to find the kth smallest element in it.
Note:
You may assume k is always valid, 1 ≤ k ≤ BST's total elements.
Follow up
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?
Hint:
- Try to utilize the property of a BST.
- What if you could modify the BST node's structure?
- The optimal runtime complexity is O(height of BST).
Solution 1 -- Inorder Traversal
Again, we use the feature of inorder traversal of BST. But this solution is not best for follow up. Time complexity O(n), n is the number of nodes.
1 /** 2 * Definition for a binary tree node. 3 * public class TreeNode { 4 * int val; 5 * TreeNode left; 6 * TreeNode right; 7 * TreeNode(int x) { val = x; } 8 * } 9 */ 10 public class Solution { 11 public int kthSmallest(TreeNode root, int k) { 12 TreeNode current = root; 13 Stack<TreeNode> stack = new Stack<TreeNode>(); 14 while (current != null || !stack.empty()) { 15 if (current != null) { 16 stack.push(current); 17 current = current.left; 18 } else { 19 TreeNode tmp = stack.pop(); 20 k--; 21 if (k == 0) { 22 return tmp.val; 23 } 24 current = tmp.right; 25 } 26 } 27 return -1; 28 } 29 }
Solution 2 -- Augmented Tree
The idea is to maintain rank of each node. We can keep track of elements in a subtree of any node while building the tree. Since we need K-th smallest element, we can maintain number of elements of left subtree in every node.
Assume that the root is having N nodes in its left subtree. If K = N + 1, root is K-th node. If K < N, we will continue our search (recursion) for the Kth smallest element in the left subtree of root. If K > N + 1, we continue our search in the right subtree for the (K – N – 1)-th smallest element. Note that we need the count of elements in left subtree only.
Time complexity: O(h) where h is height of tree.
(referrence: GeeksforGeeks)
Here, we construct tree in a way that is taught during Algorithm class.
"size" is an attribute which indicates number of nodes in sub-tree rooted in that node.
Time complexity: constructing tree O(n), find Kth smallest number O(h).
start: if K = root.leftElement + 1 root node is the K th node. goto stop else if K > root.leftElements K = K - (root.leftElements + 1) root = root.right goto start else root = root.left goto srart stop
1 /** 2 * Definition for a binary tree node. 3 * public class TreeNode { 4 * int val; 5 * TreeNode left; 6 * TreeNode right; 7 * TreeNode(int x) { val = x; } 8 * } 9 */ 10 class ImprovedTreeNode { 11 int val; 12 int size; // number of nodes in the subtree that rooted in this node 13 ImprovedTreeNode left; 14 ImprovedTreeNode right; 15 public ImprovedTreeNode(int value) {val = value;} 16 } 17 18 public class Solution { 19 20 // Construct ImprovedTree recursively 21 public ImprovedTreeNode createAugmentedBST(TreeNode root) { 22 if (root == null) 23 return null; 24 ImprovedTreeNode newHead = new ImprovedTreeNode(root.val); 25 ImprovedTreeNode left = createAugmentedBST(root.left); 26 ImprovedTreeNode right = createAugmentedBST(root.right); 27 newHead.size = 1; 28 if (left != null) 29 newHead.size += left.size; 30 if (right != null) 31 newHead.size += right.size; 32 newHead.left = left; 33 newHead.right = right; 34 return newHead; 35 } 36 37 public int findKthSmallest(ImprovedTreeNode root, int k) { 38 if (root == null) 39 return -1; 40 ImprovedTreeNode tmp = root; 41 int leftSize = 0; 42 if (tmp.left != null) 43 leftSize = tmp.left.size; 44 if (leftSize + 1 == k) 45 return root.val; 46 else if (leftSize + 1 > k) 47 return findKthSmallest(root.left, k); 48 else 49 return findKthSmallest(root.right, k - leftSize - 1); 50 } 51 52 public int kthSmallest(TreeNode root, int k) { 53 if (root == null) 54 return -1; 55 ImprovedTreeNode newRoot = createAugmentedBST(root); 56 return findKthSmallest(newRoot, k); 57 } 58 }