Skip to content

Commit

Permalink
fix(gmm): change communicator.grad_hook to async (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde authored Dec 10, 2024
1 parent f6c66bd commit 71c32c8
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def forward(
ctx.compute_weight_gradient = weight.requires_grad
ctx.backend = backend

saved_x = None if ctx.compute_weight_gradient is False else x
ctx.save_for_backward(saved_x, weight, batch_sizes)

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
Expand All @@ -358,8 +361,7 @@ def forward(

output = torch.matmul(x, weight)

saved_x = None if ctx.compute_weight_gradient is False else x
ctx.save_for_backward(saved_x, weight, batch_sizes)
assert len(output.shape) == len(x.shape)

return output

Expand All @@ -372,6 +374,14 @@ def backward(ctx, grad_output):
x, weight, batch_sizes = ctx.saved_tensors
grad_input, grad_weight = None, None

if grad_output.numel() == 0:
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
if ctx.needs_input_grad[0]:
grad_input = torch.zeros_like(x)

return grad_input, grad_weight, None, None, None, None, None

if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if backend == "gmm":
Expand Down Expand Up @@ -450,6 +460,8 @@ def forward(
saved_x = None if ctx.compute_weight_gradient is False else x
ctx.save_for_backward(saved_x, weight, batch_sizes)

assert len(output.shape) == len(x.shape)

return output

@staticmethod
Expand All @@ -461,20 +473,28 @@ def backward(ctx, grad_output):
backend = ctx.backend
full_weight_shape = ctx.full_weight_shape

grad_output = grad_output.contiguous()

total_weight = communicator.weight_hook(weight, module=module)
total_weight = total_weight.reshape(full_weight_shape)
grad_input, grad_weight = None, None
if grad_output.numel() == 0:
if ctx.needs_input_grad[1]:
total_weight_shape = torch.Size(
(full_weight_shape.numel() // full_weight_shape[-1], full_weight_shape[-1])
)
grad_weight = torch.zeros(total_weight_shape, dtype=weight.dtype, device=weight.device)
grad_weight, grad_weight_sync = communicator.grad_hook(
grad_weight, async_op=True, module=module, is_bias=False
)
if ctx.needs_input_grad[0]:
grad_input = torch.zeros_like(x)
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(total_weight).reshape(-1, full_weight_shape[-1])
grad_weight, _ = communicator.grad_hook(grad_weight, async_op=False, module=module, is_bias=False)
grad_weight_sync.wait()

return grad_input, grad_weight, None, None, None, None, None

grad_output = grad_output.contiguous()

total_weight = communicator.weight_hook(weight, module=module)
total_weight = total_weight.reshape(full_weight_shape)
grad_input, grad_weight = None, None

if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if backend == "gmm":
Expand Down

0 comments on commit 71c32c8

Please sign in to comment.