GitHub - msamogh/nonechucks: Skip bad items in your PyTorch DataLoader, use Tran...
source link: https://github.com/msamogh/nonechucks
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
README.md
nonechucks
nonechucks is a library that provides wrappers for PyTorch's datasets, samplers, and transforms to allow for dropping unwanted or invalid samples dynamically.
Introduction
Why do I need it?
When you have a small fraction of undesirable samples in a large dataset, or When your sample-loading operation is expensive, or When you want to let downstream consumers know that a sample is undesirable (with nonechucks, transforms are not restricted to modifying samples; they can drop them as well), When you want your dataset and sampler to be decoupled.Examples
1. Dealing with bad samples
Create a dataset (the usual way)
Using something like torchvision's ImageFolder dataset class, we can load an entire folder of labelled images for a typical supervised classification task.
import torchvision.datasets as datasets fruits_dataset = datasets.ImageFolder('fruits/')
Without nonechucks
Now, if you have a sneaky fruits/apple/143.jpg
(that is corrupted) sitting in your fruits/
folder, to avoid the entire pipeline from surprise-failing, you would have to resort to something like this:
import random # Shuffle dataset indices = list(range(len(fruits_dataset)) random.shuffle(indices) batch_size = 4 for i in range(0, len(indices), batch_size): try: batch = [fruits_dataset[idx] for idx in indices[i:i + batch_size]] # Do something with it pass except IOError: # Skip the entire batch continue
Not only do you have to put your code inside an extra try-except
block, but you are also forced to use a for-loop, depriving yourself of PyTorch's built-in DataLoader
, which means you can't use features like batching, shuffling, multiprocessing, and custom samplers for your dataset.
I don't know about you, but not being able to do that kind of defeats the whole point of using a data processing module for me.
With nonechucks
You can transform your dataset into a SafeDataset
with a single line of code.
import nonechucks as nc fruits_dataset = nc.SafeDataset(fruits_dataset)
That's it! Seriously.
And that's not all. You can also use a DataLoader
on top of this.
dataloader = nc.SafeDataLoader(fruits_dataset, batch_size=4, shuffle=True) for i_batch, sample_batched in enumerate(dataloader): # Do something with it pass
In this case, SafeDataset
will skip the erroneous image, and use the next one in the place of it (as opposed to dropping the entire batch).
2. Use Transforms as Filters!
The function of transorms in PyTorch is restricted to modifying samples. With nonechucks, you can simply return None
(or raise an exception) from the transform's __call__
method, and nonechucks will drop the sample from the dataset for you, allowing you to use transforms as filters!
For the example, we'll assume a PDFDocumentsDataset
, which reads PDF files from a folder, a PlainTextTransform
, which transforms the files into raw text, and a LanguageFilter
, which retains only documents of a particular language.
class LanguageFilter: def __init__(self, language): self.language = language def __call__(self, sample): # Do machine learning magic document_language = detect_language(sample) if document_language != self.language: return None return sample transforms = transforms.Compose([ PlainTextTransform(), LanguageFilter('en') ]) en_documents = PDFDocumentsDataset(data_dir='pdf_files/', transform=transforms) en_documents = nc.SafeDataset(en_documents)
Installation
To install nonechucks, simply use pip:
or clone this repo, and build from source with:
Contributing
Licensing
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK