Skip to main content

Dual/Codual numbers for Forward/Reverse Automatic Differentiation

In my last two posts on automatic differentiation (AD), I described some basic primitives that implement the standard approach to forward mode AD using dual numbers, and then a dual representation of dual numbers that can compute in reverse mode. I'm calling these "co-dual " numbers, as they are the categorical dual of dual numbers.

It didn't click at the time that this reverse mode representation seems to be novel. If it's not, please let me know! I haven't seen any equivalent of dual numbers capable of computing in reverse mode.

When reverse mode AD is needed, most introductions to AD go straight to building a graph/DAG representation of the computation in order to improve the sharing properties and run the computation backwards, but that isn't strictly necessary. I aim to show that there's a middle ground between dual numbers and the graph approach, even if it's only suitable for pedagogical purposes.

Review: Dual Numbers

Dual numbers augment standard numbers with an extra term for the derivative of that number:

real number x   =(dual number transform)=>   x + x'*ϵ

Then various identities over addition, multiplication and trig functions propagate the derivative throughout the function to the return value. Here is a simple type for dual numbers to illustrate:
public readonly struct Dual
{
    public readonly double Magnitude;

    public readonly double Derivative;
    
    public static Dual operator +(Dual lhs, Dual rhs) =>
        new Dual(lhs.Magnitude + rhs.Magnitude,
                 lhs.Derivative + rhs.Derivative);
        
    public static Dual operator *(Dual lhs, Dual rhs) =>
        new Dual(lhs.Magnitude * rhs.Magnitude,
                 lhs.Derivative * rhs.Magnitude + rhs.Derivative * lhs.Magnitude);
            
    public Dual Pow(int k) =>
        new Dual(Math.Pow(Magnitude, k),
                 k * Math.Pow(Magnitude, k - 1) * Derivative);
}
The standard calculation operates on the Magnitude field, and the Derivative field computes the derivative right alongside via the standard identities. This permits you to somewhat efficiently differentiate any function that accepts only values of type Dual, which is called forward mode AD:
public static class Calculus
{
    public static Dual DerivativeX0At(
        double x0, double x1, Func<Dual, Dual, Dual> f) =>
        f(new Dual(x0, 1), new Dual(x1, 0));

    public static Dual DerivativeX1At(
        double x0, double x1, Func<Dual, Dual, Dual> f) =>
        f(new Dual(x0, 0), new Dual(x1, 1));
}
We call a differentiable function with 1 in the derivative position of a parameter if we're differentiating with respect to that parameter. This means you can only take the derivative with respect to a single parameter at a time in forward mode. If you need to differentiate with respect to multiple input parameters, you need to run the whole function repeatedly, once for each derivative (or as I covered in my first post, use a vector of derivatives which has high space complexity).

This is obviously inefficient if you often need the derivatives of multiple parameters, which is very common in machine learning. Fortunately, reverse mode AD is a dual of forward mode, which exhibits the dual of this property: it can efficiently compute the derivative of all parameters simultaneously.

Codual Numbers

With a basic dual number representation, we can take its categorical dual and get a type with the complete opposite properties of the original. However, the categorical dual of the Derivative field is not obvious. So let's modify Dual slightly to to make the duality transformation more obvious. First, we replace the Derivative field as a value with it's equivalent function:
public readonly struct Dual
{
    public readonly double Magnitude;

    public readonly Func<double> Derivative;
    
    public static Dual operator +(Dual lhs, Dual rhs) =>
        new Dual(lhs.Magnitude + rhs.Magnitude,
                 () => lhs.Derivative() + rhs.Derivative());
        
    public static Dual operator *(Dual lhs, Dual rhs) =>
        new Dual(lhs.Magnitude * rhs.Magnitude,
                 () => lhs.Derivative() * rhs.Magnitude + rhs.Derivative() * lhs.Magnitude);

    public Dual Pow(int k) =>
        new Dual(Math.Pow(Magnitude, k),
                 () => k * Math.Pow(this.Magnitude, k - 1) * this.Derivative());
}
This doesn't really change the semantics of dual numbers, we essentially just added one level of indirection. But now that the Derivative is a function, the duality is obvious: the dual of a function that accepts nothing and returns a value is a function that accepts a value and returns nothing:
public readonly struct Codual
{
    public readonly double Magnitude;

    internal readonly Action<double> Derivative;
    
    public static Codual operator +(Codual lhs, Codual rhs) =>
        new Codual(lhs.Magnitude + rhs.Magnitude, dx =>
        {
            lhs.Derivative(dx);
            rhs.Derivative(dx);
        });
        
    public static Codual operator *(Codual lhs, Codual rhs) =>
        new Codual(lhs.Magnitude * rhs.Magnitude, dx =>
        {
            lhs.Derivative(dx * rhs.Magnitude);
            rhs.Derivative(dx * lhs.Magnitude);
        });
    
    public Codual Pow(int k)
    {
        var lhs = this;
        return new Codual(Math.Pow(Magnitude, k),
                          dx => lhs.Derivative(k * Math.Pow(lhs.Magnitude, k - 1) * dx));
    }
}
The operations on the magnitude remain unchanged, but the operations on the derivative have been turned inside out. Let's consider the multiplication identity on dual numbers to see how this works:

<x, x'> * <y, y'> = <x*y, y'*x + y*x'>

