Skip to content

Commit

Permalink
[Tensor Parallelism] fix shape of output tensor proxy's given off fro…
Browse files Browse the repository at this point in the history
…m tensor parallel ops (Lightning-AI#534)
  • Loading branch information
crcrpar authored Jun 6, 2024
1 parent d3c06ff commit 0342223
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 21 deletions.
39 changes: 22 additions & 17 deletions thunder/distributed/tensor_parallel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,35 +107,40 @@ def eligible_for_comm_optimization(self) -> bool:
return self._has_other_tensor_parallel

def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
from thunder.core.prims import PrimIDs
from thunder.core.transforms import VISIT_TYPE
from thunder.core.trace import get_tracectx
from thunder.core.proxies import variableify

for t in bsym.flat_proxy_args:
if bsym.sym.id in {
PrimIDs.UNPACK_TRIVIAL,
PrimIDs.UNPACK_SEQUENCE,
PrimIDs.UNPACK_KEY,
PrimIDs.UNPACK_EMPTY_DICT,
}:
return VISIT_TYPE.NO_OP

pre_post_process: PrePostProcessInterface | None = self.bsym_to_prepostprocess.get(bsym, None)
new_bsym = bsym.from_bsym_swap_proxies(self.swap_map)
for t in new_bsym.flat_proxy_args:
self._maybe_other_tensor_parallel(t)

input_swap_map: dict[VariableInterface, ProxyInterface] = {}
pre_post_process: PrePostProcessInterface | None = None
if bsym in self.bsym_to_prepostprocess:
pre_post_process = self.bsym_to_prepostprocess[bsym]
orig_arg = bsym.flat_proxy_args[0]
if pre_post_process is not None:
orig_arg = new_bsym.flat_proxy_args[0]
new_arg, preprocess_artifacts = pre_post_process.preprocess(orig_arg)
if new_arg.name != orig_arg.name:
input_swap_map[variableify(orig_arg)] = new_arg

new_bsym = bsym.from_bsym_swap_proxies(self.swap_map, skip_output=True)
if pre_post_process is not None:
new_bsym = new_bsym.from_bsym_swap_proxies(input_swap_map)
new_bsym = new_bsym.from_bsym_swap_proxies({variableify(orig_arg): new_arg})
new_bsym = pre_post_process.maybe_modify_args_and_kwargs(new_bsym)
# note(crcrpar): This header seems to be lost in the extrace.
new_bsym.header = f"{pre_post_process.__class__.layer_type}"
trace = get_tracectx()
trace.scopes[-1].append(new_bsym)

new_out = new_bsym.sym(*new_bsym.args, **new_bsym.kwargs)

var_original_bsym_output = variableify(new_bsym.flat_proxy_outs[0])
if pre_post_process is not None:
y = bsym.flat_proxy_outs[0]
processed_y = pre_post_process.postprocess(y, preprocess_artifacts)
self.swap_map[variableify(y)] = processed_y
processed_y = pre_post_process.postprocess(new_out, preprocess_artifacts)
self.swap_map[var_original_bsym_output] = processed_y
else:
self.swap_map[var_original_bsym_output] = new_out

return VISIT_TYPE.REPLACE

Expand Down
50 changes: 50 additions & 0 deletions thunder/tests/distributed/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations
from typing import ClassVar, TYPE_CHECKING

import torch.nn as nn

from thunder.core import utils

if TYPE_CHECKING:
import torch


__all__ = [
"ParallelMLP",
]


class ParallelMLP(nn.Module):
"""Simplified version of Megatron/NeMo's ParallelMLP.
Ref: https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L61
"""

COLUMN_WISE: ClassVar[tuple[str]] = ("dense_h_to_4h",)
ROW_WISE: ClassVar[tuple[str]] = ("dense_4h_to_h",)

SUPPORTED_GELU_APPROX: ClassVar[tuple[str, str]] = ("none", "tanh")

def __init__(
self,
hidden_size: int,
ffn_hidden_size: int | None = None,
bias: bool = True,
gelu_approximate: str = "none",
) -> None:
utils.check(
gelu_approximate in ParallelMLP.SUPPORTED_GELU_APPROX,
lambda: f"Invalid {gelu_approximate}, supported are {ParallelMLP.SUPPORTED_GELU_APPROX}",
)
if ffn_hidden_size is None:
ffn_hidden_size = 4 * hidden_size

super().__init__()
self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=bias)
self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
self.gelu = nn.GELU(approximate=gelu_approximate)

def forward(self, x: torch.Tensor) -> torch.Tensor:
four_h = self.gelu(self.dense_h_to_4h(x))
h = self.dense_4h_to_h(four_h)
return h
57 changes: 53 additions & 4 deletions thunder/tests/distributed/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from thunder.distributed import column_parallel, row_parallel
import thunder.executors
from thunder.tests.distributed.helper import ToyModel, DataParallelTestCase
from thunder.tests.distributed.modules import ParallelMLP

from torch.testing._internal import common_utils
from torch.distributed import distributed_c10d as c10d

_COL = "column"
_ROW = "row"
Expand All @@ -24,7 +24,7 @@ class TensorParallelTest(DataParallelTestCase):

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="")
@common_utils.parametrize("name,bias", product(tuple(_name_to_transform.keys()), (True, False)))
def test_tensor_parallel_linear(self, name, bias):
def test_linear(self, name, bias):
device = torch.device("cuda", self.rank)
x = torch.randn(2, 12).to(device).requires_grad_()
x_ref = x.clone().detach().requires_grad_()
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_tensor_parallel_linear(self, name, bias):

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="")
@common_utils.parametrize("name", tuple(_name_to_transform.keys()))
def test_tensor_parallel_embedding(self, name):
def test_embedding(self, name):
num_embeddings = 128
embedding_dim = 32

Expand Down Expand Up @@ -130,7 +130,7 @@ def forward(self, x):

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="")
@common_utils.parametrize("bias", (True, False))
def test_tensor_parallel_both_column_and_row(self, bias):
def test_both_column_and_row(self, bias):
num_embeddings = 128
embedding_dim = 32
n_hidden = 96
Expand Down Expand Up @@ -189,6 +189,55 @@ def forward(self, x):
grad = tp_model.get_parameter(param_fqn).grad
torch.testing.assert_close(actual=grad, expected=ref_grad, msg=msg)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="")
def test_parallel_mlp(self):
from thunder.distributed.prims import PrimIDs

sequence_length: int = 32
batch_size: int = 4
hidden_size: int = 128
ffn_hidden_size: int = 512
device = torch.device("cuda", self.rank)

ref_mlp = ParallelMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size).to(device)
ref_state_dict = ref_mlp.state_dict()
mlp = ParallelMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size).to(device)
mlp.load_state_dict(ref_state_dict)
tp_mlp = thunder.jit(mlp)
tp_mlp = column_parallel(tp_mlp, ParallelMLP.COLUMN_WISE)
tp_mlp = row_parallel(tp_mlp, ParallelMLP.ROW_WISE)

# See https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L221 for the input shape.
x_ref = torch.randn((sequence_length, batch_size, hidden_size), device=device, requires_grad=True)
x = x_ref.clone().detach().requires_grad_(True)

expected = ref_mlp(x_ref)
actual = tp_mlp(x)
torch.testing.assert_close(actual=actual, expected=expected)

grad = torch.rand_like(x_ref)
expected.backward(grad)
actual.backward(grad)
torch.testing.assert_close(actual=x.grad, expected=x_ref.grad)

tp_syncs = {PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT, PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT}
fwd_traces_with_tensor_parallel_syncs = list(
filter(
lambda trace: any(bsym.sym.id in tp_syncs for bsym in trace.bound_symbols),
thunder.last_traces(tp_mlp),
)
)

last_fwd_trace_with_tp_sync = fwd_traces_with_tensor_parallel_syncs[-1]
bsyms_of_tp_sync = tuple(
filter(lambda bsym: bsym.sym.id in tp_syncs, last_fwd_trace_with_tp_sync.bound_symbols)
)
msg = f"{bsyms_of_tp_sync=}"
# Two bsyms are supposed to be
# - preprocessing of column-wise parallel linear
# - postprocessing of row-wise parallel linear
self.assertEqual(len(bsyms_of_tp_sync), 2, msg=msg)


common_utils.instantiate_parametrized_tests(TensorParallelTest)

Expand Down

0 comments on commit 0342223

Please sign in to comment.