Dual numbers
To make forward-mode autodiff work, we replace all floating point numbers with so-called dual numbers z=x+yN where N is a nilpotent element, i.e., N2=0. The space of dual number D is a real algebra with addition and multiplication defined as follows. Let z1=x1+y1N and z2=x2+y2N, then z1+z2=(x1+x2)+(y1+y2)N and z1⋅z2=x1x2+(x1y2+x2y1)N Notice now this product encodes the product rule for derivatives. All elements z∈D with z=x+yN and x≠0 have a multiplicative inverse z−1=x−1−yx−2N. In C++ we can define the dual numbers using aclass
and overload the arithmatic operators as follows.
In the header file dual.hpp
we declare a class Dual
with members x
and y
.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef DUAL_HPP_ | |
#define DUAL_HPP_ | |
#include <iostream> | |
class Dual { | |
public: | |
// public methods | |
constexpr Dual() : x(0.0), y(0.0) {} | |
constexpr Dual(const Dual & z) : x(z.x), y(z.y) {} | |
constexpr Dual(double x) : x(x), y(0.0) {} | |
constexpr Dual(double x, double y) : x(x), y(y) {} | |
const Dual & operator=(const Dual & z); | |
const Dual & operator=(double x); | |
const Dual & operator+=(const Dual & z); | |
const Dual & operator-=(const Dual & z); | |
const Dual & operator*=(const Dual & z); | |
const Dual & operator/=(const Dual & z); | |
Dual operator+(const Dual & z) const; | |
Dual operator-(const Dual & z) const; | |
Dual operator-() const; | |
Dual operator*(const Dual & z) const; | |
Dual operator/(const Dual & z) const; | |
// members | |
double x; | |
double y; | |
}; | |
constexpr Dual Nil(0.0, 1.0); // the Nilpotent element | |
std::ostream & operator<<(std::ostream & os, const Dual & z); | |
#endif |
dual.cpp
we define all methods using the rules for dual arithmetic described above.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "dual.hpp" | |
const Dual & Dual::operator=(double x) { | |
this->x = x; | |
y = 0.0; | |
return *this; | |
} | |
const Dual & Dual::operator=(const Dual & z) { | |
if ( this != &z ) { | |
x = z.x; | |
y = z.y; | |
} | |
return *this; | |
} | |
const Dual & Dual::operator+=(const Dual & z) { | |
x += z.x; | |
y += z.y; | |
return *this; | |
} | |
const Dual & Dual::operator-=(const Dual & z) { | |
x -= z.x; | |
y -= z.y; | |
return *this; | |
} | |
const Dual & Dual::operator*=(const Dual & z) { | |
y = y*z.x + x*z.y; | |
x *= z.x; | |
return *this; | |
} | |
const Dual & Dual::operator/=(const Dual & z) { | |
y = y/z.x - x*z.y/(z.x*z.x); | |
x /= z.x; | |
return *this; | |
} | |
Dual Dual::operator+(const Dual & z) const { | |
Dual w(*this); | |
w += z; | |
return w; | |
} | |
Dual Dual::operator-(const Dual & z) const { | |
Dual w(*this); | |
w -= z; | |
return w; | |
} | |
Dual Dual::operator-() const { | |
Dual w; | |
w -= *this; | |
return w; | |
} | |
Dual Dual::operator*(const Dual & z) const { | |
Dual w(*this); | |
w *= z; | |
return w; | |
} | |
Dual Dual::operator/(const Dual & z) const { | |
Dual w(*this); | |
w /= z; | |
return w; | |
} | |
std::ostream & operator<<(std::ostream & os, const Dual & z) { | |
os << z.x << " + " << z.y << " Nil"; | |
return os; | |
} |
Function evaluation
The magic of dual numbers comes from the way we can extend differentiable functions on R to the dual numbers. For any differentiable function f:R→R, we can extend f to a function f:D→D by defining f(x+yN)=f(x)+yf′(x)N where f′ is the derivative of f. To see why this is a natural choise, we take the Taylor exansion of f around x and see that f(x+yN)=f(x)+f′(x)yN+12f″(x)(yN)2+… However, as N2=0 all the terms of order y2 and higher vanish. This allows us to overload a number of standard mathematical functions such that they work on bothdouble
and Dual
variables.
In the files dualmath.hpp
and dualmath.cpp
we define some functions like sin(x) and cos(x)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef DUALMATH_HPP_ | |
#define DUALMATH_HPP_ | |
#include "dual.hpp" | |
Dual sin(const Dual & z); | |
Dual cos(const Dual & z); | |
Dual exp(const Dual & z); | |
Dual log(const Dual & z); | |
Dual sqrt(const Dual & z); | |
Dual pow(const Dual & z, const Dual & w); | |
#endif |
dualmath.hpp
are declared for double
values in the <cmath>
header,
which is included in the file dualmath.cpp
.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "dualmath.hpp" | |
#include <cmath> | |
Dual sin(const Dual & z) { | |
return Dual(sin(z.x), z.y * cos(z.x)); | |
} | |
Dual cos(const Dual & z) { | |
return Dual(cos(z.x), -z.y * sin(z.x)); | |
} | |
Dual exp(const Dual & z) { | |
double expx = exp(z.x); | |
return Dual(expx, z.y * expx); | |
} | |
Dual log(const Dual & z) { | |
return Dual(log(z.x), z.y / z.x); | |
} | |
Dual sqrt(const Dual & z) { | |
double sqrtx = sqrt(z.x); | |
return Dual(sqrtx, 0.5*z.y / sqrtx); | |
} | |
Dual pow(const Dual & z, const Dual & w) { | |
return exp(log(z) * w); | |
} |
Templated functions
Of course, in the functions defined indualmath.cpp
we had to "manually" calculate the derivative to evaluate the functions on D (with the exception of the pow
function).
However, we went to the trouble to define the dual numbers in C++ in order to "automatically" calculate derivatives of functions.
To see how this works, suppose that define a templated C++ function fun<T>(T x)
, then we can either use fun<double>
or
fun<Dual>
. The derivative of fun
can then be evaluated at a real number x using a dual number z=x+yN with non-zero y.
In the following example, we "automatically" compute the derivative of tan(x).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "dualmath.hpp" | |
#include <cmath> | |
#include <iostream> | |
template<class T> | |
T fun(T x) { | |
return sin(x) / cos(x); // tan(x) | |
} | |
int main() { | |
for ( double x = 0; x < 10; x += 0.01 ) { | |
// promote x to dual | |
Dual xe(x, 1.0); // set y = 1 | |
// the derivative of tan(x) is 1/cos^2(x) | |
std::cout << x << " " << fun(xe).y << " " << 1/(cos(x) * cos(x)) << std::endl; | |
} | |
return 0; | |
} |
$ g++ ad_example1.cpp dual.cpp dualmath.cpp -o ad_example1 $ ./ad_example1 0 1 1 0.01 1.0001 1.0001 0.02 1.0004 1.0004 0.03 1.0009 1.0009 0.04 1.0016 1.0016 0.05 1.0025 1.0025 0.06 1.00361 1.00361 0.07 1.00492 1.00492 0.08 1.00643 1.00643 ...which we can use to make the following graph
To see why this works, we write z=x+1N and mimick the calculation done by the computer: tan(z)=sin(z)cos(z)=sin(x)+cos(x)Ncos(x)−sin(x)N which is defined in
dualmath.cpp
. Then using the rules for division and multiplication of dual numbers defined in dual.cpp
, we get
sin(x)+cos(x)Ncos(x)−sin(x)N=(sin(x)+cos(x)N)(1/cos(x)+sin(x)/cos2(x)N)
=sin(x)cos(x)+(1+sin2(x)/cos2(x))N
Which simplifies to
tan(x)=tan(x)+1/cos2(x)N
Hence, the Dual::y
component of fun(z)
is equal to tan′(x).
Create derivatives and gradients with a functor
Modern C++ makes it easy to write functional programs, and we can use this to define a functor that returns the derivative of f when given a function f. For this, we make use of the<functional>
header. The signature of our functor has the following form
std::function<double(double)> derivative(std::function<Dual(Dual)> f);meaning that the functor
derivative
takes a function f:D→D as argument, and
returns a function f′:R→R.
Things become much more useful in higher dimensions. In the example below, we also show how to automatically compute the gradient of a function F:Rn→R and (more efficiently) directional derivatives ∇F(x)⋅y. The 1D derivative, gradient and directional derivatives are declared in the following header file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef AUTOGRAD_HPP_ | |
#define AUTOGRAD_HPP_ | |
#include "dual.hpp" | |
#include <functional> | |
#include <vector> | |
// simple 1D derivative | |
std::function<double(double)> derivative(std::function<Dual(Dual)> f); | |
// some shorthands | |
typedef std::vector<double> RealVec; | |
typedef std::vector<Dual> DualVec; | |
// gradient | |
std::function<RealVec(const RealVec&)> gradient(std::function<Dual(const DualVec&)> F); | |
// directional derivative | |
std::function<double(const RealVec&, const RealVec&)> grad_vec_prod(std::function<Dual(const DualVec&)> F); | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "autograd.hpp" | |
#include <algorithm> // for std::transform | |
std::function<double(double)> derivative(std::function<Dual(Dual)> f) { | |
// use a lambda expression | |
return [f](double x) -> double { return f(Dual(x,1)).y; }; | |
} | |
// gradient | |
std::function<RealVec(const RealVec&)> | |
gradient(std::function<Dual(const DualVec&)> F) { | |
auto Df = [F](const RealVec & x) -> RealVec { | |
int n = x.size(); // dimension | |
// promote x to DualVec | |
DualVec xe(x.begin(), x.end()); | |
// vector containing the result | |
RealVec Dfx(n); | |
// compute all components of the gradients by setting xe[i].y = 1 | |
for ( size_t i = 0; i < n; ++i ) { | |
xe[i].y = 1; | |
Dfx[i] = F(xe).y; | |
// set xe[i].y back to zero for next iteration | |
xe[i].y = 0; | |
} | |
return Dfx; | |
}; | |
return Df; | |
} | |
// directional derivative | |
std::function<double(const RealVec&, const RealVec&)> | |
grad_vec_prod(std::function<Dual(const DualVec&)> F) { | |
auto Df = [F](const RealVec & x, const RealVec & y) -> double { | |
DualVec z(x.size()); // NB: x and y should have equal size! | |
// fill z with dual numbers x[i] + y[i]*Nil | |
std::transform(x.begin(), x.end(), y.begin(), z.begin(), | |
[](double xi, double yi){return Dual(xi,yi);}); | |
return F(z).y; | |
}; | |
return Df; | |
} |
autograd.cpp
is given here:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "autograd.hpp" | |
#include "dualmath.hpp" | |
#include <cmath> | |
#include <iostream> | |
#include <vector> | |
template<class T> | |
T norm(const std::vector<T> & x) { | |
T u(0); | |
for ( T xi : x ) { | |
u += xi*xi; | |
} | |
return sqrt(u); | |
} | |
int main() { | |
// create the gradient of the norm function | |
auto grad_norm = gradient(norm<Dual>); | |
// or the gradient of the norm times a vector | |
auto grad_norm_vec_prod = grad_vec_prod(norm<Dual>); | |
// compute some stuff | |
RealVec x = {0.5, 1.5}; | |
double Nx = norm<double>(x); | |
RealVec DNx = grad_norm(x); | |
RealVec u = {0.0, 1.0}; | |
double DNxu = grad_norm_vec_prod(x, u); | |
// print results | |
std::cout << "norm of x = " << Nx << std::endl; | |
std::cout << "gradient of norm(x) = [" << DNx[0] << ", " << DNx[1] << "]" << std::endl; | |
std::cout << "gradient of norm(x) times u = " << DNxu << std::endl; | |
return 0; | |
} |
$ g++ ad_example2.cpp dual.cpp dualmath.cpp autograd.cpp -o ad_example2 $ ./ad_example2 norm of x = 1.58114 gradient of norm(x) = [0.316228, 0.948683] gradient of norm(x) times u = 0.948683