Building a computational graph: part 2

This is the second post in a series on computational graphs. You can go back to the previous post or ahead to the next post.


If you’d like to see what we are working towards in these posts, here is the Github link:


Last time we

  • looked at computational graphs and their use in autodiff packages.
  • looked at the autodiff problem and the structure of the grad function in autograd
  • showed how Python breaks down expressions to create computational graphs
  • created a simple graph manually using a simplified Node class

How do we automatically create a computational graph for a function? We could create it manually last time, but we’ll need to be able to do it automatically for any function. That’s what we cover here.

As a running example we’ll use this logistic function throughout:

def logistic(z):  return 1 / (1 + np.exp(-z))

Primitives Link to heading

Loosely speaking, a primitive is a basic operation, like $+, \times, /, \exp$ or $\log$. We want to create a function for each primitive that adds them to a computation graph whenever they are called. Something like this:

def add_new(x,y): 
    # add to computation graph 
    print('Add to graph!')
    return x+y

The numpy package implements well-tested functions for each primitive, like np.add, np.multiply or np.exp. Because numpy goes to all the work of creating reliable, tested primitives, it’d be great to reuse their work instead of creating our functions from scratch. So that’s what we’ll do.

We create a function primitive that

  • takes a function f as an input (which will be a numpy function)
  • returns the same function f, except we add f to our computation graph as a Node.

Here’s the basic structure of primitive, just with placeholder code for the computational-graph adding bit.

def primitive(f): 
    def inner(*args, **kwargs): 
        """This is a nested function"""
        # add to graph
        print("add to graph!")
        return f(*args, **kwargs)
    return inner

Use it like this.

mult_new = primitive(np.multiply)  # 
print(mult_new(1,4))
add to graph!
4

Since primitive is a function that returns a function, we can also use it as a decorator. I’ve written this other post on decorators if you want to know more.

# another way to use it 
@primitive 
def mult_new2(*args, **kwargs): return np.multiply(*args, **kwargs)
print(mult_new2(1,4))
add to graph!
4

A problem with this as it stands is that we lose all the metadata of the numpy function we wrap in primitive, like its documentation and name. It won’t get copied over. Instead this new function has the metadata of the nested function inner inside primitive.

print("Name of new function:", mult_new.__name__)
print("Doc of new function:", mult_new.__doc__)
Name of new function: inner
Doc of new function: This is a nested function

We obviously don’t want this, but we can get around it by adding the @wraps(f) decorator from the functools package above inner inside the primitive definition. This copies over the name, docs, and some other things from the numpy function to our version. Now we don’t lose all the documentation.

from functools import wraps 
def primitive(f): 
    @wraps(f)
    def inner(*args, **kwargs): 
        """This is a nested function"""
        # add to graph
        print("add to graph!")
        return f(*args, **kwargs)
    return inner

mult_new3 = primitive(np.multiply) 
mult_new3.__name__  # multiply
print(mult_new3.__doc__[0:300])
multiply(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])

Multiply arguments element-wise.

Parameters
----------
x1, x2 : array_like
    Input arrays to be multiplied.
out : ndarray, None, or tuple of ndarray and None, optional
    A 

Creating primitives Link to heading

Last time we created a Node class. Remember, Nodes hold operations/primitives in them (as the fun attribute), the value at that point, and their parents in the graph.

Below is the same Node class. I have just added a __repr__ method to make debugging a bit easier.

class Node:
    """A node in a computation graph."""
    def __init__(self, value, fun, parents):
        self.parents = parents
        self.value = value
        self.fun = fun 
        
    def __repr__(self): 
        """A (very) basic string representation"""
        if self.value is None: str_val = 'None'
        else:                  str_val = str(round(self.value,3))
        return   "\n" + "Fun: " + str(self.fun) +\
                " Value: "+ str_val + \
                " Parents: " + str(self.parents) 
    

Let’s create some primitives. There are a few differences to before:

  • inner doesn’t return a function value like f(*args, **kwargs), but a Node with the function value as the value attribute: Node(f(*args, **kwargs), f, args)
  • Sometimes Node’s interact with integers. There is some extra code below to handle that situation, mostly around extracting the value attribute of the node and savng that in args and kwargs for use in f.
from functools import wraps
def primitive(f): 
    @wraps(f)
    def inner(*args, **kwargs):
        ## Code to add operation/primitive to computation graph
        
        # We need to separate out the integer/non node case. Sometimes you are adding 
        # constants to nodes. 
        def getval(o):      return o.value if type(o) == Node else o
        if len(args):       argvals = [getval(o) for o in args]
        else:               argvals = args
        if len(kwargs):     kwargvals = dict([(k,getval(o)) for k,o in kwargs.items()])
        else:               kwargvals =  kwargs
         
        # get parents 
        l = list(args) + list(kwargs.values())
        parents = [o for o in l if type(o) == Node ]
        
        value = f(*argvals, **kwargvals)
        print("add", "'" + f.__name__ + "'", "to graph with value",value)
        return Node(value, f, parents)
    return inner

Now wrap some basic numpy functions with primitive to get computational-graph versions of these functions:

add_new = primitive(np.add)
mul_new = primitive(np.multiply)
div_new = primitive(np.divide)
sub_new = primitive(np.subtract)
neg_new = primitive(np.negative)
exp_new = primitive(np.exp)

Let’s try it out! We can’t try it out on our logistic function yet, because that uses operators like $+$ and $\times$ instead of np.add and np.multiply, and we haven’t done any operator overloading. But we can write out the logistic function in terms of the operators and see if it works. We should get a final value of 0.818 (and indeed we do).

