Backprop and topological sorting

Neural networks need to update their weights with respect to a loss function. Rather than compute these weight updates manually, today we usually use automatic differentiation packages and tools. We call these autodiff systems.

To do their job, autodiff systems construct computational graphs of the function they’re trying to find the gradient of. These computational graphs are key to backpropagation. I won’t go into how these graphs are constructed here.

Nodes make up a computational graph. Here’s an example of a sequence of nodes making up a graph. This might have been constructed by the autodiff system, or maybe the user did it manually.

Here it is in tree form:

An arrow from z to t1 means “z is the parent of t1”. This means t1 will hold a reference to z as its ‘parent’. . Note that it doesn’t mean that z will hold a reference to t1. There’s nothing denoting children in this structure, only parents.

Once the graph is constructed, we backpropagate through it, starting off at the end node (the one for the loss function) and then visiting the other nodes. This is how it works for reverse-mode differentiation.

The order you visit the nodes matters. It turns out that you want to visit nodes in the opposite order that you would topologically sort them. So autodiff systems like to topologically sort computation graphs, then reverse the ordering.

Looking through code for Autodidact, a somewhat simplified code autodiff library, there’s a generator function that yields nodes in a topological order. It looks like this:

def toposort(end_node):
    child_counts = {}
    stack = [end_node]
    while stack:
        node = stack.pop()
        if node in child_counts:
            child_counts[node] += 1
        else:
            child_counts[node] = 1
            stack.extend(node.parents)

    childless_nodes = [end_node]
    while childless_nodes:
        node = childless_nodes.pop()
        yield node
        for parent in node.parents:
            if child_counts[parent] == 1:
                childless_nodes.append(parent)
            else:
                child_counts[parent] -= 1

It’s a bit complicated, so I’ve rewritten it and added some comments to the code to explain how it works.

def get_counts(end_node): 
    """
    Coming down the tree, from scalar (like a loss function) to inputs. How many times does each node come up?
    Useful because when we do our topological sort, we'll want to only return a
    node's parent if all its children have already been returned.  
    
    Params:
	    end_node: scalar value/ loss function: the thing at the end of the graph
	  """
	    
    # How many nodes have that node as a parent
    # e.g. node_counts ['a']=2 mean two nodes have 'a' as a parent. Note: won't count nodes that have 'a' as a parent if they have no link to the final scalar outpu
    node_counts = defaultdict(lambda: 0 )  
    
    stack = [end_node]  # list/stack initiated at the end node
    while stack:    # keep going while there's still stuff in the stack 
        node = stack.pop()  # current node is the one at the back of the stack
        stack.extend(node.parents)  # add its parents to the stack
        node_counts[node.name] += 1  # add the count - number of times it has come up 
    return node_counts
    
def toposort(end_node): 
    """Yield nodes in topological order. 
    See https://www.youtube.com/watch?v=eL-KzMXSXXI for tutorial on what this order is. 
    """
    
    counts = get_counts(end_node)   # see above 
    stack = [end_node]  # tracks where we are up to in the tree 
    while stack:   # keep going while there's stuff here 
        node = stack.pop()  # take from back of list 
        yield node  # print here to inspect the node
        # iterate through parents of node, ensure we only return node once 
        for parent in node.parents: 
            p = parent.name
            # we can only return a node's parent after all its children have been returned. 
            # otherwise sort is out of order, or if a node is parent to more than one 
            # node it'll be returned more than once (which is bad) 
            if counts[p] == 1: stack.append(parent)
            else:              counts[p] -= 1             

How can we use this code? Let’s define a class called Node that we can do this topological sort on.

class Node: 
    def __init__(self, name, parents): 
        self.name = name
        x = [parents] if type(parents) == Node else parents  # gotta put Nodes in lists
        self.parents = x
    
    def __str__(self):   # for printing
        if self.parents is None: 
            n = "None"
        elif len(self.parents)==1:
            n = self.parents[0].name
        else: 
            n = ', '.join([o.name for o in self.parents])
        return "name: "+ self.name + " parents: "+ n
    
    def __len__(self):   # for use in printing 
        return 1 

Now we can create this computational graph ourselves using Node objects.

z = Node('z', [])
t1 = Node('t1', z)
t2 = Node('t2', t1)
one = Node('1', [])
t3 = Node('t3', [one, t2])
y = Node('y', t3)

Then if we run toposort(y), it’ll return the nodes in reverse topological order, which is what we want. We can test it ourselves quickly by changing the yield to a print inside the function. We get

toposort(y)
# y
# t3
# t2
# t1
# z
# 1

This was a pretty simple example. But it’ll work as long as the graph remains a DAG: i.e. it doesn’t have any loops in it.