Delete Nodes And Return Forest

By | November 28, 2021

Problem

Given a binary tree and a list of nodes. Return the forest after removing the given nodes from the tree. A forest is a disjoint set of trees. Each tree in the forest should be in the in-order traversal.

Sample Input

arr = {10, 5} 

      10
      /  \
    20    30
   /  \     \
  4    5     7

Sample Output

4 20
30 7

Explanation

Once, 10 and 5 are removed, we are left two trees:

   20      and    30
 /    \
4      7

Approach

We can observe that if the parent of a node is present in the input array, then this node will become the root of a new tree that will be a part of the forest. We will use the same logic for every node and keep creating a new tree whenever the above-mentioned condition becomes true.

Below is the algorithm to solve this problem:

  1. Perform the Binary Tree Postorder Traversal.
  2. Check each node to see if it has the value to be removed.
  3. If it is found to be true, save its child as a forest’s root.
  4. Then, traverse the trees of the forest using the roots we inserted in the previous step.

Complexity Analysis

The time complexity is O(N) and the space complexity is also O(N) due to the call stack.

Vamware

C++ Programming

#include <bits/stdc++.h>
using namespace std;

unordered_map<int, bool> mp;

struct TreeNode {
    int key;
    struct TreeNode *left, *right;
};

// create a new node
TreeNode* newNode(int key)
{
    TreeNode* temp = new TreeNode;
    temp->key = key;
    temp->left = temp->right = NULL;
    return (temp);
}

bool deleteNode(int val)
{
    return mp.find(val) != mp.end();
}

// Function to perform tree pruning
TreeNode* PruneTree(TreeNode* root, vector<TreeNode*>& ans)
{
    if (root == NULL)
        return NULL;

    root->left = PruneTree(root->left, ans);
    root->right = PruneTree(root->right, ans);

    // If the node needs to be deleted
    if (deleteNode(root->key)) {
        // Store the its subtree
        if (root->left) {
            ans.push_back(root->left);
        }

        if (root->right) {
            ans.push_back(root->right);
        }
        return NULL;
    }
    return root;
}

// Perform Inorder Traversal
void inOrder(TreeNode* root)
{
    if (root == NULL)
        return;

    inOrder(root->left);
    cout << root->key << " ";
    inOrder(root->right);
}

void solve(TreeNode* root, int arr[], int n)
{
    for (int i = 0; i < n; i++) {
        mp[arr[i]] = true;
    }

    vector<TreeNode*> ans;

    if (PruneTree(root, ans))
        ans.push_back(root);

    // Print the inorder traversal trees
    for (int i = 0; i < ans.size(); i++) {
        inOrder(ans[i]);
        cout << "\n";
    }
}

int main()
{
//         1
//       /   \
//      2     3
//           /  \
//          4    5
    TreeNode* root = newNode(1);
    root->left = newNode(2);
    root->right = newNode(3);
    root->right->left = newNode(4);
    root->right->right = newNode(5);

    int arr[] = { 1 };
    int n = sizeof(arr) / sizeof(arr[0]);
    solve(root, arr, n);

}

Output

2 
4 3 5

Java Programming

import java.util.*;

class Solution{

static HashMap<Integer,Boolean> mp = new HashMap<>();

// Respresents each node
static class TreeNode
{
    int val;
    TreeNode left, right;
};

// Function to create a new node
static TreeNode newNode(int val)
{
    TreeNode temp = new TreeNode();
    temp.val = val;
    temp.left = temp.right = null;
    return (temp);
}

// Function to check whether the node
// needs to be deleted or not
static boolean deleteNode(int nodeVal)
{
    return mp.containsKey(nodeVal);
}

// Function to perform tree pruning
static TreeNode PruneTree(TreeNode root, Vector<TreeNode> result)
{
    if (root == null)
        return null;

    root.left = PruneTree(root.left, result);
    root.right = PruneTree(root.right, result);

    // If the node needs to be deleted
    if (deleteNode(root.val))
    {
        // Store the its subtree
        if (root.left != null)
        {
            result.add(root.left);
        }

        if (root.right != null)
        {
            result.add(root.right);
        }
        return null;
    }
    return root;
}

// Perform Inorder Traversal
static void inOrder(TreeNode root)
{
    if (root == null)
        return;

    inOrder(root.left);
    System.out.print(root.val + " ");
    inOrder(root.right);
}

// Function to print the forests
static void solve(TreeNode root, int arr[], int n)

{
    for (int i = 0; i < n; i++)
    {
        mp.put(arr[i], true);
    }

    Vector<TreeNode> result = new Vector<>();

    if (PruneTree(root, result) != null)
        result.add(root);

    // Print the inorder traversal of each tree
    for (int i = 0; i < result.size(); i++)
    {
        inOrder(result.get(i));
        System.out.println();
    }
}

public static void main(String[] args)
{
/*
         1
       /   \
      2     3
           / \
          4   5
*/
    TreeNode root = newNode(1);
    root.left = newNode(2);
    root.right = newNode(3);
    root.right.left = newNode(4);
    root.right.right = newNode(5);

    int arr[] = { 1 };
    int n = arr.length;
    solve(root, arr, n);

}
}

Output

2 
4 3 5

Python Programming

mp = dict()

# represents each node
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

# create a new node
def newNode(val):
    temp = TreeNode(val)
    return temp

def deleteNode(nodeVal):
    if nodeVal in mp:
        return True
    else:
        return False

# perform tree pruning
def PruneTree( root, result):
    if (root == None):
        return None;

    root.left = PruneTree(root.left, result);
    root.right = PruneTree(root.right, result);

    # If the node needs to be deleted
    if (deleteNode(root.val)):
        # Store the its subtree
        if (root.left):
            result.append(root.left);

        if (root.right):
            result.append(root.right);

        return None;

    return root;

# inorder Traversal
def inOrder(root):
    if (root == None):
        return;

    inOrder(root.left);
    print(root.val, end=' ')
    inOrder(root.right);

# Function to print the forests
def solve(root, arr, n):
    for i in range(n):
        mp[arr[i]] = True;

    # Stores the remaining nodes
    result = []

    if (PruneTree(root, result)):
        result.append(root)

    # Print the inorder traversal of trees
    for i in range(len(result)):
        inOrder(result[i]);
        print()
        
"""
        1
       /   \
      2     3
           / \
          4   5
"""
root = newNode(1)
root.left = newNode(2)
root.right = newNode(3)
root.right.left = newNode(4)
root.right.right = newNode(5)
arr = [ 1 ]
n = len(arr)
solve(root, arr, n)

Output

2 
4 3 5

People are also reading:

Author: Vinay

I am a Full Stack Developer with a Bachelor's Degree in Computer Science, who also loves to write technical articles that can help fellow developers.

Leave a Reply

Your email address will not be published. Required fields are marked *