def start_node(value = None): 
    """A function to create an empty node to start off the graph"""
    fun,parents = lambda x: x, []
    return Node(value, fun, parents)

z = start_node(1.5)
t1 = mul_new(z, -1)
t2 = exp_new(t1)
t3 = add_new(t2, 1)
y = div_new(1,t3)
print("Final answer:", round(y.value,3))  # correct final output 
print(y)
add 'multiply' to graph with value -1.5
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437
Final answer: 0.818

Fun: <ufunc 'true_divide'> Value: 0.818 Parents: [
Fun: <ufunc 'add'> Value: 1.223 Parents: [
Fun: <ufunc 'exp'> Value: 0.223 Parents: [
Fun: <ufunc 'multiply'> Value: -1.5 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x10fea27b8> Value: 1.5 Parents: []]]]]

Operator overloading Link to heading

We want to be able to use these functions for common operators. In other words, if we define a function def f(x,y): return x+y, and we pass in two Node objects to f as x and y, we want f to use our add_new method.

Let’s do this. All we have to do is redefine a version of Node that implements the relevant dunder methods:

class Node:
    """A node in a computation graph."""
    def __init__(self, value, fun, parents):
        self.parents = parents
        self.value = value
        self.fun = fun 
        
    def __repr__(self): 
        """A (very) basic string representation"""
        if self.value is None: str_val = 'None'
        else:                  str_val = str(round(self.value,3))
        return   "\n" + "Fun: " + str(self.fun) +\
                " Value: "+ str_val + \
                " Parents: " + str(self.parents) 
    
    ## Code to overload operators
    # Don't put self.value or other.value in the arguments of these functions, 
    # otherwise you won't be able to access the Node object to create the 
    # computational graph. 
    # Instead, pass the whole node through. And to prevent recursion errors, 
    # extract the value inside the `primitive` function. 
    def __add__(self, other): return add_new(self, other)
    def __radd__(self, other): return add_new(other, self)
    def __sub__(self, other): return sub_new(self, other)
    def __rsub__(self, other): return sub_new(other, self)
    def __truediv__(self, other): return div_new(self, other)
    def __rtruediv__(self, other): return div_new(other, self)
    def __mul__(self, other): return mul_new(self, other)
    def __rmul__(self, other): return mul_new(other, self)
    def __neg__(self): return neg_new(self)
    def __exp__(self): return exp_new(self)
    

Now we can add nodes using $+$, divide them with $/$ and so on. Here is a basic example of adding Nodes with $+$:

val_z = 1.5 
z = Node(val_z, None, [])
val_t1 = 4
t1 = Node(val_t1, None, [])
y = z + t1 
add 'add' to graph with value 5.5

Here is the graph of y:

print(y)
Fun: <ufunc 'add'> Value: 5.5 Parents: [
Fun: None Value: 1.5 Parents: [], 
Fun: None Value: 4 Parents: []]

Let’s try it out on a modified version of logistic function that uses our exp_new function.

def logistic2(z):  return 1 / (1 + exp_new(-z))
y = logistic2(start_node(value = 1.5))
add 'negative' to graph with value -1.5
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437

The graph of y:

print(y)
Fun: <ufunc 'true_divide'> Value: 0.818 Parents: [
Fun: <ufunc 'add'> Value: 1.223 Parents: [
Fun: <ufunc 'exp'> Value: 0.223 Parents: [
Fun: <ufunc 'negative'> Value: -1.5 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x10fe90f28> Value: 1.5 Parents: []]]]]

Sweet! It is working. Now try a multivariate functions.

def somefun(x,y):  return (x*y + exp_new(x)*exp_new(y))/(4*y)
def somefun2(x,y):  return (x*y + np.exp(x)*np.exp(y))/(4*y)
val_x, val_y = 3,4 
ans = somefun(start_node(3), start_node(4))
add 'multiply' to graph with value 12
add 'exp' to graph with value 20.085536923187668
add 'exp' to graph with value 54.598150033144236
add 'multiply' to graph with value 1096.6331584284585
add 'add' to graph with value 1108.6331584284585
add 'multiply' to graph with value 16
add 'true_divide' to graph with value 69.28957240177866

Graph of ans:

print(ans)
Fun: <ufunc 'true_divide'> Value: 69.29 Parents: [
Fun: <ufunc 'add'> Value: 1108.633 Parents: [
Fun: <ufunc 'multiply'> Value: 12 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566f28> Value: 3 Parents: [], 
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []], 
Fun: <ufunc 'multiply'> Value: 1096.633 Parents: [
Fun: <ufunc 'exp'> Value: 20.086 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566f28> Value: 3 Parents: []], 
Fun: <ufunc 'exp'> Value: 54.598 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []]]], 
Fun: <ufunc 'multiply'> Value: 16 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []]]

The result looks complex, but that is because our __repr__ function is basic and doesn’t handle nested representations. Still, all the information is there, and we have created a computational graph successfully.

Next steps Link to heading

At this point we can create functions using common operators and automatically trace their computation graph. Nice!

But we aren’t quite there yet. There’s a few things missing.

  • we don’t want to replace np.add with add_new, np.exp with exp_new etc everywhere. That’s a pain, especially we have a lot of code to do that for.
  • currently we have to implement primitives for every numpy function we want. Is there a way to get them all?
  • how do we handle non-differentiable functions?

We’ll cover these in the next post!