
Description
Given a binary search tree (BST), find the k-th smallest element in the tree. The BST properties allow for efficient searching and traversal. The goal is to return the k-th smallest element when the nodes are visited in ascending order.
Intuition
The in-order traversal of a BST visits nodes in ascending order. Thus, the k-th smallest element in the BST corresponds to the k-th element visited during an in-order traversal. By leveraging a stack to perform an iterative in-order traversal, we can efficiently find the k-th smallest element.
Approach
- Initialization:
- Initialize a stack to simulate the in-order traversal iteratively.
- Use a counter
n
to keep track of the number of nodes visited. - In-Order Traversal:
- Start with the root node. Traverse the left subtree first, pushing nodes onto the stack until reaching the leftmost node.
- Pop nodes from the stack one by one, incrementing the counter
n
for each node visited. - Ifn
equalsk
, return the value of the current node as the k-th smallest element. - After visiting the current node, move to its right subtree and continue the traversal. - Termination:
- The function returns the k-th smallest element within the loop if
k
is valid. - If the loop completes without finding the k-th smallest element (which shouldn't happen ifk
is within the valid range), return a default value (e.g.,-1
).
Complexity
Time Complexity:
O(n) in the worst case, where n is the number of nodes in the BST. This is because, in the worst case, we might need to visit all nodes.
Space Complexity:
O(h), where h is the height of the BST. The stack's size is proportional to the height of the tree, which is O(log n) for a balanced BST and O(n) for a completely unbalanced tree.
Code
C++
class Solution {
public:
int kthSmallest(TreeNode* root, int k) {
int n = 0;
stack stk;
TreeNode* current = root;
while (current || !stk.empty()) {
while (current) {
stk.push(current);
current = current->left;
}
current = stk.top();
stk.pop();
n += 1;
if (n == k) return current->val;
current = current->right;
}
return -1;
}
};
Python
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
stack = []
current = root
n = 0
while current or stack:
while current:
stack.append(current)
current = current.left
current = stack.pop()
n += 1
if n == k:
return current.val
current = current.right
return -1 # This return is just to satisfy the function signature, the function should always return within the loop if k is valid.
Java
class Solution {
public int kthSmallest(TreeNode root, int k) {
Stack stack = new Stack<>();
TreeNode current = root;
int n = 0;
while (current != null || !stack.isEmpty()) {
while (current != null) {
stack.push(current);
current = current.left;
}
current = stack.pop();
n++;
if (n == k) return current.val;
current = current.right;
}
return -1;
}
}
JavaScript
var kthSmallest = function(root, k)
{
let n = 0
let stack = []
let current = root
while (current || stack.length > 0)
{
while (current) {
stack.push(current)
current = current.left
}
current = stack.pop()
n += 1
if (n === k) return current.val
current = current.right
}
};