Implementing a depth-first tree iterator in Python

norman picture norman · Oct 1, 2014 · Viewed 8.1k times · Source

I'm trying to implement an iterator class for not-necessarily-binary trees in Python. After the iterator is constructed with a tree's root node, its next() function can be called repeatedly to traverse the tree in depth-first order (e.g., this order), finally returning None when there are no nodes left.

Here is the basic Node class for a tree:

class Node(object):

    def __init__(self, title, children=None):
        self.title = title
        self.children = children or []
        self.visited = False   

    def __str__(self):
        return self.title

As you can see above, I introduced a visited property to the nodes for my first approach, since I didn't see a way around it. With that extra measure of state, the Iterator class looks like this:

class Iterator(object):

    def __init__(self, root):
        self.stack = []
        self.current = root

    def next(self):
        if self.current is None:
            return None

        self.stack.append(self.current)
        self.current.visited = True

        # Root case
        if len(self.stack) == 1:
            return self.current

        while self.stack:
            self.current = self.stack[-1] 
            for child in self.current.children:
                if not child.visited:
                    self.current = child
                    return child

            self.stack.pop()

This is all well and good, but I want to get rid of the need for the visited property, without resorting to recursion or any other alterations to the Node class.

All the state I need should be taken care of in the iterator, but I'm at a loss about how that can be done. Keeping a visited list for the whole tree is non-scalable and out of the question, so there must be a clever way to use the stack.

What especially confuses me is this--since the next() function, of course, returns, how can I remember where I've been without marking anything or using excess storage? Intuitively, I think of looping over children, but that logic is broken/forgotten when the next() function returns!

UPDATE - Here is a small test:

tree = Node(
    'A', [
        Node('B', [
            Node('C', [
                Node('D')
                ]),
            Node('E'),
            ]),
        Node('F'),
        Node('G'),
        ])

iter = Iterator(tree)

out = object()
while out:
    out = iter.next()
    print out

Answer

mgilson picture mgilson · Oct 1, 2014

If you really must avoid recursion, this iterator works:

from collections import deque

def node_depth_first_iter(node):
    stack = deque([node])
    while stack:
        # Pop out the first element in the stack
        node = stack.popleft()
        yield node
        # push children onto the front of the stack.
        # Note that with a deque.extendleft, the first on in is the last
        # one out, so we need to push them in reverse order.
        stack.extendleft(reversed(node.children))

With that said, I think that you're thinking about this too hard. A good-ole' (recursive) generator also does the trick:

class Node(object):

    def __init__(self, title, children=None):
        self.title = title
        self.children = children or []

    def __str__(self):
        return self.title

    def __iter__(self):
        yield self
        for child in self.children:
            for node in child:
                yield node

both of these pass your tests:

expected = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
# Test recursive generator using Node.__iter__
assert [str(n) for n in tree] == expected

# test non-recursive Iterator
assert [str(n) for n in node_depth_first_iter(tree)] == expected

and you can easily make Node.__iter__ use the non-recursive form if you prefer:

def __iter__(self):
   return node_depth_first_iter(self)