11

Self-Paced Learning for Machine Learning

 4 years ago
source link: https://towardsdatascience.com/self-paced-learning-for-machine-learning-f1c489316c61?gi=7e483191fd76
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Self-Paced Learning for Machine Learning

Smart way to improve neural network convergence (and find anomalies…)

You torment your neural net by cramming data into it ruthlessly! Usually, when training a machine learning model with Stochastic Gradient Descent (SGD), the training data set gets shuffled. In that way, we make sure that the model sees different data points in no particular order and can learn the task evenly without getting stuck in local optima. However, already in 2009, Bengio et al. showed that a certain sorting is beneficial. They called their approach Curriculum Learning and showed that a machine learning model reaches higher accuracy if its training data is in a specific order. More precisely, easier examples in the beginning and harder ones towards the end. For their experiments, they used an already trained model and let that model decide what data points are easy or hard. A new model then trained on the correctly ordered data set and converged to a higher accuracy than models that were trained on random orders or the opposite curriculum.

While working on my current project, I came across a technique called Self-Paced Learning (SPL). It’s not a new idea and the respective paper was published about 10 years ago. Anyhow, this technique is quite interesting and still important as it helps Stochastic Gradient Descent (SGD) to converge faster and even at higher accuracy. It skips certain data points that are considered to be yet too hard. It is based on Curriculum learning but sorts the data while training. There is no need for an additional pre-trained model which would decide over the ordering. Hence the name Self-Paced Learning .

RvyABbR.jpg!web

Photo by Suzanne D. Williams on Unsplash

Intuition behind SPL

The term of Self-Paced Learning originated from a learning technique used by humans. It allows you to define your speed to suit your learning patterns. SPL can be seen as studying or training a special skill, for instance, maths. When we started learning about maths, we started with counting, then went on to addition, subtraction, and so on. We didn’t hear about matrix multiplication or derivatives until a certain age. In the same way, SPL for machine learning starts with very easy examples and, once learned, continues with harder ones benefitting from the already learned “basics”. I imagine SPL as a kind of narrowing down a task over time. Consider a simple classification task in a 2-dimensional space and a model that needs to split two point clouds at the right spot. The easier samples are far away from the intersecting area. The harder samples are close to the intersecting area. The initial state of the model is a line somewhere in this space. If we start with only easy data points, the model gets gradients that tell it to go in a certain direction. If we started with only hard points, our model would know it is wrong and would get a direction to go to but might be far off again to the other side. Narrowing down the points to the right area helps the model avoid overshooting and converge more smoothly, as seen in the animation I created below.

NrmiaeQ.jpg

Self-Paced Learning on a simple data set

The Algorithm

The trick is very simple. It uses a threshold, that we call lambda . It exists to be compared to the loss values of the data points in the training set. Usually, lambda starts very low at a number close to 0. With each epoch, lambda gets multiplied by a fixed factor greater than 1. The model which is being trained has to calculate the loss values of its training points to perform SGD. Generally, these loss values are getting smaller with further training iterations because the model is getting better at the training tasks and makes fewer mistakes. The threshold lambda is now determining whether a data point is considered easy or hard . Whenever the loss of a data point is below lambda , it is an easy data point. If it is above, it’s considered hard. During training, the back-propagation step is only performed on easy data points and hard ones are skipped. Hence, the model increases the difficulty of the training instances during training whenever it has progressed enough. Certainly, in the beginning, the model may consider no data point to be easy and won’t train at all. Therefore, the authors of SPL introduced a warm-up phase in which no skipping is allowed and only a small subset of the training set is used.

Maths

Loss function to minimize

The SPL paper introduces a two-step learning approach. Here, the loss function is trained twice keeping some variables fixed for each step. The loss functions need to be minimized given the model weights “w” and the variables “v”. In the loss function, we see several terms. The first term “r(w)” is a usual regularization term that helps the model avoid overfitting. It’s also part of other loss functions not related to SPL. The second term is a sum over the data point loss “f(x, y, w)” of our model multiplied by a variable “v”. This variable “v” will decide later on whether the current data point “(x, y)” is easy enough to train on. The third term is a sum over all “v”s multiplied by the threshold “lambda” which we already mentioned in earlier sections. The variables “v” are integers and can only take the values “0” or “1”. In the first learning step, the variables “w” are fixed and only the variables “v” are changed according to the optimization. If you look at the loss function carefully, we see that “lambda”, indeed, acts as a threshold. If “f(x, y, w)” is smaller than lambda and “v” is one, we would subtract something from the regularization term. Hence, with “v=0”, we wouldn’t subtract anything, which is bigger than subtracting something. If “f(x, y, w)” is bigger than lambda and “v=1”, “f(x, y, w)-lambda” would be positive and we would add something. Hence, with “v=0”, we wouldn’t add anything, which is smaller than adding something. In summary, whenever “f(x, y, w)” is smaller than “lambda” the first step optimizes “L” by setting “v” to “1”, otherwise to “0”. The second step fixes the before calculated “v” and optimizes “w”. If “v” was “1”, the usual model update is performed, for instance, back-propagation. If “v” was “0”, the gradients for “f(x, y, w)” will be also 0 and no update is performed (except for the regularization term but you can ignore this for now for a better understanding). Setting the threshold to a very low number, in the beginning, would result in nothing as all “v”s would be “0” because no data point loss would be below the threshold. Therefore, the authors of SPL suggested to have a warm-up phase without SPL for a certain number of iterations and start with SPL afterward.

PyTorch Implementation

In the following code example, I am showing how to implement SPL with PyTorch on a dummy data set. In particular, we will be implementing the training for the animation that I have shown at the beginning of this blog post. At first, we will define a very easy model taking in 2 features and outputting two numbers which define the probability for each class. The output tells us what class the model thinks it’s seeing. To transform the output in probabilities, we are using a softmax function. In the code, I am using a log_softmax function, though. This is due to the loss function I am using later on. In the end, the model trains in the same way.

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.input_layer = nn.Linear(input_size, output_size)

def forward(self, x):
x = self.input_layer(x)
return torch.log_softmax(x, dim=1)

The loss function code can be seen in the next section. Here, we calculate the loss of each point, which is the NLL loss. If the loss is smaller than the threshold, we multiply the loss with one, otherwise with zero. Hence, zero multiplied losses don’t have any effect on the training.

import torch
from torch import Tensor
import torch.nn as nn


class SPLLoss(nn.NLLLoss):
def __init__(self, *args, n_samples=0, **kwargs):
super(SPLLoss, self).__init__(*args, **kwargs)
self.threshold = 0.1
self.growing_factor = 1.3
self.v = torch.zeros(n_samples).int()

def forward(self, input: Tensor, target: Tensor, index: Tensor) -> Tensor:
super_loss = nn.functional.nll_loss(input, target, reduction="none")
v = self.spl_loss(super_loss)
self.v[index] = v
return (super_loss * v).mean()

def increase_threshold(self):
self.threshold *= self.growing_factor

def spl_loss(self, super_loss):
v = super_loss < self.threshold
return v.int()

The training function, in the end, looks as usual. We load a data loader, initialize the model and optimizer, and start iterating over the data set multiple epochs. For simplicity, I left out the plotting function for the animation.

import torch.optim as optim

from model import Model
from dataset import get_dataloader
from loss import SPLLoss


def train():
model = Model(2, 2, 2, 0)
dataloader = get_dataloader()
criterion = SPLLoss(n_samples=len(dataloader.dataset))
optimizer = optim.Adam(model.parameters())

for epoch in range(10):
for index, data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target, index)
loss.backward()
optimizer.step()
criterion.increase_threshold()
return model

The whole project can be found on my GitHub repository . Feel free to play around with it. Also, the plotting function can be found there.

Anomaly Detection

Since the SPL approach is somehow sorting the data points by hardness based on loss, I had the idea to use it for anomaly detection. An anomaly is a data point that isn’t similar to any other data point in the data set, is far off, and might be the result of a false input or a systematic error. If an anomaly occurs in the data set, its loss should be higher than the loss of normal points because machine learning models cannot generalize to errors if they are very rarely presented. The SPL approach should move the threshold over the anomaly in the very end. That way, we can easily classify them as anomalies by observing the order of “activation” of data points, i.e. considering them as easy.

For this experiment, I used the aforementioned code and didn’t run a fixed number of epochs. Instead, I ran the training as long as there are more than 5 data points above the threshold and, hence, considered hard. Once, I have 5 or less, I stop the training and plot them as red points. As you can see in the animation below, the algorithm found an anomaly in the left lower part of the blue cloud. I added this anomaly by taking the furthest point from the orange centroid and changing its class to “orange”.

mm2MnmU.jpg

Illustrating anomalies as red points in the end

Certainly, this example isn’t hard but it illustrates the problem anomaly detection is facing. If there are more dimensions than 2 or 3, the task gets more complex and obvious anomalies as in our example aren’t found that easy.

Conclusion

Why isn’t everyone using SPL? It takes some time to find the right starting and growing factor values for the threshold because this isn’t anything generic and varies based on models, loss functions, and data sets. For the examples I used in this post, I had to give it multiple tries to finally find the right configurations. Imagine you have a very large data set and also a very large model. It is basically infeasible to check the whole progress multiple times before you start the actual training. However, there are multiple other techniques for curriculum learning that fit different training setups. Despite those points, the current one has been a very intuitive idea that is easy to grasp and fun to work with. It’s basically just another set of hyper-parameters you need to optimize ;-)

Resources

Links

Papers


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK