Skip to content
This repository has been archived by the owner on Feb 1, 2024. It is now read-only.

Latest commit

 

History

History
120 lines (84 loc) · 6.25 KB

README.md

File metadata and controls

120 lines (84 loc) · 6.25 KB

Lightning extension: Fairscale

General checks CI testing Build Status pre-commit.ci status

* the Read-The-Docs is failing as this one leads to the public domain, which requires the repo to be public too

PyTorch has its own version of FSDP, which is upstreamed from their fairscale project. It was introduced in their v1.11.0 release, but it is recommended to use it with PyTorch v1.12 or more, and that's what Lightning supports.

Auto Wrapping

Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't have to wrap layers manually, as in the case of manual wrapping.

While initializing the optimizers inside configure_optimizers hook, make sure to use self.trainer.model.parameters(), else PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your lightning_module.parameters() will return a generator with no params. This inconvenience will be addressed in the future.

from lightning_fairscale.strategies import DDPFullyShardedStrategy
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel

model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPFullyShardedStrategy(), precision=16)
trainer.fit(model)

Read more here.

Manual Wrapping

Manual wrapping can be useful to explore complex sharding strategies by applying wrap selectively to some parts of the model. To activate parameter sharding with manual wrapping, you can wrap your model using the wrap function. Internally in Lightning, we enable a context manager around the configure_sharded_model function to make sure the wrap parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

wrap simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.

Here's an example using that uses wrap to create your model:

import torch
import torch.nn as nn
from lightning_fairscale.strategies import DDPFullyShardedStrategy
from pytorch_lightning import Trainer, LightningModule
from torch.distributed.fsdp.wrap import wrap


class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.linear_layer = nn.Linear(32, 32)
        self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32))

    def configure_sharded_model(self):
        # modules are sharded across processes
        # as soon as they are wrapped with `wrap`.
        # During the forward/backward passes, weights get synced across processes
        # and de-allocated once computation is complete, saving memory.

        # Wraps the layer in a Fully Sharded Wrapper automatically
        linear_layer = wrap(self.linear_layer)

        for i, layer in enumerate(self.block):
            self.block[i] = wrap(layer)

        self.model = nn.Sequential(linear_layer, nn.ReLU(), self.block)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters())


model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPFullyShardedStrategy(), precision=16)
trainer.fit(model)

You can customize the strategy configuration by adjusting the arguments of :class:~pytorch_lightning.strategies.fully_sharded_native.DDPFullyShardedNativeStrategy and pass that to the strategy argument inside the Trainer.

from pytorch_lightning import Trainer
from lightning_fairscale.strategies import DDPFullyShardedStrategy

native_fsdp = DDPFullyShardedStrategy(cpu_offload=True)
trainer = Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)

Check out this tutorial to learn more about native support.


Activation Checkpointing

Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in selected layers. The tradeoff is that the computation cost for the backpropagation increases as the dropped activations need to be recomputed.

Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:

from pytorch_lightning import Trainer
from lightning_fairscale.strategies import DDPFullyShardedStrategy

fsdp = DDPFullyShardedStrategy(
    activation_checkpointing=MyTransformerBlock,  # or pass a list with multiple types
)
trainer = Trainer(strategy=fsdp, accelerator="gpu", devices=4)

Tests / Docs notes