Experiment with Billion-Parameter Models Faster using DeepSpeed and Meta Tensors

PyTorch’s Meta Tensors can save you huge amounts of time. PyTorch Lightning, together with DeepSpeed and just a single line of code, allows you to train large, billion-parameter models even faster.

PyTorch Lightning team
PyTorch Lightning Developer Blog
4 min readApr 19, 2022

--

PyTorch Lighting is a lightweight PyTorch wrapper for high-performance AI research. PyTorch Lightning provides true flexibility by reducing the engineering boilerplate and resources required to implement state-of-the-art AI. Organizing PyTorch code with Lightning enables seamless training on multiple GPUs, TPUs, and CPUs and the use of difficult to implement best practices such as model sharding and even 16-bit/bfloat16 precision without changing your code.

This post will show how to use the recently introduced Meta Tensors in PyTorch to save huge amounts of time with just a single line of code when creating large, billion-parameter models in Lightning. We’ll also show you how to pair this with DeepSpeed to train billion-parameter models.

You can learn more about our DeepSpeed integration with Pytorch Lightning in this blog post.

What are PyTorch Meta Tensors?

Imagine a situation where you want to know the output shape from an operation. Typically, you would run the operation and check the size of the tensor after.

Applying max_pool2d filter on a CPU tensor
Applying max_pool2d filter on a CPU tensor

With Meta Tensors, you don’t have to compute the output to find the answer. Meta Tensors are just like normal tensors, except they have no data. In PyTorch meta is a device. You can even operate on Meta Tensors, saving you from computing the true output.

Applying max_pool2d filter on a Meta tensor, skipping instantiation!

How Can Meta Tensors Help Large Models?

Imagine a situation where you want to instantiate a very large model, i.e., a large transformer model. Typically, you would do something like this:

Instantiating a 45-billion-parameter GPT model takes considerable time and memory, especially when instantiating on all devices in multi-GPU or multi-node training. If the model is significantly large, like the one above, it can even be unfeasible to instantiate the model in CPU RAM.

When training large models, we usually pair techniques such as DeepSpeed ZeRO Stage 3. Under the hood, each device only has a portion of the model’s weight, so, in the above example, we waste precious time and memory instantiating entire weights on each device. DeepSpeed/Lightning does provide a context manager to assist; however, using it requires some model changes, which can be undesirable.

Image By Phoeby Naren

With Lightning, we can leverage Meta Tensors in a single added line to alleviate the problem of memory constraints and reduce the instantiation time of our large GPT model.

To enable DeepSpeed Stage 3, all it takes is an argument to the Trainer. To leverage Meta Tensors, all we need to do is wrap our LightningModule in init_meta_context and our model is automatically instantiated with the meta device, which skips CPU instantiation (which would blow up our memory!).

Using init_meta_context to remove unnecessary model initializing, code here

With this context manager, we now dramatically speed up model initialization and reduce memory overhead as the model’s actual weights are not instantiated. Internally, Lightning will instantiate the model within DeepSpeed’s sharding context, automatically sharding the model across all devices without instantiating the entire model on each device.

Measured using minGPT.

Conclusion

With just a single line of code, you can save time starting experiments using large billion parameter models with Lightning.

See our extensive model parallel docs for more information!

About The Authors

Thomas Chaton, Sean Narenthiran, and Kushashwa Ravi Shrimali are research engineers within the Grid AI Labs and PyTorch Lightning.

Thomas Chaton focuses on deep geometric learning over graphs and manifold to provide a cutting-edge technological differentiator.

Sean Narenthiran leads efforts in building advanced distributed capabilities to train massive billion parameter models.

Kushashwa Ravi Shrimali focuses on maintaining and improving the libraries in the Lightning Ecosystem, and he loves optimizing things, taking models to production, and being open source.

--

--

We are the core contributors team developing PyTorch Lightning — the deep learning research framework to run complex models without the boilerplate