Skip to content

Commit

Permalink
merge branch
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Nov 14, 2024
1 parent 62baaec commit 8516213
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,17 +377,17 @@ def backward(ctx, grad_output):
ctx.bwd_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(ctx.bwd_stream):
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
ctx.bwd_stream.activation_buffer_list = [total_input, grad_output]
grad_weight = None
if args.enable_zbh1_pipeline:
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore
elif args.enable_zbh1_pipeline:
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.bwd_stream is not None:
total_input.record_stream(ctx.bwd_stream)
grad_output.record_stream(ctx.bwd_stream)
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
Expand Down

0 comments on commit 8516213

Please sign in to comment.