14

From PyTorch to PyTorch Lightning — A gentle introduction

 4 years ago
source link: https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09?gi=5c151e4e5111
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

From PyTorch to PyTorch Lightning — A gentle introduction

This post answers the most frequent question about why you need Lightning if you’re using PyTorch.

PyTorch is extremely easy to use to build complex AI models. But once the research gets complicated and things like multi-GPU training, 16-bit precision and TPU training get mixed in, users are likely to introduce bugs.

PyTorch Lightning solves exactly this problem. Lightning structures your PyTorch code so it can abstract the details of training. This makes AI research scalable and fast to iterate on.

Who is PyTorch Lightning For?

1*oEkULsraelP7wNH6F8_zGg.png?q=20
PyTorch Lightning was created while doing PhD research at both NYU and FAIR

PyTorch Lightning was created for professional researchers and PhD students working on AI research.

Lightning was born out of my Ph.D. AI research at NYU CILVR and Facebook AI Research . As a result, the framework is designed to be extremely extensible while making state of the art AI research techniques (like TPU training) trivial.

Now the core contributors are all pushing the state of the art in AI using Lightning and continue to add new cool features.

1*9fnIoPevuZHJnvNqlte40g.jpeg?q=20

However, the simple interface gives professional production teams and newcomers access to the latest state of the art techniques developed by the Pytorch and PyTorch Lightning community.

Lightning counts with over 96 contributors , a core team of 8 research scientists , PhD students and professional deep learning engineers.

1*SjW_ILtdER7gbBMIzxicLA.png?q=20

it is rigorously tested

1*EMpOIboMGGz9F1mNtQGp_g.png?q=20

and thoroughly documented

1*b81_j__xv8M0Bb6nFTXbAA.png?q=20

Outline

This tutorial will walk you through building a simple MNIST classifier showing PyTorch and PyTorch Lightning code side-by-side. While Lightning can build any arbitrarily complicated system, we use MNIST to illustrate how to refactor PyTorch code into PyTorch Lightning.

The full code is available at this Colab Notebook .

The Typical AI Research project

In a research project, we normally want to identify the following key components:

  • the model(s)
  • the data
  • the loss
  • the optimizer(s)

The Model

Let’s design a 3-layer fully-connected neural network that takes as input an image that is 28x28 and outputs a probability distribution over 10 possible labels.

First, let’s define the model in PyTorch

1*wS0o4efr2DQnDN6NJUTMvg.png?q=20

This model defines the computational graph to take as input an MNIST image and convert it to a probability distribution over 10 classes for digits 0–9.

1*DgYiXo_5v3Zp68qGONosWw.png?q=20

3-layer network (illustration by: William Falcon)

To convert this model to PyTorch Lightning we simply replace the nn.Module with the pl.LightningModule

1*oq6G91hMTqHlvxgGfsgtcw.png?q=20

The new PyTorch Lightning class is EXACTLY the same as the PyTorch, except that the LightningModule provides a structure for the research code.

Lightning provides structure to PyTorch code

1*_mbNZKVOUQLLAAxMSl9bfg.png?q=20

See? The code is EXACTLY the same for both!

This means you can use a LightningModule exactly as you would a PyTorch module such as prediction

1*clq_bui-z1xi3Y-Ih_qhFg.png?q=20

Or use it as a pretrained model

1*01K8BlcocePCGmxtEBziwg.jpeg?q=20

The Data

For this tutorial we’re using MNIST.

0*XGBPxYSO2-hGEXNV.png?q=20

Source: Wikipedia

Let’s generate three splits of MNIST, a training, validation and test split.

This again, is the same code in PyTorch as it is in Lightning.

The dataset is added to the Dataloader which handles the loading, shuffling and batching of the dataset.

In short, data preparation has 4 steps:

  1. Download images
  2. Image transforms (these are highly subjective).
  3. Generate training, validation and test dataset splits.
  4. Wrap each dataset split in a DataLoader

1*IUDtNdGW-Cpk6Gc0mQKZVw.jpeg?q=20

Again, the code is exactly the same except that we’ve organized the PyTorch code into 4 functions:

prepare_data

This function handles downloads and any data processing. This function makes sure that when you use multiple GPUs you don’t download multiple datasets or apply double manipulations to the data.

This is because each GPU will execute the same PyTorch thereby causing duplication. ALL of the code in Lightning makes sure the critical parts are called from ONLY one GPU.

train_dataloader, val_dataloader, test_dataloader

Each of these is responsible for returning the appropriate data split. Lightning structures it this way so that it is VERY clear HOW the data are being manipulated. If you ever read random github code written in PyTorch it’s nearly impossible to see how they manipulate their data.

Lightning even allows multiple dataloaders for testing or validating.

The Optimizer

Now we choose how we’re going to do the optimization. We’ll use Adam instead of SGD because it is a good default in most DL research.

1*r9w_K-bagPGRHIi_TgwjLQ.jpeg?q=20

Again, this is exactly the same in both except it is organized into the configure optimizers function.

Lightning is extremely extensible. For instance, if you wanted to use multiple optimizers (ie: a GAN), you could just return both here.

1*Zf3m5rYC8RbKuFhtsWjskA.png?q=20

You’ll also notice that in Lightning we pass in self.parameters() and not a model because the LightningModule IS the model.

The Loss

For n-way classification we want to compute the cross-entropy loss. Cross-entropy is the same as NegativeLogLikelihood(log_softmax) which we’ll use instead.

1*_b5KEGf9bclvvUL2RdO-fA.jpeg?q=20

Again… code is exactly the same!

Training and Validation Loop

We assembled all the key ingredients needed for training:

  1. The model (3-layer NN)
  2. The dataset (MNIST)
  3. An optimizer
  4. A loss

Now we implement a full training routine which does the following:

  • Iterates for many epochs (an epoch is a full pass through the dataset D )
1*4Rsl15bCoTpDmXFFpWw1hw.png?q=20
in math
1*1cNPyK4x4tz99-iFsu7Wfg.png?q=20
in code
  • Each epoch iterates the dataset in small chunks called batches b
1*ctKs6MsUkuELI7MPCyhRgA.png?q=20
in math
1*ksSlgg3TDunDm5NUBpni6A.png?q=20
in code
  • We perform a forward pass
1*jaWtm-TYJN32oTZiiNQHhg.png?q=20
in math
1*rQiaY9ySkLvP_u2Ivrndjg.png?q=20
the code
  • Compute the loss
1*9X7vyV74yr6raF91RVsIcg.png?q=20
in math
1*vsqOLu_QZRMNEcWrsNK30A.png?q=20
in code
  • Perform a backward pass to calculate all the gradients for each weight
1*d_5muKpCN1N8RKB4oCrOJQ.png?q=20
in math
1*CL05a4uMji6PwZoEAlR63g.png?q=20
in code
  • Apply the gradients to each weight
1*kS7nrQYzDhogIC0Iw_134Q.png?q=20
in math
1*I6WNFV1KpbutsrL5NxX0og.png?q=20
in code

In both PyTorch and Lightning the pseudocode looks like this

1*0kIo7fsTwm5xFvTILazing.png?q=20

This is where lightning differs though. In PyTorch, you write the for loop yourself which means you have to remember to call the correct things in the right order — this leaves a lot of room for bugs.

Even if your model is simple, it won’t be once you start doing more advanced things like using multiple GPUs, gradient clipping, early stopping, checkpointing, TPU training, 16-bit precision, etc… Your code complexity will quickly explode.

Even if your model is simple, it won’t be once you start doing more advanced things

Here’s are the validation and training loop for both PyTorch and Lightning

1*sary1TInxeXMLW-44-_RDg.jpeg?q=20

This is the beauty of lightning. It abstracts the boilerplate (the stuff not in boxes) but leaves everything else unchanged. This means you are STILL writing PyTorch except your code has been structured nicely.

This increases readability which helps with reproducibility!

The Lightning Trainer

The trainer is how we abstract the boilerplate code.

1*kswFi_ipCKuwR9vWpYlqxg.jpeg?q=20

Again, this is possible because ALL you had to do was organize your PyTorch code into a LightningModule

Full Training Loop for PyTorch

The full MNIST example written in PyTorch is as follows:

Full Training loop in Lightning

The lightning version is EXACTLY the same except:

  • The core ingredients have been organized by the LightningModule
  • The training/validation loop code has been abstracted by the Trainer

Highlights

Let’s call out a few key points

  1. Without Lightning, the PyTorch code is allowed to be in arbitrary parts. With Lightning, this is structured.
  2. It is the same exact code for both except that it’s structured in Lightning. (worth saying twice lol).
  3. As the project grows in complexity, your code won’t because Lightning abstracts out most of it.
  4. You retain the flexibility of PyTorch because you have full control over the key points in training. For instance, you could have an arbitrarily complex training_step such as a seq2seq

5. In Lightning you got a bunch of freebies such as a sick progress bar

1*kmG3OaNt1pjqC9SRB6N9Kw.png?q=20

you also got a beautiful weights summary

1*MqY59LhJ9urAcR-V6K7YUA.png?q=20

tensorboard logs (yup! you had to nothing to get this)

1*GWVf6anY8VnDlKnBIMKcqg.png?q=20

and free checkpointing, and early stopping.

All for free!

Additional Features

But Lightning is known best for out of the box goodies such as TPU training etc…

In Lightning, you can train your model on CPUs, GPUs, Multiple GPUs, or TPUs without changing a single line of your PyTorch code.

You can also do 16-bit precision training

1*pqTMd72wZJ0ULC0b5A03rQ.png?q=20

Log using 5 other alternatives to Tensorboard

1*O_sKeBowpM_XV1osFZU1ug.gif?q=20

Logging with Neptune.AI (credits: Neptune.ai)

1*ojvcqQf4jAyF8JXVJ85Exw.png?q=20

Logging with Comet.ml

We even have a built in profiler that can tell you where the bottlenecks are in your training.

1*InjZqqlEVET0zCbDEgC1uQ.png?q=20

Setting this flag on gives you this output

1*fVMoQU6hjgYHSHwr6UN_QQ.png?q=20

Or a more advanced output if you want

1*7ycXPIEJndXxrkLT34roWg.png?q=20

1*cw8OKKih9PxS9AHkgzc6VQ.png?q=20

We can also train on multiple GPUs at once without you doing any work (you still have to submit a SLURM job)

1*YZJjv8uBzJ4n4BOgfhpxEw.png?q=20

And there are about 40 other features it supports which you can read about in the documentation.

Extensibility With Hooks

You’re probably wondering how it’s possible for Lightning to do this for you and yet somehow make it so that you have full control over everything?

Unlike keras or other high-level frameworks lightning does not hide any of the necessary details. But if you do find the need to modify every aspect of training on your own, then you have two main options.

The first is extensibility by overriding hooks. Here’s a non-exhaustive list:

1*Qmi7yHLL63UGzLrUhnodVg.png?q=20

1*ZZldF5Wv-a73VTNrpt_uwg.png?q=20

1*KXOcANQDslOeOViUpObAAw.png?q=20

1*BjOAwCdcNcEKFQXkUYlY4w.png?q=20

  • anything you would need to configure

These overrides happen in the LightningModule

1*iMK5AehGEL6kgreM6yliBA.png?q=20

Extensibility with Callbacks

A callback is a piece of code that you’d like to be executed at various parts of training. In Lightning callbacks are reserved for non-essential code such as logging or something not related to research code. This keeps the research code super clean and organized.

Let’s say you wanted to print something or save something at various parts of training. Here’s how the callback would look like

1*OWvQ78_I5IUS8J_bCaMc0A.png?q=20

PyTorch Lightning Callback

Now you pass this into the trainer and this code will be called at arbitrary times

1*_XoPKjwzW3vJKv76bvi2NQ.png?q=20

This paradigm keeps your research code organized into three different buckets

  1. Research code (LightningModule) (this is the science).
  2. Engineering code (Trainer)
  3. Non-research related code (Callbacks)

How to start

Hopefully this guide showed you exactly how to get started. The easiest way to start is to run the colab notebook with the MNIST example here .

Or install Lightning

1*160lmx0S6rTQfwem6ZM85Q.png?q=20

Or check out the Github page .


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK