Skip to content

Commit

Permalink
savae
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Jan 7, 2025
1 parent c3d374a commit 9a13dd7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Llama3Config70B(Llama3Config):
rotary_base: int = 500_000
seq_length: int = 8192
num_layers: int = 80
hidden_size: int = 8192
hidden_size: int = 1024
ffn_hidden_size: int = 28672
num_attention_heads: int = 64
init_method_std: float = 0.008944
Expand Down
37 changes: 31 additions & 6 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ def forward(
else:
forward_step_func = _forward_step


# if isinstance(forward_step, _ModuleStepFunction):
# forward_step_func = _forward_step(self)
# else:
# forward_step_func = _forward_step


step = MegatronStep.infer(
self,
data,
Expand All @@ -271,6 +278,7 @@ def forward(
num_microbatches=num_microbatches,
seq_length=seq_length,
step_i=step_i,
data_step=_data_step,
)
_forward_context["step"] = step
step = self.callbacks.transform_event("on_megatron_step_start", step)
Expand Down Expand Up @@ -461,13 +469,15 @@ def wrapped_forward_step(

@functools.wraps(forward_step)
def wrapped_forward_step_func(dataloader_iter, model):
if isinstance(data_step, _ModuleStepFunction):
_data_step = data_step(model)
else:
_data_step = data_step
# if isinstance(data_step, _ModuleStepFunction):
# _data_step = data_step(model)
# else:
# _data_step = data_step
# batch = _data_step(dataloader_iter)
# torch.distributed.breakpoint()

batch = _data_step(dataloader_iter)
step = context["step"]
batch = dataloader_iter.pop(0)

if isinstance(loss_reduction, _ModuleStepFunction):
forward_callback = loss_reduction(model)
Expand Down Expand Up @@ -1072,6 +1082,7 @@ class MegatronStep(Generic[ModelT, DataT]):
num_microbatches: Optional[int] = None
step_i: Optional[int] = None
decoder_seq_length: Optional[int] = None
data_step: Optional[int] = None

@classmethod
def infer(
Expand All @@ -1084,6 +1095,7 @@ def infer(
seq_length: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
data_step = None,
) -> "MegatronStep[ModelT, DataT]":
"""
Creates a MegatronStep instance, inferring missing parameters if possible.
Expand Down Expand Up @@ -1115,6 +1127,7 @@ def infer(
seq_length=seq_length or cls.infer_seq_length(data),
num_microbatches=num_microbatches or cls.infer_num_microbatches(data),
step_i=step_i,
data_step=data_step,
)

def __call__(self) -> List[Any]:
Expand Down Expand Up @@ -1144,9 +1157,21 @@ def __call__(self) -> List[Any]:
data_iterator, seq_length = self.get_data_iterator_and_seq_length()
seq_length = seq_length or self.seq_length

if isinstance(self.data_step, _ModuleStepFunction):
_data_step = self.data_step(self.model)
else:
_data_step = self.data_step


data_list = []
for _ in range(self.num_microbatches):
data_list.append(_data_step(data_iterator))

data_list = [data_list]

return self.forward_backward_func(
forward_step_func=self.forward_step_func,
data_iterator=data_iterator,
data_iterator=data_list,
model=self.model,
num_microbatches=self.num_microbatches,
seq_length=seq_length,
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,5 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor):
if (executor.setup_lines and len(executor.setup_lines) > 0)
else vboost_cmd
)

executor.env_vars["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
11 changes: 7 additions & 4 deletions nemo/utils/get_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

def is_global_rank_zero():
"""Helper function to determine if the current process is global_rank 0 (the main process)"""

# Try to get the MPI global rank env var
mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None)
if mpi_rank is not None:
return mpi_rank == 0

# Try to get the pytorch RANK env var
# RANK is set by torch.distributed.launch
rank = get_envint("RANK", None)
Expand All @@ -31,10 +37,7 @@ def is_global_rank_zero():
if slurm_rank is not None:
return slurm_rank == 0

# Try to get the MPI global rank env var
mpi_rank = get_envint("OMPI_COMM_WORLD_RANK", None)
if mpi_rank is not None:
return mpi_rank == 0


# if neither pytorch, SLURM nor MPI env vars are set
# check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars
Expand Down

0 comments on commit 9a13dd7

Please sign in to comment.