If we are differentiating with respect to x, then x' will have some value in forward mode, and y' will equal 0. This simplifies to:

<x, x'> * <y, 0> = <x*y, y*x'>

If we are differentiating with respect to y, then y' will have some value in forward mode, and x' will equal 0. This simplifies to:

<x, 0> * <y, y'> = <x*y, y'*x>

If we are taking the dual of these identities, then we need only propagate the dependencies: y' depends only on x, and x' depends only on y, and that's exactly what you see in the Codual multiplication operator:
public static Codual operator *(Codual x, Codual y) =>
    new Codual(x.Magnitude * y.Magnitude, dz =>
    {
        x.Derivative(dz * y.Magnitude);
        y.Derivative(dz * x.Magnitude);
    });
You can think of the Derivative field as building an execution trace which enables us to run the same computation that's happening on Magnitude, just backwards from outputs to inputs. Differentiating a function using the Codual representation works like this:
public static Result DerivativeAt(
    double x0, double x1, Func<Codual, Codual, Codual> f)
{
    double dx0 = 0, dx1 = 0;
    var y = f(new Codual(x0, dy => dx0 += dy),
              new Codual(x1, dy => dx1 += dy));
    y.Derivative(1);
    return new Result(y.Magnitude, dx0, dx1);
}
Similar to the dual number representation, we pass in x0 and x1 for the ordinary value calculation. Similar to the operators on Codual, we pass in functions for the derivative, but these ones are slightly different because they are the "leaves" of the execution trace which have to set the derivatives of the function's parameters. You can see that they operates as accumulators, summing the contributions to the derivative from every path through the executing function. If you only use a parameter once in your function, this leaf function is only executed once. If you use it N times, it executes N times.

Finally, like the dual number representation, we start the reverse derivative computation with a 1 as input, and the execution trace that was built propagates the value back to the input parameters.

Note that reverse mode AD is effectively a generalized form of backpropagation as used in machine learning. So if you've ever wanted to understand backpropagation, there it is in ~60 lines of code!

Summary

You can find a simple .NET automatic differentiation library that provides these types at this repo, and I've also uploaded a nuget package if you just want to play around with it. I wouldn't recommend using these abstractions for any sophisticated differentiation purposes, but they'll probably work fine for small applications and learning how AD works.

Edit: some earlier work noted that CPS style can be used for reverse mode AD, but likely for performance reasons, they went straight to using delimited continuations rather than implementing a co-dual number form like described here. So Codual is kinda-sorta novel, but probably not too surprising to a lot of people in this field.

Edit 2: Jules Jacobs noted some exponential blow up with some naive programs using Codual numbers, like loop(N) { x = x + x }. Dual/Codual numbers will probably never be as efficient as more sophisticated approaches to AD, but some simple optimizations can handle the low-hanging fruit to avoid exponential blowup of some naive expressions.

Comments

Popular posts from this blog

async.h - asynchronous, stackless subroutines in C

The async/await idiom is becoming increasingly popular. The first widely used language to include it was C#, and it has now spread into JavaScript and Rust. Now C/C++ programmers don't have to feel left out, because async.h is a header-only library that brings async/await to C!Features:It's 100% portable C.It requires very little state (2 bytes).It's not dependent on an OS.It's a bit simpler to understand than protothreads because the async state is caller-saved rather than callee-saved.#include "async.h" struct async pt; struct timer timer; async example(struct async *pt) { async_begin(pt); while(1) { if(initiate_io()) { timer_start(&timer); await(io_completed() || timer_expired(&timer)); read_data(); } } async_end; } This library is basically a modified version of the idioms found in the Protothreads library by Adam Dunkels, so it's not truly ground breaking. I've mad…

Building a Query DSL in C#

I recently built a REST API prototype where one of the endpoints accepted a string representing a filter to apply to a set of results. For instance, for entities with named properties "Foo" and "Bar", a string like "(Foo = 'some string') or (Bar > 99)" would filter out the results where either Bar is less than or equal to 99, or Foo is not "some string".This would translate pretty straightforwardly into a SQL query, but as a masochist I was set on using Google Datastore as the backend, which unfortunately has a limited filtering API:It does not support disjunctions, ie. "OR" clauses.It does not support filtering using inequalities on more than one property.It does not support a not-equal operation.So in this post, I will describe the design which achieves the following goals: A backend-agnostic querying API supporting arbitrary clauses, conjunctions ("AND"), and disjunctions ("OR").Implementations of this…

Easy Reverse Mode Automatic Differentiation in C#

Continuing from my last post on implementing forward-mode automatic differentiation (AD) using C# operator overloading, this is just a quick follow-up showing how easy reverse mode is to achieve, and why it's important.Why Reverse Mode Automatic Differentiation?As explained in the last post, the vector representation of forward-mode AD can compute the derivatives of all parameter simultaneously, but it does so with considerable space cost: each operation creates a vector computing the derivative of each parameter. So N parameters with M operations would allocation O(N*M) space. It turns out, this is unnecessary!Reverse mode AD allocates only O(N+M) space to compute the derivatives of N parameters across M operations. In general, forward mode AD is best suited to differentiating functions of type:RRNThat is, functions of 1 parameter that compute multiple outputs. Reverse mode AD is suited to the dual scenario:RN → RThat is, functions of many parameters that return a single real …