Find median in binary search tree

Natali Ayoub picture Natali Ayoub · May 1, 2015 · Viewed 15.6k times · Source

Write the implementation of the function T ComputeMedian() const that computes the median value in the tree in O(n) time. Assume that the tree is a BST but is not necessarily balanced. Recall that the median of n numbers is defined as follows: If n is odd, the median is x such that the number of values smaller than x is equal to the number of values greater than x. If n is even, then one plus the number of values smaller than x is equal to the number of values greater than x. For example, given the numbers 8, 7, 2, 5, 9, the median is 7, because there are two values smaller than 7 and two values larger than 7. If we add number 3 to the set, the median becomes 5.

Here is the class of binary search tree node:

template <class T>
class BSTNode
{
public:
BSTNode(T& val, BSTNode* left, BSTNode* right);
~BSTNode();
T GetVal();
BSTNode* GetLeft();
BSTNode* GetRight();

private:
T val;
BSTNode* left;
BSTNode* right;  
BSTNode* parent; //ONLY INSERT IS READY TO UPDATE THIS MEMBER DATA
int depth, height;
friend class BST<T>;
};

Binary search tree class:

template <class T>
class BST
{
public:
BST();
~BST();

bool Search(T& val);
bool Search(T& val, BSTNode<T>* node);
void Insert(T& val);
bool DeleteNode(T& val);

void BFT(void);
void PreorderDFT(void);
void PreorderDFT(BSTNode<T>* node);
void PostorderDFT(BSTNode<T>* node);
void InorderDFT(BSTNode<T>* node);
void ComputeNodeDepths(void);
void ComputeNodeHeights(void);
bool IsEmpty(void);
void Visit(BSTNode<T>* node);
void Clear(void);

private:
BSTNode<T> *root;
int depth;
int count;
BSTNode<T> *med; // I've added this member data.

void DelSingle(BSTNode<T>*& ptr);
void DelDoubleByCopying(BSTNode<T>* node);
void ComputeDepth(BSTNode<T>* node, BSTNode<T>* parent);
void ComputeHeight(BSTNode<T>* node);
void Clear(BSTNode<T>* node);

};

I know I should count the nodes of the tree first and then do an inorder traversal until I reach (n/2)th node and return it. I just have no clue how.

Answer

amit picture amit · May 1, 2015

As you mentioned, it is fairly easy to first find the number of nodes, doing any traversal:

findNumNodes(node):
   if node == null:
       return 0
   return findNumNodes(node.left) + findNumNodes(node.right) + 1

Then, with an inorder traversal that aborts when the node number is n/2:

// index is a global variable / class variable, or any other variable that is constant between all calls
index=0
findMedian(node):
   if node == null:
       return null
   cand = findMedian(node.left)
   if cand != null:
        return cand
   if index == n/2:
       return node
   index = index + 1
   return findMedian(node.right)

The idea is that in-order traversal processes nodes in BST in sorted manner. So, since the tree is a BST, the ith node you process, is the ith node in order, this is of course also true for i==n/2, and when you find it is the n/2th node, you return it.


As a side note, you can add functionality to BST to find ith element efficiently (O(h), where h is the tree's height), using order statistics trees.