14

[Paper Summary] Distilling the Knowledge in a Neural Network

 3 years ago
source link: https://towardsdatascience.com/paper-summary-distilling-the-knowledge-in-a-neural-network-dc8efd9813cc?gi=d9629638a901
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

[Knowledge Distillation] Distilling the Knowledge in a Neural Network

aiM7viy.jpg!web

Photo by Aw Creative on Unsplash

Problem(s) addressed

The authors start the paper with a very interesting analogy to explain the notion that the requirements for the training & inference could be very different.

The analogy given is that of a larva and its adult form and the fact the requirements of nourishments for the two forms are quite different.

We can easily appreciate that during training the priority is to solve the problem at hand. We end up employing a multitude of techniques and tricks to achieve the goal. i.e. the goal of learning parameters of the model.

For example, you could

  • use an ensemble of networks which is proven to work for many different kinds of problems
  • you could use dropout to generalize better
  • increase the depth of the network,
  • use a larger dataset etc

Also it is important to appreciate that during this quest to learn , the mechanics of machine learning is such that we will explore various paths which while crucial for learning may not be needed during the inference phase. In other words, this extra information could be considered redundant from inference perspective.

This brings us to the requirements for inference where along with accuracy, the runtime performance i.e. the speed of prediction plays an important role as well.

If your product is not usable because it is slow then however accurate it is, it would not matter. Usability wins over accuracy in most of the cases !

The paper aims to address the challenge of how to run accurate models using a network architecture with a smaller number of parameters without sacrificing too much accuracy.

Prior art and its limitations

This is not the first time the problem is being discussed. The notion of training simple networks that use the knowledge of cumbersome model was demonstrated by Rich Caruana et al in the year 2006 in a paper titled Model Compression.

A cumbersome model is the model which has lot of parameters or is an ensemble of models and is generally difficult to setup and run on devices with less computing resources.

In this paper, Hinton refers to the Model compression to give them the credit for proving that it is possible to extract the information from cumbersome models and provide it to the simpler model.

In Model Compression paper, the technique used was to minimize the distance in logits space using RMSE. This paper argues that they build on that insight and propose a more general solution; in other words, a Model Compression technique from Caruana et al is a specific case proposed by Hinton et al.

Required Background knowledge to understand the Key Insights

To appreciate the key insights from this paper, you should have a good intuition as well as a mathematical understanding of what softmax activation function does!

Here I am showing a typical classification network with 3 neurons in the last output layer. This means that we have 3 classes. The activation function used in typical classification problems is the softmax function (in the last layer). For our discussion, it does not matter what activation functions are used in the hidden layers.

FR3muyZ.gif

Source — Author of the article

Above animation is showing the softmax activation formula i.e

q_i = exp(z_i) / sum(exp(z_j)where j = 1 to 3
where q_i corresponds to the value of neuron i in the last layerThus the numerator corresponds to the exponentiated value of logit provided by a neuron whereas the denominator is the sum of all the logits in the exponential space.

But we are after why softmax? This can be answered by dividing this question into two sub-questions

  • Why do we exponentiate (the numerator part)?
  • Why do we normalize (the denominator part)?

The answer to both these questions is that we desire a probability distribution as the output of the network.

Now any probability distribution has to respect two important properties :

  • All entries in the distribution should be positive
  • All entries should sum up to 1

The exponential function has the magic power to convert a negative real into the positive real so this addresses our first requirement. Next, the normalization (i.e. dividing each entry by the sum of all entries) is what makes it a distribution.

But now you may wonder why we need the probability distribution as the output?

In classification problems, you use the ground truth labels as one-hot encoded vectors. A one-hot encoded vector is nothing but a probability distribution where only one entry gets all the probabilities. Therefore the task (objective) is to compare this ground truth probability distribution with the predicted probability distribution. This is why we want our output to be the probability distribution provided to use by the all-mighty softmax activation function. The comparison between these two probability distributions is done using cross-entropy loss function.

The Key Insights

Softmax works well for cross-entropy loss however it has one issue which is that that in the process of giving importance to the most likely class it pushes the rest of the classes towards very small values.

The paper takes an example from hand digit classification where it highlights the scenario that an example image of 2 may be closer or similar to an example image of 3 than it is to that of 7.

How close or similar the examples are is of much importance to understand what the network has actually learned !

The key insight here is that softmax function tends to hide this relative similarity between the other classes and this information if available could play a vital role in training the distilled networks.

The second key insight is about how to highlight the relative similarities between the examples of classes while remaining in the realm of softmax.

The authors figured that if we make the values of logits lower (i.e. the output of neurons of the last layer) before passing them to the exponential function than we get a smoother distribution.

Smoother here means that unlike regular softmax there is no big spike corresponding to one entry.

To make the output of the logits lower you now need a number to divide them. That number has a symbol called T and the authors call it the temperature. The higher the temperature the smoother the distribution.

They modified the softmax function as shown below

q_i = exp(z_i/T) / sum(exp(z_j/T)

YVfMjef.png!web

The analogy here is that of distillation where you use temperature to distill the impurities. However, the value of T is of high importance and is something that you have to find by experimentation. This is why it is a hyper-parameter.

Here is a code snippet showing the impact of different values of T on the output of the softmax function

FbyMR3E.png!web

As you can see — the higher the temperature (T), the smoother the obtained distribution. The value of T=1 corresponds to regular softmax behavior.

How does it work?

First, let me introduce some new terminology here.

Teacher model. The original (cumbersome) model is called the teacher model since we are extracting the knowledge from it.

Student model. The new model with fewer parameters is called a student model since we are distilling information into it.

Soft labels. The output of the teacher model where softmax with Temperature greater than 1 (T>1) is used.

Soft predictions. The output of the student model where softmax with Temperature greater than 1 (T>1) is used.

Hard predictions. When the regular softmax is used in the student model

Hard labels. The ground truth label in a one-hot encoded vector form.

The setup of the training process of the student is explained in the below figure.

AjymeeU.png!web

Source — https://nervanasystems.github.io/distiller/knowledge_distillation.html

You essentially end up having two loss terms. The first loss term uses the soft labels (from Teacher) & soft predictions (from Student) and the second loss term uses hard prediction (from Student) and hard labels. You can always configure the contribution of these two terms.

The authors conducted their experiments on MNSIT and Voice Recognition problems and obtained excellent results.

ABjymui.png!web
Results from the paper for Voice Recognition in Android

Various links and details

Is there an open-source implementation of the paper?

There are multiple implementations on Github and it is very simple to implement. Here is a link to a repository where multiple KD losses are implemented

https://github.com/karanchahal/distiller

Was the paper published in a conference?

Yes. This paper was accepted at NIPS 2014 and has over 3000 citations

Link to the paper — https://arxiv.org/abs/1412.6550

Is there a video explaining the paper?

Yes. I have created the youtube video for this paper.

My opinions and takeaways

  • This is the foundational paper that jumpstarted the research area of Knowledge Distillation.
  • The paper is well written and if you have a good intuition and understanding of mathematical properties of softmax function then the notion of temperature would make sense. This is the reason I have dedicated a good portion of the article on explaining the significance of softmax in classification networks
  • The technique (i.e. smoother softmax) is still in use and is often complemented by other methods as per the requirements of different problems and architectures.

I would be doing more paper summary articles on Knowledge Distillation as the follow up to this article.

Hope you enjoyed the summary, it is possible that I may have misunderstood/misinterpreted parts of the paper, therefore if any, the mistakes are mine and not that of the original paper authors.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK