diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index a9e7beec53..58e1d1f976 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -377,17 +377,17 @@ def backward(ctx, grad_output): ctx.bwd_stream.wait_stream(get_accelerator().current_stream()) with get_accelerator().stream(ctx.bwd_stream): WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) - ctx.bwd_stream.activation_buffer_list = [total_input, grad_output] grad_weight = None - if args.enable_zbh1_pipeline: - from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore + elif args.enable_zbh1_pipeline: WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - + if ctx.bwd_stream is not None: + total_input.record_stream(ctx.bwd_stream) + grad_output.record_stream(ctx.bwd_stream) if ctx.sequence_parallel: handle.wait() return sub_grad_input, grad_weight, grad_bias, None, None, None