22

Few-shot Learning with Prototypical Networks

 3 years ago
source link: https://towardsdatascience.com/few-shot-learning-with-prototypical-networks-87949de03ccd?gi=722d37803113
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Learn to code a Few-shot Learning algorithm on the Omniglot dataset

QBBNnuQ.jpg!web

Image credit: https://unsplash.com/photos/kZO9xqmO_TA

Introduction

We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples overtakes any classical Machine Learning algorithm. A lot of people think the Human Kind is being overthrown by AI, but here is the truth: to be able to well differentiate classes, a classifier is often fed with several thousands of images per class… while we only need two or three!

Prototypical Networks is an algorithm introduced by Snell et al. in 2017 (in “Prototypical Networks for Few-shot Learning”) that addresses the Few-shot Learning paradigm. Let’s understand it step by step with an example. In this article, our goal is to classify images. The code provided is in PyTorch, available here.

The Omniglot dataset

In Few-shot Learning, we are given a dataset with few images per class (1 to 10 usually). In this article, we will work on the Omniglot dataset, which contains 1,623 different handwritten characters collected from 50 alphabets. This dataset can be found in this GitHub repository . I used the “images_background.zip” and the “images_evaluation.zip” files.

7BNf6nn.png!web

Examples of characters found in the Omniglot dataset

As suggested in the official paper, data augmentation is performed to increase the number of classes. In practice, all the images are rotated by 90°, 180° and 270°, each rotation resulting in an additional class. Once this data augmentation is performed, we have 1,623 * 4 = 6,492 classes. I split the whole dataset into a training set (images of 4,200 classes), and a testing set (images of 2,292 classes).

Select a sample

To create a sample, Nc classes are randomly picked among all classes. For each class we have two sets of images: the support set of size Ns and the query set of size Nq.

u6ruUzU.jpg!web

Illustration of a sample of Nc classes, each containing a support set and a query set

Embed the images

“Our approach is based on the idea that there exists an embedding in which points cluster around a single prototype representation for each class.” claim the authors of the original paper.

In other words, there exists a Mathematical representation of the images, in which images of the same class gather in groups called clusters. The main advantage of working in that embedding space is that two images that look the same will be close to each other, and two images that are completely different will be far away.

In our case, with the Omniglot dataset, the embedding block takes (28x28x3) images as inputs and returns column 64-dimensional points. The image2vector function is composed of 4 modules. Each module consists of a convolutional layer, a batch normalization, a ReLu activation function and a 2x2 max pooling layer.

jqQNVfb.jpg!web

The 4 modules of the image2vector function

Compute the class prototypes

In this step we compute a prototype for each cluster. Once the support images are embedded, vectors are averaged to form a class prototype, a kind of “delegate” for that class.

Znay6vf.png!web

where v(k) is the prototype of class k, f_phi is the embedding function and xi are the support images.

eUzAf2B.jpg!web

One prototype is computed per class

Compute distances between queries and prototypes

This step consists in classifying the query images. To do so, we compute the distance between each image and the prototypes. Metric choice is crucial here, and the inventors of Prototypical Networks must be credited to their choice of distance: the Euclidean distance.

Once distances are computed, a softmax is performed over them to get probabilities of belonging to each class.

Compute the loss and backpropagate

Prototypical Networks learning phase proceeds by minimizing the negative log-probability, also called log-softmax loss. The main advantage of using a logarithm is to drastically increase the loss when the model fails to predict the right class.

The backpropagation is performed via Stochastic Gradient Descent.

Launch training

The whole sequence described above forms an episode. And the training phase contains several episodes. I tried to reproduce the results of the original paper. Here are the training settings:

  • Nc: 60 classes
  • Ns: 1 or 5 support points / class
  • Nq: 5 query points / class
  • 5 epochs
  • 2000 episodes / epoch
  • Learning Rate initially at 0.001 and divided by 2 at each epoch

The training took 30 min to run.

Results

Once the ProtoNet is trained, we can test it with new data. We select samples in the testing set in a similar way. The support set is used to compute de prototypes, and then each point of the query set is labelled according to the shorter distance to prototypes.

For the testing I tried 5-way and 20-way scenarios. I took the same number of support and query points than during the training phase. The tests were performed on 1000 episodes.

The results are presented in the table below. “5-way 1-shot” means Nc = 5 and Ns = 1.

Y3Efu2u.png!web
Obtained VS paper results

I obtained similar results than the original paper, slightly better in some cases. This may be due to the sampling strategy which is not specified in the paper. I used random sampling at each episode.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK