

Github GitHub - PyTorchLightning/lightning-flash: Collection of tasks for fast p...
source link: https://github.com/PyTorchLightning/lightning-flash
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Installation • Docs • About • Prediction • Finetuning • Tasks • General Task • Contribute • Community • Website • License
Installation
Pip / conda
pip install lightning-flash -U
Other installations
What is Flash
Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. It is focused on:
- Predictions
- Finetuning
- Task-based training
It is built for data scientists, machine learning practitioners, and applied researchers.
Scalability
Flash is built on top of PyTorch Lightning (by the Lightning team), which is a thin organizational layer on top of PyTorch. If you know PyTorch, you know PyTorch Lightning and Flash already!
As a result, Flash can scale up across any hardware (GPUs, TPUS) with zero changes to your code. It also has the best practices in AI research embedded into each task so you don't have to be a deep learning PhD to leverage its power :)
Predictions
from flash.text import TranslationTask # 1. Load finetuned task model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") # 2. Translate a few sentences! predictions = model.predict([ "BBC News went to meet one of the project's first graduates.", "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", ]) print(predictions)
Finetuning
First, finetune:
import flash from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') # 2. Load the data datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18") # 4. Create the trainer. Run once on data trainer = flash.Trainer(max_epochs=1) # 5. Finetune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Save it! trainer.save_checkpoint("image_classification_model.pt")
Then use the finetuned model:
from flash.image import ImageClassifier # load the finetuned model classifier = ImageClassifier.load_from_checkpoint('image_classification_model.pt') # predict! predictions = classifier.predict('data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg') print(predictions)
Tasks
Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods.
Example 1: Image embedding
Flash has an Image Embedder task to encode an image into a vector of image features which can be used for anything like clustering, similarity search or classification.
View example
Example 2: Text Summarization
Flash has a Summarization task to sum up text from a larger article into a short description.
View example
Example 3: Tabular Classification
Flash has a Tabular Classification task to tackle any tabular classification problem.
View example
Example 4: Object Detection
Flash has an Object Detection task to identify and locate objects in images.
View example
Example 5: Video Classification with PyTorchVideo
Flash has a Video Classification task to classify videos using PyTorchVideo.
View example
Example 6: Semantic Segmentation
Flash has a Semantic Segmentation task for segmentation of images.
View example
Example 7: Style Transfer with Pystiche
Flash has a Style Transfer task for Neural Style Transfer (NST) with Pystiche.
View example
A general task
Flash comes prebuilt with a task to handle a huge portion of deep learning problems.
import flash from torch import nn, optim from torch.utils.data import DataLoader, random_split from torchvision import transforms, datasets # model model = nn.Sequential( nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10) ) # data dataset = datasets.MNIST('./data_folder', download=True, transform=transforms.ToTensor()) train, val = random_split(dataset, [55000, 5000]) # task classifier = flash.Task(model, loss_fn=nn.functional.cross_entropy, optimizer=optim.Adam) # train flash.Trainer().fit(classifier, DataLoader(train), DataLoader(val))
Infinitely customizable
Tasks can be built in just a few minutes because Flash is built on top of PyTorch Lightning LightningModules, which are infinitely extensible and let you train across GPUs, TPUs etc without doing any code changes.
import torch import torch.nn.functional as F from torchmetrics import Accuracy from typing import Callable, Mapping, Sequence, Type, Union from flash.core.classification import ClassificationTask class LinearClassifier(ClassificationTask): def __init__( self, num_inputs, num_classes, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()], learning_rate: float = 1e-3, ): super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, ) self.save_hyperparameters() self.linear = torch.nn.Linear(num_inputs, num_classes) def forward(self, x): return self.linear(x) classifier = LinearClassifier(128, 10) ...
When you reach the limits of the flexibility provided by Flash, then seamlessly transition to PyTorch Lightning which gives you the most flexibility because it is simply organized PyTorch.
Contribute!
The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!
Join our Slack and/or read our CONTRIBUTING guidelines to get help becoming a contributor!
Community
For help or questions, join our huge community on Slack!
Citations
We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors.
Flash leverages models from torchvision, huggingface/transformers, timm, and pytorch-tabnet for the vision
, text
, and tabular
tasks respectively. Also supports self-supervised backbones from bolts.
License
Please observe the Apache 2.0 license that is listed in this repository. In addition the Lightning framework is Patent Pending.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK