Find K-th Smallest Element in BST

Posted in

Vinay Khatri
Last updated on September 17, 2024

Problem

Given a Binary Search Tree, find the k-th smallest element in the BST.

Sample Input

`K = 2`

`10`

Approach

We know that the inorder traversal of a BST gives a sorted order of the nodes. Therefore, we can perform the inorder traversal of the given tree and keep track of the number of nodes visited so far. If this number reaches k , we find the kth smallest element.

Complexity Analysis

The time complexity is O(N), and the space complexity is also O(N)

C++ Programming

```#include <iostream>
using namespace std;

struct TreeNode {
int data;
TreeNode *left, *right;
TreeNode(int val)
{
data = val;

left = right = NULL;
}
};

TreeNode* create(TreeNode* root, int val)
{
if (root == NULL)
return new TreeNode(val);

if (val < root->data)
root->left = create(root->left, val);

else if (val > root->data)
root->right = create(root->right, val);

return root;
}

TreeNode* dfs(TreeNode* root, int& k)
{
// base case
if (!root)
return NULL;

// search in left subtree
TreeNode* left = dfs(root->left, k);

// if k'th smallest is found in left subtree
if (left != NULL)
return left;

// if current element is k'th smallest
k--;

if (k == 0)
return root;

// search in right subtree
return dfs(root->right, k);
}

void solve(TreeNode* root, int k)
{
// maintain index to count number of nodes processed so far
int count = 0;

TreeNode* res = dfs(root, k);
cout << "K-th Smallest Element is " << res->data;
}

int main()
{
TreeNode* root = NULL;

int keys[] = { 1, 2, 3, 4, 5 };

for (int val : keys)
root = create(root, val);

int k = 3;

solve(root, k);

return 0;
}```

Output

```K-th Smallest Element is 3
```

Java Programming

```import java.io.*;

class TreeNode {
int val;
TreeNode left, right;
TreeNode(int itr)
{
val = itr;

left = right = null;
}
}

class Solution {
static int count = 0;

// Recursive function to create an key into BST
public static TreeNode create(TreeNode root, int itr)
{
if (root == null)
return new TreeNode(itr);

if (itr < root.val)
root.left = create(root.left, itr);

else if (itr > root.val)
root.right = create(root.right, itr);

return root;
}

public static TreeNode dfs(TreeNode root, int k)
{
// base case
if (root == null)
return null;

// search in left subtree
TreeNode left = dfs(root.left, k);

// if k'th smallest is found in left subtree
if (left != null)
return left;

// if current element is k'th smallest
count++;

if (count == k)
return root;

// search in right subtree
return dfs(root.right, k);
}

public static void solve(TreeNode root, int k)
{
// maintain an index to count number of
// nodes processed so far
count = 0;

TreeNode ans = dfs(root, k);
System.out.println("K-th smallest element is " + ans.val);
}

public static void main (String[] args) {

TreeNode root = null;

int nodes[] = {1, 2, 3, 4};

for (int itr : nodes)
root = create(root, itr);

int k = 3;

solve(root, k);
}
}```

Output

```K-th smallest element is 3
```

Python Programming

```class TreeNode:
def __init__(self, node):
self.data = node
self.left = None
self.right = None

# Recursive function to create a node into BST
def create(root, itr):
if (root == None):
return TreeNode(itr)

if (itr < root.data):
root.left = create(root.left, itr)

elif (itr > root.data):
root.right = create(root.right, itr)
return root

def dfs(root):
global k

if (root == None):
return None

# Search in left subtree
left = dfs(root.left)

# If k'th smallest is found in
# left subtree
if (left != None):
return left

# If current element is k'th
# smallest
k -= 1

if (k == 0):
return root

# search in right subtree
return dfs(root.right)

def solve(root):
# Maintain index to count number
# of nodes processed so far
count = 0
res = dfs(root)
print("K-th smallest element is ", res.data)

root = None

nodes = [ 1, 2, 3, 4 ]

for itr in nodes:
root = create(root, itr)

k = 3

solve(root)```

Output

`K-th smallest element is  3`

People are also reading: