Saturday, March 11, 2017

Build a Neural Network from Scratch in 60 lines of OCaml Code

People have been asking me what is the current development state of Owl (a numerical library in OCaml). Well, I think it is about time to show what Owl can actually do at the moment with its newly added AD (Algorithmic Differentiation) module.

I will demonstrate how to build a small two-layer neural network from scratch in order to learn the hand-written digits in MNIST dataset. First, open `utop`, load `Owl` library, then type `Dataset.download_all ();;` to download all the necessary datasets used in the example.

The following code snippet defines a simple two-layer neural network with `tanh` and `softmax` as the activation function for the first and second layer respectively. Remember to open `Owl` and `Algodiff.AD` modules.


Defining a network seems trivial, but how about the core component in all neural networks: back propagation? It turns out writing up a back propagation in Owl is just as easy as a dozen lines of code. Well, actually 12 lines of code in total :)


The reason for this brevity is because algorithmic differentiation is a generalisation of back propagation. `Owl.Algodiff` module relieves us from manually deriving the derivatives of activation function which is just a laborious and tedious task.

Now, you can use the following code in `utop` to train the model then test the model on the test dataset.


You should be able to see the following output in your terminal. It seems this small neural network works just fine. E.g., our model predicts the following hand-written digit as 6, correct!



How about more complicated ones such as convolutional networks, recurrent neural networks, and etc. Well, you can either define it yourself with `Owl.Algodiff` module, or you can also wait for me to wrap up everything up and add a new module in Owl specifically for neural networks.

In general, `Owl` just makes my life so easy when dealing with these numerical tasks in OCaml. I hope you also find it useful.

No comments: