I’m excited to announce the first official release of the backprop library (currently at version 0.1.3.0 on hackage)! backprop is a library that allows you write functions on your heterogeneous values like you would normally and takes them and (with reverse-mode automatic differentiation) automatically generate functions computing their gradients. backprop differs from the related ad by working with functions using and transforming different types, instead of only one monomorphic scalar type.
This has been something I’ve been working on for a while (trying to find a good API for heterogeneous automatic differentiation), and I’m happy to finally find something that I feel good about, with the help of a lens-based API.
As a quick demonstration, this post will walk through the creation of a simple neural network implementation (inspired by the Tensorflow Tutorial for beginners) to learn handwritten digit recognition for the MNIST data set. To help tell the story, we’re going to be implementing it “normally”, using the hmatrix library API, and then re-write the same thing using backprop and hmatrix-backprop (a drop-in replacement for hmatrix).
For this network, we’re not going to be doing anything super fancy. Our “neural network” will just be simple series of matrix multiplications, vector additions, and activation functions. We’re going to make a neural network with a single hidden layer using normal Haskell data types, parameterized by two weight matrices and two bias vectors.
The purpose of the MNIST challenge is to take a vector of pixel data (28x28, so 784 elements total) and classify it as one of ten digits (0 through 9). To do this, we’re going to be building and training a model that takes in a 784-vector of pixel data and produces a 10-item one-hot vector of categorical predictions (which is supposed to be 0 everywhere, except for 1 in the category we predict the input picture to be in).
For our types, our imports are pretty simple:
Net type will just be a simple collection of all of the matrices and vectors we want to optimize:
We’re using the matrix types from
L 250 784 is a matrix – or, as we are using it, a linear transformation . An
R 250 is a 250-vector, etc.
Via the lens library, four lenses are generated:
These lenses give us ways to access components of our data type:
I’m also going to define
Fractional instances for our network, which makes it really easy to write code to “update” our network (we can just add and scale our networks with each other). To do this, I’m going to be using one-liner-instances to make a
Num instance automatically using GHC Generics:
-- source: https://github.com/mstksg/inCode/tree/master/code-samples/backprop/intro-normal.hs#L61-L73 instance Num Net where (+) = gPlus (-) = gMinus (*) = gTimes negate = gNegate abs = gAbs signum = gSignum fromInteger = gFromInteger instance Fractional Net where (/) = gDivide recip = gRecip fromRational = gFromRational
First, let’s look at the picture if we just try to compute the error function for our network directly.
Running our network is pretty textbook:
runNet takes a network and produces the
R 784 -> R 10 function it encodes.
#> :: L m n -> R n -> R m is the matrix-vector multiplication operator from hmatrix (its static module); we can also just use
Num) to add vectors together.
We can define the logistic function using only
Num operations, which operate component-wise for hmatrix types.
softMax requires us to
norm_1 (to get the absolute sum of all items in a vector) from hmatrix, and also
konst (to generate a vector of a single item repeated). Still, though, pretty much a straightforward implementation of the mathematical definitions.
This neural network now makes predictions. However, in order to train a network, we actually need a scalar error function that we want to minimize. This is a function on the network that, given an input and its expected output, computes how “bad” the currently network is. It computes the error between the output of the network and the expected output, as a single number.
To do this, we will be using the cross entropy between the target output and the network output. This is a standard error function for classification problems; smaller cross-entropies indicate “better” predictions.
Computing the cross entropy involves using
<.> (the dot product) from hmatrix, but other than that we can just use
Floating) and negation (from
At this point, we are supposed to find a way to compute the gradient of our error function. It’s a function that computes the direction of greatest change of all of the components in our network, with respect to our error function.
The gradient will take our
Net -> Double error function and, given a current network, and produce a “gradient”
Net whose components contain the derivative of each component with respect to the error. It tells us how to “nudge” each component to increase the error function. Training a neural network involves moving in the opposite direction of the gradient, which causes the error to go down.
netErr’s definition, it is not obvious how to compute our gradient function. Doing so involves some careful multi-variable vector calculus and linear algebra based on our knowledge of the operations we used. For simple situations we often do it by hand, but for more complicated situations, this becomes impractical. That’s where automatic differentiation comes into play.
We’ve gone as far as we can go now, so let’s drop into the world of backprop and see what it can offer us!
Let’s see what happens if we compute our error function using backprop, instead!
We’ll switch out our imports very slightly:
First, we add
Numeric.Backprop, the module where the magic happens.
Second, we switch from
Numeric.LinearAlgebra.Static.Backprop (from hmatrix-backprop), which exports the exact same1 API as
Numeric.LinearAlgebra.Static, except with numeric operations that are “lifted” to work with backprop. It’s meant to act as a drop-in replacement, and, because of this, most of our actual code will be more or less identical.
Writing functions that can be used with backprop involves tweaking the types slightly – instead of working directly with values of type
a, we work with
BVars (backpropagatable variables) containing
BVar s a.
For example, let’s look a version
softMax that works with backprop:
R 10 -> R 10, its type signature is now
BVar s (R 10) -> BVar s (R 10). Instead of working directly with
R 10s (10-vectors), we work with
BVar s (R 10)s (
BVars containing 10-vectors).
norm_1 for vectors only) lifted to work with
BVars also have
Floating instances, so
/ already work out-of-the-box.
With only a minimal and mechanical change in our code,
softMax is now automatically differentiable!
One neat trick — because of
BVar’s numeric instances, we can actually re-use our original implementation of
To run our network, things look pretty similar:
-- source: https://github.com/mstksg/inCode/tree/master/code-samples/backprop/intro-backprop.hs#L71-L79 runNet :: Reifies s W => BVar s Net -> BVar s (R 784) -> BVar s (R 10) runNet n x = z where y = logistic $ (n ^^. weights1) #> x + (n ^^. bias1) z = softMax $ (n ^^. weights2) #> y + (n ^^. bias2)
Again, pretty much the same, except with the lifted type signature. One notable difference, however, is how we access the weights and biases. Instead of using
^. for lens access, we can use
^^., for lens access into a
Some insight may be gleamed from a comparison of their type signatures:
^. is access to a value using a lens, and
^^. is access to a value inside a
BVar using a lens.
Using lenses like this gives us essentially frictionless usage of
BVars, allowing us to access items inside data types in a natural way. We can also set items using
.~~ (to parallel
.~), access constructors in sum types using
^^? (which can be used to implement pattern matching) and get matches for multiple targets using
Because of these, our translation from our normal
runNet to our backprop
runNet is more or less completely mechanical.
At this point, the implementation of our updated error function should not be too surprising:
-- source: https://github.com/mstksg/inCode/tree/master/code-samples/backprop/intro-backprop.hs#L64-L87 crossEntropy :: Reifies s W => BVar s (R 10) -> BVar s (R 10) -> BVar s Double crossEntropy targ res = -(log res <.> targ) netErr :: Reifies s W => BVar s (R 784) -> BVar s (R 10) -> BVar s Net -> BVar s Double netErr x targ n = crossEntropy targ (runNet n x)
Both of these implementations are are 100% lexicographically identical to our original ones – the only difference is that
<.> comes from
Numeric.LinearAlgebra.Static.Backprop. Other than that, we can simply re-use
log and negation.
Time to gradient descend!
To break this down:
To train our network, we move in the opposite direction of our gradient. That means
net0 - 0.02 * gr– we subtract the gradient (scaled by 0.02, a learning rate, to ensure we don’t overshoot our goal) from our network.
Recall that we implemented scaling and subtraction of
Nets when we wrote its
To compute our gradient, we use
If we ignore the RankN type/
Reifiessyntax noise, this can be read as:
Which says “give a function from a
b, get the gradient function, from
ato its gradient”
This can be contrasted with
Which “runs” the actual
a -> bfunction that the
BVar s a -> BVar s bencodes.
We want to use
Net -> Doubleerror function (or, more accurately, our
BVar s Net -> BVar s Doublefunction). That’s exactly what
constVarsimply lifts a value into a
BVar, knowing that we don’t care about its gradient.
This means that we have:
We can pass this function to
gradBPto get the gradient of the network
Netwith respect to the
That’s really the entire gradient computation and descent code!
Kind of anti-climactic, isn’t it?
Taking it for a spin
In the source code I’ve included some basic code for loading the mnist data set and training the network, with some basic evaluations.
The above command will cause the program to compile itself, installing the necessary GHC (if needed) and also the automatically download the dependencies from hackage. backprop manages the automatic differentiation, and stack manages the automatic dependency management :)
If you are following along at home, you can download the mnist data set files and uncompress them into a folder, and run it all with:
$ ./intro-backprop PATH_TO_DATA Loaded data. [Epoch 1] (Batch 1) Trained on 5000 points. Training error: 13.26% Validation error: 13.44% (Batch 2) Trained on 5000 points. Training error: 9.74% Validation error: 11.08% (Batch 3) Trained on 5000 points. Training error: 6.84% Validation error: 8.71% (Batch 4) Trained on 5000 points. Training error: 6.84% Validation error: 8.53% (Batch 5) Trained on 5000 points. Training error: 5.80% Validation error: 7.55% (Batch 6) Trained on 5000 points. Training error: 5.20% Validation error: 6.77% (Batch 7) Trained on 5000 points. Training error: 4.44% Validation error: 5.85%
After about 35000 training points, we get down to 94% accuracy on our test set. Neat!
A More Nuanced Look
That’s the high level overview – now let’s look a bit at the details that might be helpful before you go strike it out on your own.
The main API revolves around writing a
BVar s a -> BVar s b function (representing an
a -> b one), and then using one of the three runners:
-- Return the result and gradient backprop :: (Num a, Num b) => (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> (a, b) -- Return the result evalBP :: (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> b -- Return the gradient gradBP :: (Num a, Num b) => (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> a
evalBP comes with virtually zero performance overhead (about 4%) over writing your functions directly, so there’s pretty much no harm in writing your entire application or library in
gradBP, however, carries measurable performance overhead over writing your gradient code “manually”, but this heavily depends on exactly how complex the code you are backpropagating is. The overhead comes from two potential sources: the building of the function call graph, and also potentially from the mechanical automatic differentiation process generating different operations than what you might write by hand. See the README for a deeper analysis.
You might have also noticed the RankN type signature (the
forall s. ...) that I glossed over earlier. This is here because backprop uses the RankN type trick (from
Control.Monad.ST and the ad library) for two purposes:
- The prevent leakage of variables from the function. You can’t use
evalBPto get a
BVarout in the end, just like you can’t use
runSTto get an
STRefout in the end. The type system prevents these variables from leaking out of the backprop/ST world.
Reifies s Wconstraint allows backprop to build a Wengert Tape of your computation, which it uses internally to perform the reverse-mode automatic differentiation (The
Wstands for Wengert).
Discussion on Num
Note that at the moment,
(^^.), and most
BVar-based operations all require a
Num instance on the things being backpropagated. This is an API decision that is a compromise between different options, and the README has a deeper discussion on this.
For the most part, writing a
Num instance for your types is some easy and quick boilerplate if your type derives Generic (and we can use one-liner-instances), like we saw above with the
Num instance for
One potential drawback is that requiring a
Num instance means you can’t directly backpropagate tuples. This can be an issue because of how pervasive tuples are used for currying/uncurrying, and also because automatically generated prisms use tuples for constructors with multiple fields.
To mitigate this issue, the library exports some convenient tuples-with-Num-instances in
Numeric.Backprop.Tuple. If you are writing an application, you can consider also using the orphan instances in NumInstances.
Lifting your own functions
Of course, all of this would be useless unless you had a way to manipulate
BVars. The library does provide lens-based accessors/setters. It also provides
Floating instances for
BVars so you can manipulate a
BVar s a just like an
a using its numeric instances. We leveraged this heavily by using
/, etc., and even going as far as re-using our entire
logistic implementation because it only relied on numeric operations.
However, for our domain-specific operations (like matrix multiplication, norms, and dot products), we needed to somehow lift those operations into backprop-land, to work with
This isn’t something that end-users of the library should be expected to do – ideally, this would be done by library maintainers and authors, so that users can use their types and operations with backprop. However, writing them is not magical – it just requires providing the result and the gradient with respect to a final total derivative. For example, let’s look at the implementation of the lifted
(<.>), we provide a function that, given its inputs
y, gives the result (
x H.<.> y), and also its gradient with respect to the total derivative of the result. For more details on the math, see the documentation for
If you’re interested in writing your own lifted operations, take a look at the source of the lifted hmatrix module, which lifts (most) of the functionality of hmatrix for backprop. (And if you’re good at computing gradients, check out the module notes for some of the current unimplemented operators – any PR’s would definitely be appreciated!)
The world is now your oyster! Go out and feel emboldened to numerically optimize everything you can get your hands on!
If you want to see an application to a more complex neural network type (and if you’re curious at how to implement the more “extensible” neural network types like in my blog series on extensible neural networks), I wrote a quick write-up on how to apply those type-level dependent programming techniques to backprop (also available in literate haskell).
Really, though, the goal of backprop is to allow you to automatically differentiate and optimize things you have already written (or plan to write, if only you had the ability to optimize them). Over the next few weeks I’ll be lifting operations from other libraries in the ecosystem. Let me know if there are any that you might want me to look at first! Be also on the lookout for some other posts I’ll be writing on applying backprop to optimize things other than neural networks.
If you have any questions, feel free to leave a comment. You can also give me a shout on twitter (I’m @mstk), on freenode’s #haskell (where I am usually idling as jle`), or on the DataHaskell gitter (where I hang out as @mstksg).
Please let me know if you end up doing anything interesting with the library — I’d love to hear about it! And, until next time, happy Haskelling!
More or less. See module documentation for more information.↩