Python - Tree traversal question

orokusaki picture orokusaki · Jun 6, 2011 · Viewed 11.9k times · Source

I have a hard time with tree traversal, and so avoid it like the plague... normally.

I have a class that's sort-of (slightly simplified version here, but functionally the same) like:

class Branch(object):
    def __init__(self, title, parent=None):
        self.title = title
        self.parent = parent

I have a dictionary of a bunch of Branch instances, the titles of each as the keys:

tree = {'Foo Branch': foo, 'Sub-Foo Branch': sub_foo, 'Bar Branch': bar}

Now, I know that there are complex algorithms for making traversal efficient (e.g. MPTT, et al), particularly for use with database-driven projects where efficiency matters the most. I'm not using the database at all, only simple in-memory objects.

Given the title of a Branch, I need to get a list of all descendants of that branch (children, children's children, so-on) from tree, so:

  1. Would you still recommended using a complicated (for my algo-less brain :) algorithm like MPTT for efficiency in my case, or is there a simple way to achieve this in a single function?
  2. If so, which one would you recommend, knowing I'm not using a database?
  3. Can you provide an example, or is this much larger than I'm thinking?

Note: This isn't a homework assignment. I'm not in school. I'm really just this bad at algorithms. I've used Django MPTT for a project which required DB-stored trees... but still don't understand it very well.

Answer

ninjagecko picture ninjagecko · Jun 6, 2011

http://en.wikipedia.org/wiki/Depth-first_search

http://en.wikipedia.org/wiki/Tree_traversal

You traverse as follows in two passes:

  • First pass: Search for the query node with the appropriate key. (This step is unnecessary if you have a hashmap of all the nodes in the entire tree; you have this (good) so this step is not necessary.)

  • Second pass: Call a modified version of the algorithm on the query node, but this time, whenever you visit a node, yield it (or append it to a nonlocal accumulator variable).

However your situation is a bit odd, because normally trees have pointers to children as well, sort of like a double-linked list. Unfortunately we don't have that information... but fortunately it's easy to add that information:

nodes = tree.values()
for node in nodes:
    if node.parent:
        if not hasattr(node.parent, 'children'):
            node.parent.children = []
        node.parent.children +=[ node ]

Now we can proceed with our example:

def traverse(root, callback):
    """
        Peform callback on all nodes in depth-first order
        e.g. traverse(root, lambda x:print(x))
    """
    yield root, callback(root)
    for child in root.children:
        traverse(child)

def getAllDescendents(title):
    queryNode = titlesToNodes[title]  #what you call 'tree'
    for node,blah in traverse(queryNode, lambda x:None):
        yield node