43

Gaining insights on transfer learning with FlashTorch

 4 years ago
source link: https://www.tuicool.com/articles/jQjyUbu
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Source: Deep Learning on Medium

Open source feature visualisation toolkit for neural networks in PyTorch

M7nEBrU.png!web
Visualisation of what AlexNet “sees” in these images of birds, using FlashTorch. Source

Setting the scene

In my last post , I gave an overview of feature visualisation as a field of research and introduced FlashTorchan open source feature visualisation toolkit for neural networks built in PyTorch.

The package is available to install via pip . Check out the GitHub repo for the source code. You can also play around with it in this notebook hosted on Google Colab , without needing to install anything!

Feature visualisation with FlashTorch

Previously , I showed you how to use FlashTorch to visualise what convolutional neural networks (CNNs) “perceive” within inputs images. I did so by creating saliency maps from AlexNet which was pre-trained on the ImageNet classification task. I picked three classes from the task (great grey owl, peacock and toucan) and used images of these classes to inspect what AlexNet has learnt to focus the most within images in identifying these objects.

In essence, I took a network that I knew would perform well on what I asked it to do and examined its perception. This is interesting in its own right, in that I want us to shift our focus from only looking at the test accuracy and start asking what the network is actually doing.

What is the neural net perceiving

?

Why

does it behave the way it does?

How can we

explain its decisions/predictions?

I created FlashTorch to make answering questions like this easier. And these questions are not just for when you have well-performing networks!

In reality, it is far more common to have a network that doesn’t perform as well as you want it to. But here again, we are often haunted by accuracy. We tend to jump straight to training when we see a poor performance, without spending much time to understand why it performs so poorly.

FlashTorch can help you do that, and I’d like to demonstrate it with an example in the context of transfer learning .

Transfer leaning

Transfer leaning in machine learning is a form of knowledge transfer — a method where a model that is trained on one task is used, often as a starting point, for another task. The amount of additional training required for the new task correlates with the similarity of the original & new tasks, availability of training data etc.

rQVfqqU.png!web
Traditional Learning vs Transfer Learning. Source

Transfer learning is often employed in computer vision and natural language processing tasks, as it helps us save compute/time resources by taking advantage of the previous training.

For example, a network trained on ImageNet (1000 classes) can be repurposed as a dog identifier without much additional training. Or, word embeddings trained on a large corpus of text (such as Word2Vec from Google) can be introduced to another deep neural network to extract vector representation of words from a new corpus.

ImageNet → flower classifier

To test the power of transfer learning, I decided to make DenseNet pre-trained on the ImageNet task into a flower classifier using 102 Category Flower Dataset .

It turns out that the model, without any further training, performs really poorly — a whopping test accuracy of 0.1%! If you’ve done the maths already… I would be better of randomly guessing it myself.

Intuitively, this perhaps makes sense. There are only a handful of flower classes included in the original ImageNet dataset, so it’s not too difficult to image that asking the model to identify 102 species of flowers is a push.

Intuition is nice, but I want to make this concrete before moving on to training.

Let’s use FlashTorch to create saliency maps and visualise what the network is ( not ) seeing. We’re going to use this image of foxgloves as an example.

iuARRvn.png!web
R7Z77bb.png!web

What we can appreciate here is that the network, without additional training, is paying attention to the shape of flower cups . But there are many flowers with similar shape (think bluebells, for instance).

For us humans, it might be obvious (even if we didn’t know the name of the specie) that what makes this flower unique is the mottled patten inside flower cups. However, the network currently doesn’t know where to pay attention to , apart from the general shape of the flower, because it never really needed to in the old task (ImageNet classification).

Now that we have insights on why the network is doing poorly, I feel ready to train it. Eventually, after trial and error, the trained model managed to achieve 98.7% test accuracy.

Which is great! … but can we explain why ?

What is it that the network is seeing now , that it wasn’t before?

FBbeMjq.png!web

Pretty different right?

The network has learnt to pay less attention to the shape of the flower, and focus intensely to those mottled pattern :)

Showing what the neural nets have learnt is useful. Taking it to another level and explaining the process of how neural nets learn is another powerful application of feature visualisation techniques.

Step forward (not away!) from accuracy

With feature visualisation techniques, not only can we obtain better understanding on what the neural network has learnt about objects, but also we are better equipped to:

  • Diagnose what the network gets wrong and why
  • Spot and correct biases in algorithms
  • Step forward from only looking at accuracy
  • Understand why the network behaves in the way it does
  • Elucidate mechanisms of how neural nets learn

Use FlashTorch today!

If you have projects which utilise CNNs in PyTorch, FlashTorch can help you make your projects more interpretable and explainable .

Please let me know what you think if you use it! I would really appreciate your constructive comments, feedback and suggestions :pray:

Thanks, and happy coding!


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK