Skip to content

Commit

Permalink
fsdp_v2 : fix weight sharing (Lightning-AI#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored May 30, 2024
1 parent e579ace commit ebe1326
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
18 changes: 18 additions & 0 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ def fsdp_transform_module(
# modify module
sharded_params = {}
device_adjustments = {}
# We use `shared_params` dictionary to track the shared parameters.
# Key to this dictionary is the original parameter from the user's Module.
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
shared_params = WeakTensorKeyDictionary()
for module_name, _ in thunder_model._model.named_modules():
submodule = thunder_model.get_submodule(module_name)

Expand Down Expand Up @@ -448,6 +452,14 @@ def fsdp_transform_module(
tdist.broadcast(thunder_model.get_buffer(pn), src=broadcast_from, group=process_group, async_op=False)

for pn, p in submodule.named_parameters(recurse=False, prefix=module_name):
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
# This way we re-create parameter sharing in thunder's copy of the Module.
if p in shared_params:
# Re-use the previous copy of this parameter.
thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"]
sharded_params[pn] = shared_params[p]["param_shard_meta"]
continue

# if we don't have an override or it is just the original, do create a copy
if thunder_model._overrides_parameters.get(pn, p) is p:
thunder_model._overrides_parameters[pn] = copy.copy(p)
Expand All @@ -459,6 +471,12 @@ def fsdp_transform_module(
new_shape = thunder_model._overrides_parameters[pn].shape
sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device)

# Track the original param and it's corresponding copied shard and metadata.
shared_params[p] = {
"param_copy": thunder_model._overrides_parameters[pn],
"param_shard_meta": sharded_params[pn],
}

early_transform_from_trace_to_fsdp_trace = FSDPTraceTransform(
sharded_params=sharded_params,
process_group=process_group,
Expand Down
55 changes: 34 additions & 21 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,36 +1037,49 @@ def __init__(self) -> None:
def forward(self, x):
return self.fc1(x) + self.fc2(x)

# Check `jit(fsdp(model))` works
with device:
model = Model()
x = torch.ones(4, 16)
def _test_model_output_and_gradients(model, x):
output = model(x)
with device:
grad_output = torch.ones_like(output)
output.backward(grad_output)
expected_shape = (4, 16)

model.fc1.weight = model.fc2.weight
assert output.shape == expected_shape, f"{output.shape=} - {expected_shape=}"

model = thunder.jit(thunder.distributed.fsdp(model), executors=["torch"])
# Verify that both params point to same grad tensor.
assert id(model.get_parameter("fc1.weight").grad) == id(model.get_parameter("fc2.weight").grad)

output = model(x)
with device:
grad_output = torch.ones_like(output)
output.backward(grad_output)
expected_shape = (4, 16)
# Verify that we accumulate the gradients for the shared parameter.
gathered_grad_shape = (model.get_parameter("fc1.weight").shape[0] * self.world_size,) + model.get_parameter(
"fc1.weight"
).shape[1:]
with device:
actual_grad_gathered = torch.empty(gathered_grad_shape)

assert output.shape == expected_shape, f"{output.shape=} - {expected_shape=}"
tdist.all_gather_into_tensor(actual_grad_gathered, model.get_parameter("fc1.weight").grad)

# Verify that both params point to same grad tensor.
assert id(model.get_parameter("fc1.weight").grad) == id(model.get_parameter("fc2.weight").grad)
# Based on the forward, grad for both params is `(grad_output.T @ x)`. Multiplying by 2 as the grad will be accumulated.
expected_grad = 2 * (grad_output.T @ x)
torch.testing.assert_close(actual_grad_gathered, expected_grad)

# Verify that we accumulate the gradients for the shared parameter.
gathered_grad_shape = (model.fc1.weight.shape[0] * self.world_size,) + model.fc1.weight.shape[1:]
with device:
actual_grad_gathered = torch.empty(gathered_grad_shape)
jit_fsdp_model = Model()
fsdp_jit_model = Model()
x = torch.ones(4, 16)

# Check `jit(fsdp(model))` works
jit_fsdp_model.fc1.weight = jit_fsdp_model.fc2.weight

jit_fsdp_model = thunder.jit(thunder.distributed.fsdp(jit_fsdp_model), executors=["torch"])

_test_model_output_and_gradients(jit_fsdp_model, x)

# Check `fsdp(jit(model))` works
fsdp_jit_model.fc1.weight = fsdp_jit_model.fc2.weight

tdist.all_gather_into_tensor(actual_grad_gathered, model.get_parameter("fc1.weight").grad)
fsdp_jit_model = thunder.distributed.fsdp(thunder.jit(fsdp_jit_model, executors=["torch"]))

# Based on the forward, grad for both params is `(grad_output.T @ x)`. Multiplying by 2 as the grad will be accumulated.
expected_grad = 2 * (grad_output.T @ x)
torch.testing.assert_close(actual_grad_gathered, expected_grad)
_test_model_output_and_gradients(fsdp_jit_model, x)


common_utils.instantiate_parametrized_tests(CompileDDPTest)
Expand Down

0 comments on commit ebe1326

Please sign in to comment.