LeetCode 222. Count Complete Tree Nodes. Google interview question.
Given a complete binary tree, count the number of nodes.
Note:
Definition of a complete binary tree from Wikipedia:
In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.
Example:
Input:
1
/ \
2 3
/ \ /
4 5 6Output: 6
Asked recently by Google, Amazon and Microsoft.
The trivial solution is to traverse the tree and count the number of nodes and that would cost linear time. Here the interviewer is looking for something faster. Using our runtime-algorithm cheat sheet, we found the interviewer is likely asking for a log(n) algorithm which would involve some kind of binary search-ish techniques.
How many nodes does a completely-filled binary tree have
A completely tree with every level filled has 2^n — 1
number of nodes. E.g.
1
/ \
2 3
/ \ /\
4 5 6 7# levels = 3# of nodes = 2^3 - 1 = 7
Depth 0 has 2^0
nodes, depth 1 has 2^1
nodes, depth 2 has 2^2
nodes… which forms a geometric series. Applying the summation formula we can get the total number of node 2^n -1
.
Knowing the tree in question is completely filled except the last level, the key to find the boundary between filled and unfilled in the last level. When we traverse the tree, we can use depth as an indicator of whether the subtree is filled or not. We know depth of left subtree ≥ right subtree because we fill the level from left to right.
Case 1: left subtree depth > right subtree depth
leve 0: 1
/ \
level 1: 2 3
/
level 2: 4
Starting from 1, left subtree has greater depth than right subtree which means right subtree is completely filled up to level 1. We can calculate the number of nodes in the right subtree using formula 2^n-1
directly.
Case 2: left subtree depth == right subtree depth
leve 0: 1
/ \
level 1: 2 3
/ \ /
level 2: 4 5 6
In this case, we know the left subtree (rooted at 2) is completely filled. We can calculate the number of nodes using the 2^n-1
formula.
Python implementation:
class Solution:
def countNodes(self, root):
if not root:
return 0
leftDepth = self.getDepth(root.left)
rightDepth = self.getDepth(root.right)
if leftDepth == rightDepth:
return pow(2, leftDepth) + self.countNodes(root.right) # 2^n -1 + 1 (current node) = 2^n
else:
return pow(2, rightDepth) + self.countNodes(root.left)def getDepth(self, root):
if not root:
return 0
return 1 + self.getDepth(root.left)
Time Complexity
Finding depth is O(log(N)), and in the worse case we do this for every depth. The worse case total run time is (log(N)²).