Skip to content

Commit

Permalink
Add CP support to Neva in NeMo2 (#11850)
Browse files Browse the repository at this point in the history
* api updates and fixes

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* fix arg

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* update seq packing in mock ds

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* save

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* update preprocess_data

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* update seq packing

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix sp

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* save

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* fix seq packing

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* add truncation and padding

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* Fix issues

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* change LLaVATemplateConfig variables to class variables

* change to use field with default attributes

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* Initial support for CP

* Add seq packing option in energon

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Fix energon conversation

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* add energon option in neva training script

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: parthmannan <parthmannan@users.noreply.github.com>

* Improvements

* add ci test for packed seq

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Fix for PP+CP

* Max seq len fix

* fix mock dataset seq packing

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix mock dataset seq packing

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix lint and update seq pack func

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* fix energon module

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* fix comments

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* address lightning issues

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* Update sequence_packing.py

Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>

* update energon requirements

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Fix for energon update

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* fix for test

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>

* revert overlap config change

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Signed-off-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com>
Signed-off-by: parthmannan <parthmannan@users.noreply.github.com>
Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Co-authored-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Co-authored-by: ykarnati <ykarnati@nvidia.com>
Co-authored-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com>
Co-authored-by: Parth Mannan <pmannan@nvidia.com>
Co-authored-by: parthmannan <parthmannan@users.noreply.github.com>
Co-authored-by: Parth Mannan <parth.mannan95@gmail.com>
  • Loading branch information
7 people authored Jan 21, 2025
1 parent 7d74e71 commit 59835c7
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 27 deletions.
160 changes: 133 additions & 27 deletions nemo/collections/vlm/neva/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
if value is not None:
setattr(packed_seq_params, attr, value.cuda(non_blocking=True))
_batch["packed_seq_params"] = packed_seq_params
if ps.get_context_parallel_world_size() > 1:
num_valid_tokens_in_ub = None
if "loss_mask" in _batch and _batch["loss_mask"] is not None:
num_valid_tokens_in_ub = _batch["loss_mask"].sum()
_batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub

return _batch

Expand Down Expand Up @@ -381,6 +386,59 @@ def forward(
return super().forward(x, attention_mask)


class _get_data_on_this_cp_rank(torch.autograd.Function):
"""Performs sharding for Context Parallelism in THD format
In the forward pass, indices are selected for each CP rank and remaining tokens are dropped.
In the backward pass, this class takes care of managing gradients for dropped tokens on each
CP rank.
"""

@staticmethod
# def forward(ctx, decoder_embeddings, labels, loss_mask, packed_seq_params):
def forward(ctx, batch, packed_seq_params):
cp_size = ps.get_context_parallel_world_size()
if cp_size > 1:
try:
import transformer_engine_torch as tex
except ModuleNotFoundError as e:
logging.error(
"Please update Transformer Engine to >= 1.10 to use \
Context Parallel with THD format data"
)
raise e
cp_rank = ps.get_context_parallel_rank()
for key, data in batch.items():
index = tex.thd_get_partitioned_indices(
packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank
)
if key == "combined_embeddings":
ctx.decoder_emb_index = index
ctx.decoder_emb_seqlen = data.size(1)
batch[key] = data.index_select(1, index)

return batch

@staticmethod
def backward(ctx, grad_out, grad_label, grad_loss):
seqlen = ctx.decoder_emb_seqlen
index = ctx.decoder_emb_index
assert grad_out.size(1) == index.size(
0
), f"Shape mismatch in incoming gradient {grad_out.shape} and \
index from THD CP sharding {index.shape}"
grad_in = torch.zeros(
grad_out.size(0),
seqlen,
*grad_out.size()[2:],
dtype=grad_out.dtype,
device=grad_out.device,
)
grad_in[:, ctx.decoder_emb_index, :] = grad_out

return (grad_in, None, None, None)


class MCoreNevaModel(MCoreLLaVAModel):
def __init__(
self,
Expand Down Expand Up @@ -604,6 +662,13 @@ def forward(
packed_seq_params,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]

if self.context_parallel_lm > 1 or self.sequence_parallel_lm:
combined_embeddings, final_labels, final_loss_mask, packed_seq_params = (
self._process_embedding_token_parallel(
combined_embeddings, final_labels, final_loss_mask, packed_seq_params
)
)

output = self.language_model(
input_ids=None,
position_ids=None,
Expand Down Expand Up @@ -850,7 +915,9 @@ def _preprocess_data(
final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length]

if final_embedding is not None:
final_embedding = final_embedding.transpose(1, 0).contiguous()
if self.context_parallel_lm == 1:
# Transpose to [s,b,h] if not using CP or not using packed_sequence/THD format
final_embedding = final_embedding.transpose(1, 0).contiguous()
# Truncate if exceeding the language model's max sequence length.
if final_embedding.shape[0] > self._language_max_sequence_length:
final_embedding = final_embedding[: self._language_max_sequence_length]
Expand All @@ -864,34 +931,73 @@ def _preprocess_data(
packed_seq_params.cu_seqlens_q[-1] >= packed_seq_params.cu_seqlens_q[-2]
), "with packed sequence, the truncation can only truncate on the last sequence."

if self.sequence_parallel_lm and not packed_sequence:
# Create an attention mask. This ensures correct computation.
# This is done even when no padding was done as we set mask_type to
# 'padding' or 'padding_causal' when using SP.
if attention_mask is None:
# Create base attention mask with original seq len to indicate valid tokens
attention_mask = (
torch.ones(
(
final_embedding.shape[1],
final_embedding.shape[0] - sp_padding_needed,
),
device=final_embedding.device,
)
.unsqueeze(1)
.unsqueeze(1)
) # [b, 1, 1, final seq len - sp_padding_needed]
if sp_padding_needed > 0:
# Add the padding portion of the mask
attention_mask = torch.nn.functional.pad(attention_mask, (0, sp_padding_needed))

# Attention mask True/False meaning flipped in 1.7.0
attention_mask = attention_mask < 0.5
if self.sequence_parallel_lm:
final_embedding = tensor_parallel.scatter_to_sequence_parallel_region(final_embedding)

return final_embedding, final_labels, final_loss_mask, attention_mask

def _process_embedding_token_parallel(self, combined_embeddings, new_labels, new_loss_mask, packed_seq_params):
"""Processes the input data for model parallelism support."""

# No pre or post processing needed with PP middle chunks.
if not self.pre_process and not self.post_process:
return combined_embeddings, new_labels, new_loss_mask, packed_seq_params

if self.pre_process:
if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
seq_dim = 1
elif self.context_parallel_lm > 1:
shard_factor = self.context_parallel_lm * 2
seq_dim = 1
elif self.sequence_parallel_lm:
shard_factor = self.tensor_model_parallel_size_lm
seq_dim = 0

assert (
combined_embeddings.shape[seq_dim] % shard_factor == 0
), f"Sequence length should be divisible by {shard_factor} for \
Sequence/Context parallelism"
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
assert (
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
), f"TP Comm overlap either requires Vision+Text token length \
== language_max_sequence_length"

if self.context_parallel_lm > 1:
batch = dict()
if self.pre_process:
batch.update(
{
"combined_embeddings": combined_embeddings,
}
)
if self.post_process:
batch.update(
{
"new_labels": new_labels,
"new_loss_mask": new_loss_mask,
}
)
# Distribute sequence across CP ranks
if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd':
from megatron.training.utils import get_batch_on_this_cp_rank

batch = get_batch_on_this_cp_rank(batch)
else:
batch = _get_data_on_this_cp_rank.apply(batch, packed_seq_params)

if self.pre_process:
combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H]
combined_embeddings = combined_embeddings.transpose(1, 0).contiguous() # [B,S/CP,H] -> [S/CP,B,H]
if self.post_process:
new_labels = batch["new_labels"]
new_loss_mask = batch["new_loss_mask"]

if self.sequence_parallel_lm and self.pre_process:
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings
) # [S/(CP*TP),B,H]

return combined_embeddings, new_labels, new_loss_mask, packed_seq_params


class NevaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions scripts/vlm/neva_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def main(args):
tensor_model_parallel_size=args.tp_size,
pipeline_model_parallel_size=args.pp_size,
encoder_pipeline_model_parallel_size=args.encoder_pp_size,
context_parallel_size=args.cp_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=True,
ddp=DistributedDataParallelConfig(
Expand Down Expand Up @@ -271,6 +272,7 @@ def main(args):
parser.add_argument("--max_steps", type=int, required=False, default=5190)
parser.add_argument("--tp_size", type=int, required=False, default=1)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--cp_size", type=int, required=False, default=1)
parser.add_argument("--encoder_pp_size", type=int, required=False, default=0)
parser.add_argument("--projector_type", type=str, required=False, default="mcore_mlp")
parser.add_argument("--name", type=str, required=False, default="neva_pretrain")
Expand Down

0 comments on commit 59835c7

Please sign in to comment.