PyTorch Lightning Developer Blog

PyTorch Lightning is a lightweight machine learning framework that handles most of the engineering work, leaving you to focus on the science. Check it out: pytorchlightning.ai

Follow publication

Lightning Flash 0.3 — New Tasks, Visualization Tools, Data Pipeline, and Flash Registry API

PyTorch Lightning team
PyTorch Lightning Developer Blog
8 min readMay 24, 2021

--

8 New Flash Tasks

Lightning Flash is a library from the creators of PyTorch Lightning to enable quick baselining and experimentation with state-of-the-art models for popular Deep Learning tasks.

We are excited to announce the release of Flash v0.3 which has been primarily focused on the design of a modular API to make it easier for developers to contribute and expand tasks.

In addition to that, we have included 8 new tasks across the Computer Vision and NLP domains, visualization tools to help with debugging, and an API to facilitate the use of existing pre-trained state-of-the-art Deep Learning models.

New Out-of-the-Box Flash Tasks

Flash now includes 10 tasks across Vision and NLP! Any task can easily be used for finetuning or predicting.

Computer Vision

Multi-label Image Classification- The ability to categorize an image into multiple classes can be useful to create applications such as categorizing movies into genres using their posters, or detecting plant diseases.

https://www.comingsoon.net/movies/features/1157662-the-25-best-movie-posters-of-2020

The snippet below shows how to easily create a fine-tuning pipeline for image classification with multi-label support, using ImageClassifierwith pre-trained state-of-the-art backbone models.

Object Detection The detection of objects in images is useful for applications like traffic control, license plate detection, or OCRs.

In this example, we finetune a custom ObjectDetector with the COCO dataset and a RetinaNet backbone model:

You can also use the Flash pre-trained ObjectDetector for predictions with just a few lines of code:

Video Classification- the task of producing a label for actions identified in a given video. Flash includes a VideoClassifier that leverages PyTorchVideo.

Code showing the VideoClassification with PyTorch Videos pretrained x3D model.

Semantic Segmentation- Classify at pixel level or image regions to given classes.

https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png

Code showing the SemanticSegmentation with torchvision pretrained FCN model.

Style Transfer — Transfer the style of an image from one domain to another.

Top Left: A Starry Night by Vincent van Gogh, c. June 1889 | Top right: Stanford campus | Bottom: Image generated by employing the Neural Style Transfer algorithm, Gatys et al.

We thanks Phillip Meier (@pmeier), the main author of PyStiche for contributing this StyleTransfer Task.

Image Embedder- Generate image vector encodings that can be used for tasks such as clustering, similarity search, or classification.

https://ai.googleblog.com/2016/12/open-sourcing-embedding-projector-tool.html

Code showing the usage of ImageEmbedder API to generate image embeddings.

NLP

Summarization- Summarize text from a larger document/article into a short sentence/description.

Translation- Translate text from a source language to another, such as English to Romanian.

New Features

Model Hub Intergrations

In this release, we integrated several model hubs such as Hugging Face, TIMM, PyTorch Video all accesible from Flash Tasks enabling you to train, finetune and inference thousands of SOTA models out of the box.

Task Visualization

We additionally added easy-to-use visualization callbacks to Flash tasks, which are useful to visually inspect the impact of different transforms on your data before launching a training.

You can implement custom visualization by subclassing a BaseVisualization callback- simply override any of the show_{preprocess_hook_name} to get the associated data and visualize it.

Contribute Your own Task

This release of Lightning Flash goes beyond finetuning, predicting, and visualizing out-of-the-box tasks. It includes APIs that give you the flexibility to extend and customize your Tasks, with little to no boilerplate.

You can now create a reusable and modular data processing pipeline, add visualization methods, and register model backbones for any task, making the Flash tasks a lot more structured!

Here are some example Scenarios that we would love for the community to help contribute to as either as new Tasks:

  • Visual Question Answering
  • Optical Character Recognition
  • Graph Classification
  • Speaker Recognition
  • Audio Classification
  • Face Detection
  • Coreference Resolution
  • Time Series Forecasting

For more details check out this detailed guide on how to contribute a task to flash.

New Task API Design and Improvements

In the next section, we will go over the new improvements for data and task flexibility.

Data Processing in Flash

To make data processing in flash more customizable and reusable, we created a few new components, creating a standard for the data pipeline.

Hooks

Lightning Flash API, just like PyTorch Lightning, is built as a collection of hooks- methods you can override to customize the behavior at different points of the model pipeline. This pipeline is comprised of 4 main routines: training, validation, testing, and predicting.

You can further customize any Flash hook for specific routines by adding a {train, val, test, predict} prefix to the hook's name.

For example, Flash DataSource has a load_data hook. Overriding it will modify the data loading behavior for all routines, but you can also override train_load_data to modify the data loading logic only in training.

Data Loading

DataSource provides a hook-based API for creating data sets. It includes 2 hooks for data loading: load_data and load_sample.

load_data takes any input, such as directory, zip file, etc, and returns a sequence or iterator of samples, which will be used as input for load_sample, that returns the loaded sample. We found good practice for the load_sample to return a dictionary with an input and a target.

In this example:

  • load_data takes a folder path and returns a list of input paths and their labels.
  • load_sample takes sample metadata as input, and returns a dictionary with a loaded object and its target.

Data Transforms

The Preprocess object provides a series of hooks encapsulating all the data processing logic that should run before the data is passed to the model. Decoupling the preprocessing logic reduces the engineering overhead used for inference on raw data or for deploying a model in a production environment, compared to a traditional PyTorch Dataset.

The Preprocess holds the DataSource instances and all the different routine transforms.

In Flash, the transforms are expected to be dictionary mapping a hook name to a given callable e.g. using batched transforms from Kornia.

Flash implements an internal mechanism to apply your transforms at the right time.

The pre_tensor_transform, to_tensor_transform, post_tensor_transform, collate and per_batch_transform transforms will be injected directly as PyTorch DataLoader collate_fn function. This enables your transforms to be applied in parallel when setting num_workers > 0.

The per_sample_transform_on_device, collate and per_batch_transform_on_device transforms will be applied after the data have been transferred to GPU or TPU.

Finally, the default_transforms hook exposes the defaults transforms associated with a Task's Preprocess.

Predictions Transforms

Similar to Preprocess, the Postprocess provides a simple hook-based API to encapsulate your post-processing logic. The Postprocess hooks cover everything from model outputs to predictions export.

For example, in the Semantic Segmentation Task, the shape of each image is added as metadata and the Postprocess will use each image metadata to resize its semantic prediction to its original size as follow:

Predictions Serializations

You can control the output format by overriding the Flash Serializer, which provides a method to convert model outputs (a tensor, after the Postprocess) to the desired prediction format.

For example, here’s a serializer that outputs probabilities (rather than logits):

You can change the serializer for your task by passing it to the model constructor, or by setting the models’ serializer property. Here’s an example where we use our probabilities serializer:

Read more on predictions in Flash here.

Flash Model Backbones

To make it easier to extend different model backbones in Flash tasks, we added the Flash Registry- an internal key-value database of models, making them accessible across the entire Flash codebase. Each Flash Task can include several registries as static attributes.

Here is how we registered all timm models within a TIMM_BACKBONES_REGISTRY.

We are looking for community contributors to help integrate additional models and hubs with the new tasks release please reach out in GitHub or on Slack if you are interested in helping out.

Next Steps

To learn more about Flash, read the docs.

In the next releases we are planning to expand our task offering to new domains and use cases, feel free to open an issue or a PR with new tasks!

Join our slack to get involved!

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Published in PyTorch Lightning Developer Blog

PyTorch Lightning is a lightweight machine learning framework that handles most of the engineering work, leaving you to focus on the science. Check it out: pytorchlightning.ai

Written by PyTorch Lightning team

We are the core contributors team developing PyTorch Lightning — the deep learning research framework to run complex models without the boilerplate

No responses yet