18

Better Data Loading: 20x PyTorch Speed-Up for Tabular Data

 4 years ago
source link: https://towardsdatascience.com/better-data-loading-20x-pytorch-speed-up-for-tabular-data-e264b9e34352?gi=722901b92f93
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

A simple change to speed up your deep learning training massively

Deep Learning: Need For Speed

When training deep learning models, performance is crucial. Datasets can be huge, and inefficient training means slower research iterations, less time for hyperparameter optimisation, longer deployment cycles, and higher compute cost.

Despite all that, it can be hard to justify investing too much time speeding things up, as there are many potential dead ends to explore. But fortunately, there are some quick wins available!

I’m going to show you how a simple change I made to my dataloaders in PyTorch for tabular data sped up training by over 20x — without any change to the training loop! Just a simple drop-in replacement for PyTorch’s standard dataloader. For the model I was looking at, that’s a sixteen minute iteration time reduced to forty seconds!

And all without installing any new packages, making any low-level code changes, or changing any hyperparameters.

Research/Industry Disconnect

In supervised learning, a quick look at Arxiv-Sanity tells us that the top research papers at the moment are all either about images (whether classification or GANs for generation), or text (mostly variations on BERT). These are great in areas where traditional machine learning just doesn’t stand a chance — but require expertise and a significant research budget to execute well.

On the other hand, much of the data that many companies hold already resides in databases, in a nice tabular format. Some examples include customer details for lifetime value estimation, click-through optimisation, and financial time-series data.

What’s Special About Tabular Data?

So why is this rift between research and industry a problem for us? Well, the needs of state-of-the-art text/vision researchers are very different of those doing supervised learning on tabular data sets.

Having data in tabular form (i.e. a database table, Pandas DataFrame, NumPy Array, or PyTorch Tensor) makes things easier in several ways:

  1. Training batches can be taken from contiguous chunks of memory by slicing .
  2. No per-sample preprocessing cost, allowing us to make full use of large-batch training for additional speed (remembering to increase the learning rate so we don’t overfit!).
  3. If your dataset is small enough, it can be loaded on to the GPU all in one go. (while this is technically also possible with text/vision data, datasets there tend to be larger and some preprocessing steps are more easily done on CPU).

These optimisations are possible for tabular data and not for text/vision data because of two main areas of difference: models and data.

Models:vision research tends to use large deep convolutional neural nets (CNNs); text tends to use large recurrent neural nets (RNNs) or Transformers; but on tabular data plain fully connected deep neural nets (FCDNN) can do fine. While not always the case, in general vision and text models require more parameters to learn more nuanced representations than interactions between variables in tabular data, and so forward and backward passes can take longer.

Data: vision data tends to be saved as nested folders full of images, which can require significant pre-processing (cropping, scaling, rotating, etc). Text data can be large files or other text streams. Both of these will generally be saved on disk, and loaded from disk in batches. This isn’t an issue because the disk read/write speeds aren’t the bottleneck here — the preprocessing or backward passes are. Tabular data, on the other hand, has the nice property of being easily loaded into contiguous chunks of memory in the form of an array or tensor. Preprocessing on tabular data tends to be done separately in advance, either in a database, or as a vectorised operation on a dataset.

QnuARjY.png!web

Comparison of different types of Supervised Learning research

PyTorch & DataLoaders

As we’ve seen, loading tabular data can be really easy and fast! So of course PyTorch works great by default for tabular data… right?

It turns out it doesn’t! :weary:

Just last week I was training a PyTorch model on some tabular data, and wondering it was taking so long to train. I couldn’t see any obvious bottlenecks, but for some reason, the GPU usage was much lower than expected. When I dug into it with some profiling I found the culprit… the DataLoader .

What is a DataLoader?DataLoaders do exactly what you might think they do: they load your data from wherever it is (on disk, in the cloud, in memory) to wherever it needs to be for your model to use it (in RAM or GPU memory). In addition to this, they take care of splitting your data into batches, shuffling it, and pre-processing individual samples if necessary. Wrapping this code in a DataLoader is nicer than having it scattered throughout, as it allows you to keep your main training code clean. The official PyTorch tutorial also recommends using DataLoaders .

How do you use them?It depends on the type of data you have. For tabular data, PyTorch’s default DataLoader can take a TensorDataset. This is a lightweight wrapper around the tensors required for training — usually an X (or features) and Y (or labels) tensor.

data_set = TensorDataset(train_x, train_y)
train_batches = DataLoader(data_set, batch_size=1024, shuffle=False)

You can then use this in your training loop:

for x_batch, y_batch in train_batches:    optimizer.zero_grad()    loss = loss_fn(model(x_batch), y_batch)    loss.backward()    optimizer.step()    ...

Why is this bad?This looks pretty good, and certainly very clean! The problem is that, each time a batch is loaded, PyTorch’s DataLoader calls the __getitem__() function on the DataSet once per example and concatenates them, rather than reading a batch in one go as a big chunk! So we don’t end up making use of the advantages of our tabular data set. This is especially bad when we use large batch sizes.

How can we fix this?Easy — replace the first two lines above with the two lines below, and copy the definition of FastTensorDataLoader from this file (credit for this goes to Jesse Mu , for this answer on the PyTorch forums):

train_batches = FastTensorDataLoader(train_x, train_y, batch_size=1024, shuffle=False)

FastTensorDataLoader is just a small custom class with no dependencies other than PyTorch —and using it doesn’t require any changes to your training code! It supports shuffling too, though the benchmarks below are for non-shuffled data.

What difference does this make?On the benchmark set I used, the custom tabular DataLoader ends up being over 20x faster. In this case, that means that instead of a 10-epoch run taking over 15 minutes, it takes less than 40 seconds— a huge difference in iteration speed!

2UbEzi7.png!web

Two near-identical runs — except one takes over 15 minutes, and the other takes less than a minute!

This benchmark was run on the Higgs dataset used in this Nature paper . With 11m examples, it makes for a more realistic deep learning benchmark than most public tabular ML datasets (which can be tiny!). It’s a binary classification problem, with 21 real-valued features. It’s nice to see that we can get to over 0.77 ROC AUC on the test set within just 40s of training, before any hyperparameter optimisation! Though we’re still a while off from the 0.88 reached in the paper.

I hope this has been helpful, and that you’re able to see similar speed increases in your own training code! After implementing this I found some further optimisations which resulted in a total speedup of closer to 100x! Leave a comment if you’d like to see more, and we can cover these in a follow-up article.

See the Appendix for how you can run the benchmark code yourself. The example includes code for running the default PyTorch DataLoader, the faster custom one, as well as timing the results and logging to TensorBoard.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK