Skip to main content

Easy Automatic Differentiation in C#

I've recently been researching optimization and automatic differentiation (AD), and decided to take a crack at distilling its essence in C#.

Note that automatic differentiation (AD) is different than numerical differentiation. Math.NET already provides excellent support for numerical differentiation.

C# doesn't seem to have many options for automatic differentiation, consisting mainly of an F# library with an interop layer, or paid libraries. Neither of these are suitable for learning how AD works.

So here's a simple C# implementation of AD that relies on only two things: C#'s operator overloading, and arrays to represent the derivatives, which I think makes it pretty easy to understand. It's not particularly efficient, but it's simple! See the "Optimizations" section at the end if you want a very efficient specialization of this technique.

What is Automatic Differentiation?

Simply put, automatic differentiation is a technique for calculating the derivative of a numerical function with roughly constant time and space overhead relative to computing the function normally.

AD is typically preferred over numerical differentiation as the latter introduces imprecision due to the approximations employed and the limits of floating point precision inherent to our number types. AD is also often preferred to symbolic differentiation, as the latter can sometimes yield large and awkward terms which results in inefficient execution times, without any appreciable increase in precision.

Differentiation is critical to all sorts of optimization problems in many domains, so this is quite a useful technique!

Automatic Differentiation Basics

The most straightforward implemenation of AD is based on dual numbers, where each regular number is augmented with an extra term corresponding to it's derivative:

real number x   =(dual number transform)=>   x + x'*ϵ

Arithmetic and other mathematical functions then have translations to operating on these extended number types as follows:

Operator Translated
<x, x'> + <y, y'> <x + y, x' + y'>
<x, x'> - <y, y'> <x - y, x' - y'>
<x, x'> * <y, y'> <x*y, y'*x - y*x'>
<x, x'> / <y, y'> <x / y, (x'*y - x*y')/y2>
<x, x'>k <xk, k*x(k-1)*x'>

See the wikipedia page for more transformations, like standard trig functions.

Dual Numbers from Operator Overloading

Invoking a function "f" with dual numbers operates like this, in math notation:

f(x0 + ϵx1, x1 + ϵx2, x2 + ϵx2)

Each parameter carries an extra 'ϵ' value corresponding to the derivative, and this extra value is distinct from the 'ϵ' values of all other parameters. However, as you can see in the translation table, these derivatives interact with one another in some operators, so a general number type carries a vector for the coefficients of all other parameters. Here's the basic number type:

public readonly struct Number
{
    public readonly double Magnitude;
    public readonly double[] Derivatives;

    internal Number(double m, params double[] d)
    {
        this.Magnitude = m;
        this.Derivatives = d;
    }
}

A differentiable function of 3 parameters would have this signature:

Number Function(Number x0, Number x1, Number x2)

Internally, differentiation invokes the function like this:

public static Number DifferentiateAt(
    double x0, double x1, double x2, Func<Number, Number, Number, Number> func) =>
    func(new Number(x0, 1, 0, 0),
         new Number(x1, 0, 1, 0),
         new Number(x2, 0, 0, 1));

Each parameter is initialized with zeroes everywhere except for a 1 in the derivative slot corresponding to its position in the argument list. This is the necessary starting configuration for automatic differentiation in order to compute the derivatives for any of the parameters.

Operators are now pretty straightforward, corresponding to operations on the magnitude and pairwise operations on each index of the array:

public static Number operator +(Number lhs, Number rhs) =>
    new Number(lhs.Magnitude + rhs.Magnitude,
               lhs.Derivatives + rhs.Derivatives);

public static Number operator *(Number lhs, Number rhs) =>
    new Number(lhs.Magnitude * rhs.Magnitude,
               lhs.Derivatives * rhs.Magnitude + rhs.Derivatives * lhs.Magnitude);

Obviously you can't add or multiply two double[] like I've shown here, but the actual implementation hides the array behind a Derivatives type that exposes the arithmetic operators:

public static Derivatives operator +(Derivatives lhs, Derivatives rhs)
{
    var v = new double[lhs.Count];
    for (int i = 0; i < lhs.vector.Length; ++i)
        v[i] = lhs.vector[i] + rhs.vector[i];
    return new Derivatives(v);
}

public static Derivatives operator *(Derivatives lhs, Derivatives rhs)
{
    var v = new double[lhs.Count];
    for (int i = 0; i < lhs.vector.Length; ++i)
        v[i] = lhs.vector[i] * rhs.vector[i];
    return new Derivatives(v);
}

Once you have your return value of type Number, you can access the derivatives of each parameter by its index:

var y = Calculus.DifferentiateAt(x0, x1, function);
Console.WriteLine("x0' = " + y.Derivatives[0]);
Console.WriteLine("x1' = " + y.Derivatives[1]);

Hopefully, that's clear enough for most people to understand the magic behind automatic differentiation. It's an incredibly useful tool that I'm only now exploring, and it deserves to be more widely known!

Optimizations

Most presentations of automatic differentiation show examples where you differentiate a function with respect to only a single parameter, but the technique described above computes every derivative simultaneously. Obviously that's more general, but you typically don't need all of the derivatives which makes this technique a little wasteful.

So as a first optimization, review the starting configuration for automatic differentiation described above and consider what happens when you're interested in the derivative of x0 only:

public static Number Differentiate_X0(
    double x0, double x1, double x2,
    Func<Number, Number, Number, Number> func) =>
    func(new Number(x0, 1, 0, 0),
         new Number(x1, 0, 0, 0),
         new Number(x2, 0, 0, 0));

When you only want one of the derivatives, the ϵ coefficient of all other parameters would be zero, and all of those array slots filled with zeroes would stay zero throughout the whole computation. So toss them out!

Create a specialized Number type that doesn't incur any array allocations at all by replacing Derivatives with a single System.Double corresponding to the one parameter that's being differentiated. That parameter gets a 1 as the extra term when differentiating, the rest all start with 0.

So while you can only differentiate with respect to one variable at a time with this more specialized Number type, you only need to carry around an extra double for each step in the calculation. This would be very space and time efficient!

Comments

Popular posts from this blog

async.h - asynchronous, stackless subroutines in C

The async/await idiom is becoming increasingly popular. The first widely used language to include it was C#, and it has now spread into JavaScript and Rust. Now C/C++ programmers don't have to feel left out, because async.h is a header-only library that brings async/await to C! Features: It's 100% portable C. It requires very little state (2 bytes). It's not dependent on an OS. It's a bit simpler to understand than protothreads because the async state is caller-saved rather than callee-saved. #include "async.h" struct async pt; struct timer timer; async example(struct async *pt) { async_begin(pt); while(1) { if(initiate_io()) { timer_start(&timer); await(io_completed() || timer_expired(&timer)); read_data(); } } async_end; } This library is basically a modified version of the idioms found in the Protothreads library by Adam Dunkels, so it's not truly ground bre...

Building a Query DSL in C#

I recently built a REST API prototype where one of the endpoints accepted a string representing a filter to apply to a set of results. For instance, for entities with named properties "Foo" and "Bar", a string like "(Foo = 'some string') or (Bar > 99)" would filter out the results where either Bar is less than or equal to 99, or Foo is not "some string". This would translate pretty straightforwardly into a SQL query, but as a masochist I was set on using Google Datastore as the backend, which unfortunately has a limited filtering API : It does not support disjunctions, ie. "OR" clauses. It does not support filtering using inequalities on more than one property. It does not support a not-equal operation. So in this post, I will describe the design which achieves the following goals: A backend-agnostic querying API supporting arbitrary clauses, conjunctions ("AND"), and disjunctions ("OR"). Implemen...