Skip to main content

Easy Reverse Mode Automatic Differentiation in C#

Continuing from my last post on implementing forward-mode automatic differentiation (AD) using C# operator overloading, this is just a quick follow-up showing how easy reverse mode is to achieve, and why it's important.

Why Reverse Mode Automatic Differentiation?

As explained in the last post, the vector representation of forward-mode AD can compute the derivatives of all parameter simultaneously, but it does so with considerable space cost: each operation creates a vector computing the derivative of each parameter. So N parameters with M operations would allocation O(N*M) space. It turns out, this is unnecessary!

Reverse mode AD allocates only O(N+M) space to compute the derivatives of N parameters across M operations. In general, forward mode AD is best suited to differentiating functions of type:

RRN

That is, functions of 1 parameter that compute multiple outputs. Reverse mode AD is suited to the dual scenario:

RNR

That is, functions of many parameters that return a single real number. A lot of problems are better suited to reverse mode AD, and some modern machine learning frameworks now employ reverse mode AD internally (thousands of parameters, single output that's compared to a goal).

How does Reverse Mode Work?

The identities I described in the other article still apply since they're simply the chain rule, but reverse mode computes derivatives backwards. Forward-mode AD is easy to implement using dual numbers in which the evaluation order matches C#'s normal evaluation order: just compute a second number corresponding to the derivative along side the normal computation. Since reverse mode runs backwards, we have to do the computational dual: build a (restricted) continuation!

This is a rough sketch showcasing both forward mode and reverse mode and how they're duals. Forward mode AD using dual numbers will look something like this:

public readonly struct Fwd
{
    public readonly double Magnitude;
    public readonly double Derivative;

    public Fwd(double mag, double deriv)
    {
        this.Magnitude = mag;
        this.Derivative = deriv;
    }

    public Fwd Pow(int k) =>
        new Fwd(Math.Pow(Magnitude, k), k * Math.Pow(Magnitude, k - 1) * Derivative);

    public static Fwd operator +(Fwd lhs, Fwd rhs) =>
        new Fwd(lhs.Magnitude + rhs.Magnitude, lhs.Derivative + rhs.Derivative);

    public static Fwd operator *(Fwd lhs, Fwd rhs) =>
        new Fwd(lhs.Magnitude + rhs.Magnitude,
                lhs.Derivative * rhs.Magnitude + rhs.Derivative * lhs.Magnitude);

    public static Func<double, Fwd> Differentiate(Func<Fwd, Fwd> f) =>
        x => f(new Fwd(x, 1));

    public static Func<double, double, Fwd> DifferentiateX0(Func<Fwd, Fwd, Fwd> f) =>
        (x0, x1) => f(new Fwd(x0, 1), new Fwd(x1, 0));

    public static Func<double, double, Fwd> DifferentiateX1(Func<Fwd, Fwd, Fwd> f) =>
        (x0, x1) => f(new Fwd(x0, 0), new Fwd(x1, 1));
}

Translating this into reverse mode entails replacing Fwd.Derivative with a continuation like so:


public readonly struct Rev
{
    public readonly double Magnitude;
    readonly Action<double> Derivative;

    public Rev(double y, Action<double> dy)
    {
        this.Magnitude = y;
        this.Derivative = dy;
    }

    public Rev Pow(int e)
    {
        var x = Magnitude;
        var k = Derivative;
        return new Rev(Math.Pow(Magnitude, e),
                       dx => k(e * Math.Pow(x, e - 1) * dx));
    }

    public static Rev operator +(Rev lhs, Rev rhs) =>
        new Rev(lhs.Magnitude + rhs.Magnitude, dx =>
        {
            lhs.Derivative(dx);
            rhs.Derivative(dx);
        });

    public static Rev operator *(Rev lhs, Rev rhs) =>
        new Rev(lhs.Magnitude * rhs.Magnitude,
                dx =>
                {
                    lhs.Derivative(dx * rhs.Magnitude);
                    rhs.Derivative(dx * lhs.Magnitude);
                });

    public static Func<double, (double, double)> Differentiate(Func<Rev, Rev> f) =>
        x =>
        {
            double dx = 0;
            var y = f(new Rev(x, dy => dx += dy));
            y.Derivative(1);
            return (y.Magnitude, dx);
        };

    public static Func<double, double, (double, double, double)> Differentiate(Func<Rev, Rev, Rev> f) =>
        (x0, x1) =>
        {
            double dx0 = 0, dx1 = 0;
            var y = f(new Rev(x0, dy => dx0 += dy),
                      new Rev(x1, dy => dx1 += dy));
            y.Derivative(1);
            return (y.Magnitude, dx0, dx1);
        };

    public static Func<double, double, double, (double, double, double, double)> Differentiate(Func<Rev, Rev, Rev, Rev> f) =>
        (x0, x1, x2) =>
        {
            double dx0 = 0, dx1 = 0, dx2 = 0;
            var y = f(new Rev(x0, dy => dx0 += dy),
                      new Rev(x1, dy => dx1 += dy),
                      new Rev(x2, dy => dx2 += dy));
            y.Derivative(1);
            return (y.Magnitude, dx0, dx1, dx2);
        };
}

As I mentioned in my last post, my goal here isn't the most efficient implementation for reverse mode AD, but to distill its essence to make it direct and understandable. This representation builds a whole new continuation on every invocation of the function being differentiated. More efficient representations would only compute this continuation once for any number of invocations, and there are plenty of other optimizations that can be applied to both forward and reverse mode representations.

Comments

Jules said…
Cool! I think you need to change it to dx0 += dy if you have functions that use a value multiple times. Still, if you do that it may be exponential time. Is that right? Do you know any way to fix that?
Sandro Magi said…
Thanks Jules, I've corrected the post. I don't see how it could be exponential, the built delegate chain is essentially an execution trace that matches 1:1 with every operation on the magnitude. If you use a variable N times, that's N operations on the magnitude and N delegates that updates to the final derivative. Am I missing some pathological case?

I'm not really concerned about the performance of this particular incarnation anyway, it's really for clarity that I'll use to test against more sophisticated implementations.

For instance, it should be possible to lift the trace delegate outside of the inner delegate returned by Differentiate by restricting loops from differentiable functions, ie. replace loops with an explicit Sum() function.

Once you do that, it should then be possible to replace delegates with LINQ expressions and get a very efficient compiled equivalent with no virtual dispatch or allocation.
Sandro Magi said…
Come to think of it, Fwd is the classic dual numbers approach to AD, which I suppose would make the reverse mode representation "co-dual" numbers.

I don't think I've come across a dual representation for dual numbers, have you?
Jules said…
The Tiark Rompf stuff comes closest, but it's different, and I think you're right that this is novel :)

I thought differentiating this function would be a pathological case:

Rev f(Rev x){
for(int i=0; i<100; i++) x = x + x;
return x;
}
Jules said…
P.S. you could put x = (x + x)/2 to avoid overflow :)
Sandro Magi said…
That definitely incurs exponential updates, but it's also exponential in the number of references to x in both Dual and Codual forms. Some identities can handle the low-hanging fruit, but I don't think it's possible to avoid in all cases while still preserving the simplicity, eg. loop{x = x * 2 + x} will still be exponential.

I think you'd need a more sophisticated representation with lookahead to solve this fully.

A simple extension for partial lookahead could add a special type for multiplication/division by constants, which accumulates operations until it hits another Codual, then it performs the operation with two Coduals, and then apply the accumulated constants.
Sandro Magi said…
Re: Rompf and AD, this looks like the earliest paper on it. It looks like section 2.2 basically covers this case, but in trying to automate this they jump right to delimited continuations rather than using the callbacks to accumulate derivatives directly.

So the specific accumulation form is novel, but not the general idea of using CPS to backpropagate. Thanks for the reference!
Jules said…
Classic backwards automatic differentiation does also handle this in linear time by keeping track of a list of operations that were performed. Building aDAG shaped AST, basically. The Rompf method uses delimited continuations to do that. It's quite the overkill since this is a very special case of delimited continuations, but cool nonetheless.
Jules said…
So backward auto diff can differentiate any function R^n -> R in the same time complexity as it takes to run the function. This works for the loop example and all other examples. What the paper you linked to addresses is the memory complexity. Classic backwards auto diff needs to keep track of the list of operations that f did, so if f was originally O(f(n)) time and O(g(n)) memory, then differentiating will be O(f(n)) time and O(f(n)) memory. And f may be larger than g.

By contrast, forward diff can differentiate R -> R^n in the same time and memory complexity as the original function.
Sandro Magi said…
> So backward auto diff can differentiate any function R^n -> R in the same time complexity as it takes to run the function.

Solved! The identities I linked before + lifting multiplicaton/division by constants to a second field, like dual numbers, ensures that the identities are as aggressive as possible in culling the expression trace. This is definitely a novel representation. The expression loops we've looked at so far all seem to be addressed. Can you see any cases where this still falls down?

I think the only problem might be a slight change in floating precision because I change the order of operations during division/multiplication by constants.
Sandro Magi said…
Cases of more than one variable are still tricky actually.
Jules said…
Classical backward automatic differentiation works like this:

internal Codual(double x)
{
this.Magnitude = x;
this.Derivative = 0;
}

public Codual Sin()
{
var lhs = this;
var y = new Codual(Math.Sin(Magnitude));
trace.Add(() => lhs.Derivative += y.Derivative * Math.Cos(lhs.Magnitude));
return y;
}

Where trace is a global list of callbacks. When you compute the derivative of a function, you first run the function, and then you iterate over the trace backwards and call all the callbacks. After that, the Derivative field of your variables will be set to the correct value.

The trick is that instead of accumulating the derivative only at the variables with +=, we do so that at every node. Because we iterate over the trace backwards, the y.Derivative has the correct final value before the callback in the code above gets executed.
Jules said…
You could potentially refactor that to this:

