Skip to content

Commit

Permalink
Fix(QA): fix loading ckpt and add launcher setting for test loss (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Jul 9, 2024
1 parent 69cb00c commit 751c103
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed as dist

import internlm
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.checkpoint import CheckpointManager
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
Expand Down Expand Up @@ -44,8 +45,8 @@
4.616517543792725,
]


cur_loss_list = []
internlm_accelerator = get_accelerator()


def train(
Expand All @@ -57,7 +58,8 @@ def train(
interleaved: bool = False,
tp_mode: str = "mtp",
enable_sp: bool = False,
enable_ckpt: bool = False,
save_ckpt: bool = False,
load_ckpt: bool = False,
model_type: str = "INTERNLM",
optimizer_ver: str = "v1",
):
Expand All @@ -69,6 +71,9 @@ def train(
config.data.fixed_random_dataset_seqlen = False
config.lr_scheduler.total_steps = TOTAL_STEPS
config.model_type = model_type
config.ckpt.load_ckpt_folder = None
config.ckpt.load_ckpt_info = None
config.ckpt.auto_resume = False
total_steps = config.data.total_steps
skip_batches = config.data.skip_batches
label_smoothing = config.loss.label_smoothing
Expand All @@ -80,16 +85,16 @@ def train(
# update ckpt config
if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False:
config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test")
config.ckpt.auto_resume = False

if enable_ckpt:
if save_ckpt:
config.ckpt.enable_save_ckpt = True
config.ckpt.checkpoint_every = 10
config.ckpt.save_ckpt_folder = "local:llm_ckpts/"
config.ckpt.load_ckpt_folder = "local:llm_ckpts/"
config.ckpt.load_ckpt_info["content"] = ("all",)
config.ckpt.oss_snapshot_freq = 100

if load_ckpt:
config.ckpt.load_ckpt_info = dict(path="local:llm_ckpts/10", content=("all",), ckpt_type="internevo")

# update parallel config
config.parallel.tensor = dict(size=tp_size, mode=tp_mode)
config.parallel.pipeline = dict(size=pp_size)
Expand All @@ -98,7 +103,18 @@ def train(
config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True)
config.model.num_chunks = num_chunks

initialize_distributed_env(config=config)
if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [
AcceleratorType.NPU,
AcceleratorType.DIPU,
]:
config.data.use_packed_dataset = False

if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
launcher = "slurm"
else:
launcher = "torch"

initialize_distributed_env(config=config, launcher=launcher)
assert hasattr(gpc, "config") and gpc.config is not None

# check parallel config
Expand Down Expand Up @@ -241,7 +257,7 @@ def train(
)
if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item())
global cur_loss_list
global cur_loss_list # pylint: disable=W0602
cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item()))
timer("fwd-bwd").stop()

Expand Down Expand Up @@ -463,7 +479,7 @@ def test_training_with_isp_save_ckpt():
CONFIG_FILE_PATH = "./configs/7B_isp_sft.py"

# model training save ckpt
train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, enable_ckpt=True)
train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, save_ckpt=True)


@pytest.mark.training_8GPU_ISP_LOAD_CKPT
Expand All @@ -476,7 +492,7 @@ def test_training_with_isp_load_ckpt():
TOTAL_STEPS = 20

# model training load ckpt
train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, enable_ckpt=True)
train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, load_ckpt=True)


@pytest.mark.training_llama2
Expand Down

0 comments on commit 751c103

Please sign in to comment.