23

Google’s Approach To Flexibility In Machine Learning

 4 years ago
source link: https://mc.ai/googles-approach-to-flexibility-in-machine-learning/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Google’s Approach To Flexibility In Machine Learning

Photo by MusicFox Fx on Unsplash

Thinking of Machine Learning the first frameworks coming to your mind are Tensorflow and PyTorch, which are currently the state of the art frameworks if you want to work with Deep Neural Networks. Technology is changing rapidly and more flexibility is needed, so Google researchers are developing a new high performance framework for the open source community: FLAX

The base for the calculations is JAX instead of numpy, which is also a Google research project. One of the biggest advantages of JAX is the use of XLA, a special compiler for linear algebra, that enables execution on GPUs and TPUs as well . For those who do not know TPU (tensor processing unit) is a specific chip optimized for Machine Learning. JAX reimplements parts of numpy to run your functions on a GPU/TPU.

Flax focuses on keypoints like:

  • easy to read code
  • prefers duplication , instead of bad abstraction or inflated functions
  • helpful error message , seems they learned from the Tensorflow error messages
  • easy expandability of basic implementations

Enough praises, now let’s start coding.

Because the MNIST-Example becomes boring I will build a Image Classification for the Simpsons Family, unfortunately Maggie is missing in the dataset .

Sample Images of the Dataset

First we install the necessary libraries and unzip our dataset. Unfortunatly you will still need Tensorflow at this point, because Flax misses a good data input pipeline.

Now we import the libraries. You see we have two “versions” of numpy, the normal numpy lib and the one part of the API that JAX implements. The print statement prints CPU, GPU or TPU out according to the available hardware.

For training and evaluation we first have to create two Tensorflow datasets and convert them to numpy/jax arrays, because FLAX doesn’t take TF data types. This is currently a bit hacky, because the evaluation method doesn’t take batches. I had to create one large batch for the eval step and create a TF feature dictionary from it, which is now parsable and can be fed to our eval step after each epoch.

The Model

The CNN-class contains our convolutional neural network, when you are familiar with Tensorflow/Pytorch you see its pretty straight forward. Every call of our flax.nn.Conv defines a learnable kernel. I used the MNIST-Example and extended it with some additional layers. At the end we have our Dense-Layer with four output neurons, because we have a four-class problem.

Unlike in Tensorflow the activation function is called explicitly, this makes it very easy to test new and own written activation functions. FLAX is based on the module abstraction and both initiating and calling the network is done with the apply function.

Metrices in FLAX

Of course we want to measure how good our network becomes. Therefore, we compute our metrices like loss and accuracy. Our accuracy is the computed with the JAX library, instead of numpy because we can use JAX on TPU/GPU.

To measure our loss we use the Cross Entropy Loss, unlike in Tensorflow it is calculated by yourself, we do not have the possibility to use ready-made loss objects yet. As you can see we use @jax.vmap as a function decorator for our loss function. This vectorizes our code for running on batches efficiently.

How does the cross_entropy_loss work? @jax.vmap takes both arrays, logits and label, and performes our cross_entropy_loss on each pair, thus allowing the parallel calculation of a batch. The cross entropy formula for a single example is:

Our ground truth y is 0 or 1 for one of the four output neurons, therefore we do not need the sum formula in our code, because we just calculate the log(y_hat) of the correct label. The mean in our loss calculation is used because we have batches.

Training

In our train step we use again a function decorator, @jax.jit , for speeding up our function. This works very similiar to Tensorflow. Please have in mind batch[0] is our image data and batch[1] our label.

The loss function loss_fn returns the loss for our current model, optimizer.target , and our jax.grad() calculates its gradient. After the calculation we apply the gradient like in Tensorflow.

The eval step is very simple and minimalistic in FLAX. Please note that the complete evaluation dataset is passed to this function.

After 50 epochs we have a very high accuracy. Of course we can continue to tweak the model and optimize hyperparameter.

For this experiment I used Google Colab, so if you want to test it yourself create a new environment with a GPU/TPU and import my notebook from Github . Please note that FLAX is not working under Windows at the moment.

Conclusions

It is important to note that FLAX is currently still in alpha and is not an official Google product.

The work so far gives hope for a fast, lightweight and highly customizable ML framework . What is completely missing so far is a data-input pipeline, so Tensorflow still has to be used. The current set of optimizers is unfortunately limited to ADAM and SGD with Momentum. I especially liked the very strict forward direction of how to use this framework and the high flexibility. My next plans are to develop some activation features that are not yet available. Also a speed comparison between Tensorflow, PyTorch and FLAX would also be very interesting.

If you want to try a little bit with FLAX, check out the documentation and their github page .

And if you want to download my example with dataset just clone SimpsonsFaceRecognitionFlax .


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK