From 59835c7d1fa472458485499a950fc98b737a2411 Mon Sep 17 00:00:00 2001 From: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:15:14 -0800 Subject: [PATCH] Add CP support to Neva in NeMo2 (#11850) * api updates and fixes Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix Signed-off-by: yaoyu-33 * fix arg Signed-off-by: yaoyu-33 * update seq packing in mock ds Signed-off-by: yaoyu-33 * save Signed-off-by: yaoyu-33 * update preprocess_data Signed-off-by: yaoyu-33 * update seq packing Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix sp Signed-off-by: yaoyu-33 * save Signed-off-by: yaoyu-33 * fix seq packing Signed-off-by: yaoyu-33 * add truncation and padding Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Fix issues Signed-off-by: yaoyu-33 * change LLaVATemplateConfig variables to class variables * change to use field with default attributes * Apply isort and black reformatting Signed-off-by: yashaswikarnati * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Initial support for CP * Add seq packing option in energon Signed-off-by: yaoyu-33 * Fix energon conversation Signed-off-by: yaoyu-33 * add energon option in neva training script Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: parthmannan * Improvements * add ci test for packed seq Signed-off-by: yaoyu-33 * Fix for PP+CP * Max seq len fix * fix mock dataset seq packing Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix mock dataset seq packing Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix lint and update seq pack func Signed-off-by: yaoyu-33 * fix energon module Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix comments Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * address lightning issues Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Update sequence_packing.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * update energon requirements Signed-off-by: yaoyu-33 * Fix for energon update Signed-off-by: yaoyu-33 * fix for test Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * revert overlap config change Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 Signed-off-by: yashaswikarnati Signed-off-by: parthmannan Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 Co-authored-by: ykarnati Co-authored-by: yashaswikarnati Co-authored-by: Parth Mannan Co-authored-by: parthmannan Co-authored-by: Parth Mannan --- nemo/collections/vlm/neva/model/base.py | 160 ++++++++++++++++++++---- scripts/vlm/neva_finetune.py | 2 + 2 files changed, 135 insertions(+), 27 deletions(-) diff --git a/nemo/collections/vlm/neva/model/base.py b/nemo/collections/vlm/neva/model/base.py index 8cead72b4832..d7ac0788581d 100644 --- a/nemo/collections/vlm/neva/model/base.py +++ b/nemo/collections/vlm/neva/model/base.py @@ -132,6 +132,11 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: if value is not None: setattr(packed_seq_params, attr, value.cuda(non_blocking=True)) _batch["packed_seq_params"] = packed_seq_params + if ps.get_context_parallel_world_size() > 1: + num_valid_tokens_in_ub = None + if "loss_mask" in _batch and _batch["loss_mask"] is not None: + num_valid_tokens_in_ub = _batch["loss_mask"].sum() + _batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub return _batch @@ -381,6 +386,59 @@ def forward( return super().forward(x, attention_mask) +class _get_data_on_this_cp_rank(torch.autograd.Function): + """Performs sharding for Context Parallelism in THD format + + In the forward pass, indices are selected for each CP rank and remaining tokens are dropped. + In the backward pass, this class takes care of managing gradients for dropped tokens on each + CP rank. + """ + + @staticmethod + # def forward(ctx, decoder_embeddings, labels, loss_mask, packed_seq_params): + def forward(ctx, batch, packed_seq_params): + cp_size = ps.get_context_parallel_world_size() + if cp_size > 1: + try: + import transformer_engine_torch as tex + except ModuleNotFoundError as e: + logging.error( + "Please update Transformer Engine to >= 1.10 to use \ + Context Parallel with THD format data" + ) + raise e + cp_rank = ps.get_context_parallel_rank() + for key, data in batch.items(): + index = tex.thd_get_partitioned_indices( + packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank + ) + if key == "combined_embeddings": + ctx.decoder_emb_index = index + ctx.decoder_emb_seqlen = data.size(1) + batch[key] = data.index_select(1, index) + + return batch + + @staticmethod + def backward(ctx, grad_out, grad_label, grad_loss): + seqlen = ctx.decoder_emb_seqlen + index = ctx.decoder_emb_index + assert grad_out.size(1) == index.size( + 0 + ), f"Shape mismatch in incoming gradient {grad_out.shape} and \ + index from THD CP sharding {index.shape}" + grad_in = torch.zeros( + grad_out.size(0), + seqlen, + *grad_out.size()[2:], + dtype=grad_out.dtype, + device=grad_out.device, + ) + grad_in[:, ctx.decoder_emb_index, :] = grad_out + + return (grad_in, None, None, None) + + class MCoreNevaModel(MCoreLLaVAModel): def __init__( self, @@ -604,6 +662,13 @@ def forward( packed_seq_params, ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] + if self.context_parallel_lm > 1 or self.sequence_parallel_lm: + combined_embeddings, final_labels, final_loss_mask, packed_seq_params = ( + self._process_embedding_token_parallel( + combined_embeddings, final_labels, final_loss_mask, packed_seq_params + ) + ) + output = self.language_model( input_ids=None, position_ids=None, @@ -850,7 +915,9 @@ def _preprocess_data( final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length] if final_embedding is not None: - final_embedding = final_embedding.transpose(1, 0).contiguous() + if self.context_parallel_lm == 1: + # Transpose to [s,b,h] if not using CP or not using packed_sequence/THD format + final_embedding = final_embedding.transpose(1, 0).contiguous() # Truncate if exceeding the language model's max sequence length. if final_embedding.shape[0] > self._language_max_sequence_length: final_embedding = final_embedding[: self._language_max_sequence_length] @@ -864,34 +931,73 @@ def _preprocess_data( packed_seq_params.cu_seqlens_q[-1] >= packed_seq_params.cu_seqlens_q[-2] ), "with packed sequence, the truncation can only truncate on the last sequence." - if self.sequence_parallel_lm and not packed_sequence: - # Create an attention mask. This ensures correct computation. - # This is done even when no padding was done as we set mask_type to - # 'padding' or 'padding_causal' when using SP. - if attention_mask is None: - # Create base attention mask with original seq len to indicate valid tokens - attention_mask = ( - torch.ones( - ( - final_embedding.shape[1], - final_embedding.shape[0] - sp_padding_needed, - ), - device=final_embedding.device, - ) - .unsqueeze(1) - .unsqueeze(1) - ) # [b, 1, 1, final seq len - sp_padding_needed] - if sp_padding_needed > 0: - # Add the padding portion of the mask - attention_mask = torch.nn.functional.pad(attention_mask, (0, sp_padding_needed)) - - # Attention mask True/False meaning flipped in 1.7.0 - attention_mask = attention_mask < 0.5 - if self.sequence_parallel_lm: - final_embedding = tensor_parallel.scatter_to_sequence_parallel_region(final_embedding) - return final_embedding, final_labels, final_loss_mask, attention_mask + def _process_embedding_token_parallel(self, combined_embeddings, new_labels, new_loss_mask, packed_seq_params): + """Processes the input data for model parallelism support.""" + + # No pre or post processing needed with PP middle chunks. + if not self.pre_process and not self.post_process: + return combined_embeddings, new_labels, new_loss_mask, packed_seq_params + + if self.pre_process: + if self.context_parallel_lm > 1 and self.sequence_parallel_lm: + shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2 + seq_dim = 1 + elif self.context_parallel_lm > 1: + shard_factor = self.context_parallel_lm * 2 + seq_dim = 1 + elif self.sequence_parallel_lm: + shard_factor = self.tensor_model_parallel_size_lm + seq_dim = 0 + + assert ( + combined_embeddings.shape[seq_dim] % shard_factor == 0 + ), f"Sequence length should be divisible by {shard_factor} for \ + Sequence/Context parallelism" + if self.sequence_parallel_lm and self.tp_comm_overlap_lm: + assert ( + combined_embeddings.shape[seq_dim] == self._language_max_sequence_length + ), f"TP Comm overlap either requires Vision+Text token length \ + == language_max_sequence_length" + + if self.context_parallel_lm > 1: + batch = dict() + if self.pre_process: + batch.update( + { + "combined_embeddings": combined_embeddings, + } + ) + if self.post_process: + batch.update( + { + "new_labels": new_labels, + "new_loss_mask": new_loss_mask, + } + ) + # Distribute sequence across CP ranks + if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd': + from megatron.training.utils import get_batch_on_this_cp_rank + + batch = get_batch_on_this_cp_rank(batch) + else: + batch = _get_data_on_this_cp_rank.apply(batch, packed_seq_params) + + if self.pre_process: + combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H] + combined_embeddings = combined_embeddings.transpose(1, 0).contiguous() # [B,S/CP,H] -> [S/CP,B,H] + if self.post_process: + new_labels = batch["new_labels"] + new_loss_mask = batch["new_loss_mask"] + + if self.sequence_parallel_lm and self.pre_process: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region( + combined_embeddings + ) # [S/(CP*TP),B,H] + + return combined_embeddings, new_labels, new_loss_mask, packed_seq_params + class NevaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): def __init__( diff --git a/scripts/vlm/neva_finetune.py b/scripts/vlm/neva_finetune.py index 3bf0084ea60d..81749f8ef03e 100644 --- a/scripts/vlm/neva_finetune.py +++ b/scripts/vlm/neva_finetune.py @@ -153,6 +153,7 @@ def main(args): tensor_model_parallel_size=args.tp_size, pipeline_model_parallel_size=args.pp_size, encoder_pipeline_model_parallel_size=args.encoder_pp_size, + context_parallel_size=args.cp_size, pipeline_dtype=torch.bfloat16, sequence_parallel=True, ddp=DistributedDataParallelConfig( @@ -271,6 +272,7 @@ def main(args): parser.add_argument("--max_steps", type=int, required=False, default=5190) parser.add_argument("--tp_size", type=int, required=False, default=1) parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--cp_size", type=int, required=False, default=1) parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) parser.add_argument("--projector_type", type=str, required=False, default="mcore_mlp") parser.add_argument("--name", type=str, required=False, default="neva_pretrain")