internal Codual(double x, Action callback)
{
this.Magnitude = x;
this.Derivative = 0;
val self = this;
trace.Add(() => callback(self));
}

public Codual Sin()
{
return new Codual(Math.Sin(Magnitude), Propagate(this, Math.Cos(this.Magnitude)));
}

Where

Propagate(x, d) = (y) => y.Derivative += x.Derivative * d.

For operations with two arguments you'd need Propagate2(x1,d1,x2,d2).

Then you can use exactly the same code for the operations on dual and codual numbers, changing only the constructor and the definition of Propagate. E.g. for a single variable:

internal Dual(double x, double dx)
{
this.Magnitude = x;
this.Derivative = dx;
}

public Dual Sin()
{
return new Dual(Math.Sin(Magnitude), Propagate(this, Math.Cos(this.Magnitude)));
}

Where Propagate(x, d) = x.Derivative * d.

By the way, just like you can make forward mode handle multiple inputs by changing Derivative to be an array, you can make reverse mode handle multiple outputs by making Derivative an array.


Jules said…
For multiple variables you could also define a compose operator on Propagators, instead of Propagate2/3/4/etc. For forward you'd have Compose(a,b) = a+b and for backward you'd have Compose(a,b) = (y) => { a(y); b(y) }.
Sandro Magi said…
I've been trying to avoid the use of an imperative tape because I feel there's a simpler semantics underlying it, without going all the way to Conal's categorical breakdown of AD. I have to think about it some more.

That's an interesting breakdown of Dual/Codual via Propagators though! If only more languages had better partial evaluation support so that more general code like that could be more common.
Sandro Magi said…
Here's a basic but complete tape-based implementation, with the corresponding differentiators (a more efficient version would replace the Gradient function with an enum). It's fine, just not quite what I was going for, ie. side-effects, execution serialized/not parallelizable, etc.

There must be a purely functional structure that can merge the parallel branches without just reimplementing a state monad for a tape. The more I think about it, the more convinced I am that this will probably just end up as some specialization of Conal's approach, so maybe I should just start there.
Jules said…
Cool! I like the tape method. It allows you do do bulk operations like matrix multiply and register only 1 thing in the tape instead of 1 per primitive operation. I believe the Julia can do this automatically for some code.
Sandro Magi said…
A bytecode variant of the tape is interesting. It's a standard interpreter loop and so can benefit from all of the interpreter optimizations over the years, ie. direct/indirect threading, loop unrolling, super instructions, etc. I wonder at what scale it makes sense to generate a DynamicMethod instead of an interpreter loop to run the reverse accumulation.

Popular posts from this blog

async.h - asynchronous, stackless subroutines in C

The async/await idiom is becoming increasingly popular. The first widely used language to include it was C#, and it has now spread into JavaScript and Rust. Now C/C++ programmers don't have to feel left out, because async.h is a header-only library that brings async/await to C! Features: It's 100% portable C. It requires very little state (2 bytes). It's not dependent on an OS. It's a bit simpler to understand than protothreads because the async state is caller-saved rather than callee-saved. #include "async.h" struct async pt; struct timer timer; async example(struct async *pt) { async_begin(pt); while(1) { if(initiate_io()) { timer_start(&timer); await(io_completed() || timer_expired(&timer)); read_data(); } } async_end; } This library is basically a modified version of the idioms found in the Protothreads library by Adam Dunkels, so it's not truly ground bre

Simple, Extensible IoC in C#

I just committed the core of a simple dependency injection container to a standalone assembly, Sasa.IoC . The interface is pretty straightforward: public static class Dependency { // static, type-indexed operations public static T Resolve<T>(); public static void Register<T>(Func<T> create) public static void Register<TInterface, TRegistrant>() where TRegistrant : TInterface, new() // dynamic, runtime type operations public static object Resolve(Type registrant); public static void Register(Type publicInterface, Type registrant, params Type[] dependencies) } If you were ever curious about IoC, the Dependency class is only about 100 lines of code. You can even skip the dynamic operations and it's only ~50 lines of code. The dynamic operations then just use reflection to invoke the typed operations. Dependency uses static generic fields, so resolution is pretty much just a field access + invoking a

Building a Query DSL in C#

I recently built a REST API prototype where one of the endpoints accepted a string representing a filter to apply to a set of results. For instance, for entities with named properties "Foo" and "Bar", a string like "(Foo = 'some string') or (Bar > 99)" would filter out the results where either Bar is less than or equal to 99, or Foo is not "some string". This would translate pretty straightforwardly into a SQL query, but as a masochist I was set on using Google Datastore as the backend, which unfortunately has a limited filtering API : It does not support disjunctions, ie. "OR" clauses. It does not support filtering using inequalities on more than one property. It does not support a not-equal operation. So in this post, I will describe the design which achieves the following goals: A backend-agnostic querying API supporting arbitrary clauses, conjunctions ("AND"), and disjunctions ("OR"). Implemen