diff --git a/doc/code-docs/source/checkpoint.rst b/doc/code-docs/source/checkpoint.rst index c01c6950..4b22f754 100644 --- a/doc/code-docs/source/checkpoint.rst +++ b/doc/code-docs/source/checkpoint.rst @@ -36,7 +36,7 @@ CheckpointManager ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - load_ckpt_info=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), + load_ckpt_info=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internevo"), auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_info'. checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3, volc and oss2 ckpt) @@ -95,7 +95,6 @@ load_ckpt_info 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type - ``internevo``:internevo约定的checkpoint存储格式。 - ``llama``:llama约定的checkpoint存储格式。 - ``hf_llama``:huggingface llama约定的checkpoint存储格式。 - - ``hf_model``:适用于加载huggingface所有模型的checkpoint存储格式。 下面给出两个例子: @@ -107,10 +106,6 @@ load_ckpt_info 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type # 从文件存储相对路径 ckpt_model 中加载所有的状态,适合断点续训的场景 load_ckpt_info = dict(path="local:ckpt_model", content=("all",), ckpt_type="internevo") - # 从 huggingface 下载指定模型,加载checkpoint - load_ckpt_info = dict(path="internlm/internlm-7b", content=("model",), ckpt_type="hf_model") - - .. _asyncupload: 异步上传 @@ -179,4 +174,3 @@ config.ckpt 中相关的参数: # 如果存入的step>0,则任务会在存储ckpt后自动退出 # 如果存入的step<0,则任务会在存储ckpt后会继续训练 echo "999" > ./llm_alter/1006_pr.log - diff --git a/doc/en/usage.md b/doc/en/usage.md index f1d0f65e..919a3ef2 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -11,7 +11,7 @@ If you are using the HuggingFace datasets for on-the-fly streaming load and toke ```python TRAIN_FOLDER = "roneneldan/TinyStories" data = dict( - type="hf", + type="streaming", tokenizer_path="internlm/internlm-7b", ) ``` @@ -117,7 +117,7 @@ ckpt = dict( # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. @@ -295,8 +295,8 @@ ckpt = dict( # When resuming training from a breakpoint,: # (1) 'path' is the path of the loaded checkpoint. # (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internevo" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), ) ``` @@ -427,4 +427,4 @@ model = dict( Regarding the principle of Dyanmic NTK, please refer to 1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases -2. https://kexue.fm/archives/9675 \ No newline at end of file +2. https://kexue.fm/archives/9675 diff --git a/doc/usage.md b/doc/usage.md index e05cc693..66fb3885 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -11,7 +11,7 @@ ```python TRAIN_FOLDER = "roneneldan/TinyStories" data = dict( - type="hf", + type="streaming", tokenizer_path="internlm/internlm-7b", ) ``` @@ -106,7 +106,7 @@ ckpt = dict( # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama", "hf_model". + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama". load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) @@ -327,7 +327,7 @@ train_folder设置为huggingface上可以通过load_dataset直接下载的数据 TRAIN_FOLDER = "roneneldan/TinyStories" SEQ_LEN = 2048 data = dict( - type="hf", + type="streaming", tokenizer_path="internlm/internlm-7b", seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048 micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1 @@ -353,10 +353,8 @@ ckpt = dict( checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 # content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all" - # ckpt_type 表示加载的模型类型,目前支持: "internevo", "llama", "hf_llama", "hf_model" - # 其中,"hf_model"类型表示从huggingface上下载模型加载ckpt,MODEL_ONLY_FOLDER需要设置为可以 - # 通过AutoModel直接加载的模型路径,如:"internlm/internlm-7b" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # ckpt_type 表示加载的模型类型,目前支持: "internevo", "llama", "hf_llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), # 'auto_resume' 旨在在遇到由硬件故障引起的训练中断/挂起时,自动从 'save_ckpt_folder' 加载最新的检查点, # 使用调度系统(例如 k8s/slurm)在训练重启时自动重启机制。 # 请注意,如果未设置 auto_resume(其默认值为 True),它将不会默认加载 load_ckpt_info 中指定的检查点路径。 @@ -510,4 +508,4 @@ generation = dict( 关于 Dyanmic NTK 的原理,详细请参考 1. [dynamically_scaled_rope_further_increases](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases) -2. [https://kexue.fm/archives/9675](https://kexue.fm/archives/9675) \ No newline at end of file +2. [https://kexue.fm/archives/9675](https://kexue.fm/archives/9675) diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index 306c6681..a8b7d55e 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -10,7 +10,6 @@ from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load from internlm.utils.utils import ModelType -from transformers import AutoModelForCausalLM logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -300,22 +299,8 @@ def load_internlm_with_dynamic_parallel_size(folder, model): ) -def load_hf_model_pretrained_weights(folder, model): - """NOTE: when loading huggingface's model pretrained weights, you should set `adapt_hf=True` in your config.""" - assert folder is not None, "Please specify the folder of the pretrained model" - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - - pretrained_model = AutoModelForCausalLM.from_pretrained(folder, trust_remote_code=True) - model.load_state_dict(pretrained_model.state_dict(), strict=False) - - if gpc.is_rank_for_log(): - logger.info("Pretrained weights loaded successfully") - - LOAD_FUNC_DICT = { "llama": load_llama_pretrained_weights, "hf_llama": load_hf_llama_pretrained_weights, "internlm_test": load_internlm_with_dynamic_parallel_size, - "hf_model": load_hf_model_pretrained_weights, } diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index c0c8a304..03c3da16 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -892,41 +892,40 @@ def _(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs) -> torc return context -def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, float], nn.Module]: +def auto_wrap_distributed_attention(attn_impl: nn.Module) -> Callable[[bool, Any, float], nn.Module]: """ Wrap a local attention module to a distributed one, which will be used in the ISP parallelism. """ # should we impl distributed attention as a metaclass? - def _attetion_constructor( - local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0 - ) -> nn.Module: + def _attetion_constructor(attn_impl: type, causal=False, softmax_scale=None, attention_dropout=0.0) -> nn.Module: tp_mode = gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) if tp_mode != TensorParallelMode.isp.name: - return local_attn_cls(causal, softmax_scale, attention_dropout) + return attn_impl(causal, softmax_scale, attention_dropout) else: return DistributedAttention( - local_attention=local_attn_cls(causal, softmax_scale, attention_dropout), + local_attention=attn_impl(causal, softmax_scale, attention_dropout), sequence_process_group=gpc.get_group(ParallelMode.TENSOR), ) - return partial(_attetion_constructor, local_attn_cls=cls) + return partial(_attetion_constructor, attn_impl=attn_impl) -def auto_wrap_func_distributed_attention(func: Callable) -> Callable[..., Callable]: +def auto_wrap_func_distributed_attention(attn_impl: Callable) -> Callable[..., Callable]: """ Wrap a local attention function to a distributed one, which will be used in the ISP parallelism. """ - def _attention_func_constructor(*args, local_attn_func=None, **kwargs) -> Callable: + # should we impl distributed attention as a metaclass? + def _attetion_constructor(*args, attn_impl: type, **kwargs) -> Callable: tp_mode = gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) if tp_mode != TensorParallelMode.isp.name: - return local_attn_func(*args, **kwargs) + return attn_impl(*args, **kwargs) else: return DistributedAttention( - local_attention=local_attn_func, sequence_process_group=gpc.get_group(ParallelMode.TENSOR) + local_attention=attn_impl, sequence_process_group=gpc.get_group(ParallelMode.TENSOR) )(*args, **kwargs) - return partial(_attention_func_constructor, local_attn_func=func) + return partial(_attetion_constructor, attn_impl=attn_impl) diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 2f7b0f70..074d551c 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -11,7 +11,7 @@ from internlm.core.context import global_context as gpc from internlm.core.parallel.comm.utils import _split from internlm.utils.logger import get_logger -from internlm.utils.utils import ModelType, TensorParallelMode +from internlm.utils.utils import TensorParallelMode logger = get_logger(__file__) @@ -31,12 +31,6 @@ def split_data_sequence_parallel(data, label): and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name ): data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim) - if ( - gpc.config.model_type == ModelType.HF.name - and "position_ids" in data - and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name - ): - data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_indexes_seq_dim) data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim) label = _split(label, ParallelMode.TENSOR, dim=_seq_dim) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 0884664e..c285e297 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -12,7 +12,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode from internlm.core.trainer import Trainer -from internlm.data.streaming.utils import hf_simple_resume +from internlm.data.streaming.utils import streaming_simple_resume from internlm.data.train_state import get_train_state from internlm.eval.evaluation import evaluate_on_val_dls from internlm.initialize.initialize_trainer import initialize_trainer @@ -24,6 +24,7 @@ initialize_llm_profile, initialize_optimizer, initialize_parallel_communicator, + inject_model, load_new_batch, record_current_batch_training_metrics, ) @@ -48,13 +49,26 @@ class TrainerBuilder(Trainer): """ - Manage InternEvo training process. + A class for building and managing InternEvo training workflow. + + `TrainerBuilder` extends the base `Trainer` class to include additional functionality + for initializing and managing various components involved in the training process. + This includes setting up logging, checkpoints, loss functions, optimizers, metrics, + train states, and profiling tools. The class supports distributed training and allows + for seamless management of training, evaluation, and checkpointing. Args: model (torch.nn.Module): The model to be trained. - train_dl (torch.utils.data.DataLoader): The training data loader. - val_dls (Optional[Dict[str, torch.utils.data.DataLoader]]): The validation data loaders. - kwargs: Additional keyward arguments. + train_dl (DataLoader): DataLoader for training data. + val_dls (Optional[Dict[str, DataLoader]], optional): DataLoaders for validation data. + **kwargs: Additional keyword arguments including: + - config (str): Path to the configuration file. + - profiling (bool): Whether to enable profiling. + - dataset_types (list): List of dataset types to be used for training. + + Methods: + __init__: Initializes the `TrainerBuilder` with the model, data loaders, and other components. + fit: Runs the training loop, processing batches and handling evaluation and checkpointing. """ def __init__( @@ -65,50 +79,100 @@ def __init__( **kwargs, ): """ - Initialize InternEvo TrainerBuilder class. + Initialize TrainerBuilder with necessary components for training. Args: model (torch.nn.Module): The model to be trained. - train_dl (torch.utils.data.DataLoader): The training data loader. - val_dls (Optional[Dict[str, torch.utils.data.DataLoader]]): The validation data loaders. - kwargs: Additional keyward arguments. + train_dl (DataLoader): DataLoader for training data. + val_dls (Optional[Dict[str, DataLoader]], optional): DataLoaders for validation data. + **kwargs: Additional keyword arguments including: + - config (str): Path to the configuration file. + - profiling (bool): Whether to enable profiling. + - dataset_types (list): List of dataset types to be used for training. + """ + # set very_beginning_time + self.very_beginning_time = time.time() + # broadcast current_time and setup logging + self.current_time = self._setup_time_and_logging() + # load config_lines + config_lines = self._read_config(kwargs["config"]) - # record very_begining_time - very_begining_time = time.time() + # inject model for amp and parallel training + model = inject_model(model) - # set torch expandable_segments + # enable torch expandable_segments enable_pytorch_expandable_segments() - # get and broadcast current time + # initialize loss function + criterion = self._initialize_criterion() + + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + # initialize train state + train_state = get_train_state(train_dl) + + # initialize optimizer + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + # initialize checkpoint manager and try resume training + self.ckpt_manager = self._initialize_checkpoint_manager(model, optimizer, lr_scheduler, train_dl, config_lines) + self.ckpt_manager.try_resume_training(train_state, self.current_time) + + # initialize customed llm writer + self.writer = self._initialize_writer(train_state, config_lines) + + # initialize metric for calculating accuracy and perplexity + self.metric = self._initialize_metric(kwargs["dataset_types"]) + + # initialize simple memory profiler + self.memory_profiler = self._initialize_memory_profiler(model, optimizer, kwargs["profiling"]) + + # initialize batch skipper + self.batch_skipper = self._initialize_batch_skipper(train_state) + + # initialize trainer + engine, scheduler = initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator), + ) + + # set attributes + self._set_attributes( + kwargs["profiling"], train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator + ) + + super().__init__(engine, scheduler) + + def _setup_time_and_logging(self) -> str: current_time = launch_time() objs = [current_time] dist.broadcast_object_list(objs, src=0) current_time = objs[0].replace(":", ".") global logger logger = get_logger( - __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() + __name__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() ) + return current_time - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) - - with open(kwargs["config"], "r") as f: - config_lines = f.readlines() + def _read_config(self, config_path: str) -> list: + with open(config_path, "r") as f: + return f.readlines() - # initialize loss function - criterion = FlashGPTLMLoss( + def _initialize_criterion(self) -> FlashGPTLMLoss: + return FlashGPTLMLoss( parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing ) - # initialize and resume train state - train_state = get_train_state(train_dl) - - # initialize optimizer - optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) - - # initialize checkpoint manager - ckpt_manager = CheckpointManager( + def _initialize_checkpoint_manager( + self, model, optimizer, lr_scheduler, train_dl, config_lines + ) -> CheckpointManager: + return CheckpointManager( ckpt_config=gpc.config.ckpt, model=model, optimizer=optimizer, @@ -119,17 +183,14 @@ def __init__( feishu_address=gpc.config.monitor.alert.feishu_alert_address, ) - # load other persistent training states - ckpt_manager.try_resume_training(train_state, current_time) - - # initialize customed llm writer - writer = Writer( + def _initialize_writer(self, train_state, config_lines) -> Writer: + return Writer( job_name=gpc.config.JOB_NAME, - launch_time=current_time, + launch_time=self.current_time, file_name=get_parallel_log_file_name(), tensorboard_folder=gpc.config.tensorboard_folder, - resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt. - step_count=train_state.step_count, # resume from ckpt. + resume_tb_folder=train_state.resume_tb_folder, + step_count=train_state.step_count, config=config_lines, logger=logger, enable_tb=gpc.config.enable_tb, @@ -137,200 +198,177 @@ def __init__( total_steps=gpc.config.data.total_steps, ) - # initialize metric for calculating accuracy and perplexity + def _initialize_metric(self, dataset_types) -> AccPerplex: _dp_pg = gpc.get_group(ParallelMode.ISP_DATA) if is_using_isp() else gpc.get_group(ParallelMode.DATA) _tp_pg = dist.new_group([gpc.get_global_rank()]) if is_using_isp() else gpc.get_group(ParallelMode.TENSOR) - metric = AccPerplex( + return AccPerplex( device=get_current_device(), tp_pg=_tp_pg, dp_pg=_dp_pg, - dataset_types=kwargs["dataset_types"], + dataset_types=dataset_types, ) - # initialize simple memory profiler - if kwargs["profiling"]: - self.memory_profiler = SimpleMemoryProfiler( + def _initialize_memory_profiler(self, model, optimizer, profiling) -> Optional[SimpleMemoryProfiler]: + if profiling: + return SimpleMemoryProfiler( model, optimizer.optim, - log_folder=f"RUN/{gpc.config.JOB_NAME}/{current_time}/memory_trace/rank{gpc.get_global_rank()}_" + log_folder=f"RUN/{gpc.config.JOB_NAME}/{self.current_time}/memory_trace/rank{gpc.get_global_rank()}_" + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", ) else: - self.memory_profiler = None + return None - # initialize batch skipper + def _initialize_batch_skipper(self, train_state) -> BatchSkipper: skip_batches = gpc.config.data.skip_batches - if gpc.config.data.type == DataType.hf.name and gpc.config.ckpt.auto_resume: - skip_batches = hf_simple_resume(train_state) - self.batch_skipper = BatchSkipper(skip_batches) - - # set TrainerBuilder attributes - self.very_begining_time = very_begining_time - self.profiling = kwargs["profiling"] - self.current_time = current_time + if gpc.config.data.type == DataType.tokenized.name and gpc.config.ckpt.auto_resume: + skip_batches = streaming_simple_resume(train_state) + return BatchSkipper(skip_batches) + + def _set_attributes(self, profiling, train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator): + self.profiling = profiling self.train_dl = train_dl self.val_dls = val_dls self.train_state = train_state self.optimizer = optimizer self.beta2_scheduler = beta2_scheduler self.isp_communicator = isp_communicator - self.writer = writer - self.ckpt_manager = ckpt_manager - self.metric = metric - - # initialize trainer - engine, scheduler = initialize_trainer( - model=model, - optimizer=optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler, - beta2_scheduler=beta2_scheduler, - scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), - ) - - super().__init__(engine, scheduler) def fit(self): """ - Launch InternEvo TrainerBuilder training process. + Run InternEvo training loop. """ - self.train() train_iter = iter(self.train_dl) with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: - # close automatic garbage collection gc.disable() - # start iterating the train data and begin training for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): - empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) - # internlm_accelerator.memory._record_memory_history() - start_time = time.time() - timer("one-batch").start() - - # load batch data - batch, train_iter = load_new_batch( - train_dl=self.train_dl, train_iter=train_iter, train_state=self.train_state - ) + if self._process_batch(batch_count, train_iter, prof): + break - # record the consumed samples in training - self.train_state.batch_count = batch_count - self.train_state.num_consumed_samples_in_epoch += len(batch[1]) - if self.batch_skipper(batch_count): # skip this batch - if gpc.is_rank_for_log(): - logger.info(f"Skip batch count:`{batch_count}`...") - timer("one-batch").stop() - continue - - # zero the grads of parameters - self.zero_grad() - # process data - if batch[0].get("type_ids", None) is not None: - self.metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) - # if batch[0].get("cu_seqlens", None) is not None: - # metric.set_cu_seqlens(cu_seqlens=batch[0].pop("cu_seqlens", None)) - - # do forward and backward - timer("fwd-bwd").start() - - moe_loss = None - if hasattr(gpc.config.model, "num_experts"): - _, _, loss, moe_loss = self.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - else: - _, _, loss = self.execute_schedule( # pylint: disable=W0632 - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - timer("fwd-bwd").stop() - - if self.isp_communicator and self.isp_communicator.enable_memory_pool: - self.isp_communicator.memory_pool.reset_lazy_pools() - - # update parameters, and returns (success_update, grad_norm) - trainer_result = self.step() - assert trainer_result is not None - - success_update, grad_norm_groups = trainer_result - if success_update: # update parameters successfully - self.train_state.step_count += 1 - else: - self.train_state.inf_nan_skip_batches += ( - 1 # record the amount of updating parameters unsuccessfully. - ) - if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case - logger.warning(f"Warning: skip parameter update at step {batch_count}.") - send_alert_message( - address=gpc.config.monitor.alert.feishu_alert_address, - message=f"Warning: skip parameter update at step {batch_count}.", - ) - - get_tflops_func = partial( - get_megatron_flops, - checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.data["seq_len"], - hidden_size=gpc.config.model.hidden_size, - num_layers=gpc.config.model.num_layers, - vocab_size=gpc.config.model.vocab_size, - global_batch_size=gpc.config.data.micro_bsz - * gpc.config.data.micro_num - * gpc.get_world_size(ParallelMode.DATA), - global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.model["mlp_ratio"], - ) + self.ckpt_manager.wait_async_upload_finish() - # calculate and record the training metrics, eg. loss, accuracy and so on. - record_current_batch_training_metrics( - get_tflops_func=get_tflops_func, - logger=logger, - writer=self.writer, - success_update=success_update, - batch_count=batch_count, - batch=batch, - train_state=self.train_state, - optimizer=self.optimizer, - beta2_scheduler=self.beta2_scheduler, - trainer=self, - start_time=start_time, - very_begining_time=self.very_begining_time, - loss=loss, - moe_loss=moe_loss, - grad_norm=grad_norm_groups, - metric=self.metric, + def _process_batch(self, batch_count: int, train_iter, prof) -> bool: + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) + start_time = time.time() + timer("one-batch").start() + + batch, train_iter = self._load_and_prepare_batch(batch_count, train_iter) + if self.batch_skipper(batch_count): + if gpc.is_rank_for_log(): + logger.info(f"Skip batch count:`{batch_count}`...") + timer("one-batch").stop() + return False + + timer("fwd-bwd").start() + loss, moe_loss = self._forward_backward(batch) + timer("fwd-bwd").stop() + + if self.isp_communicator and self.isp_communicator.enable_memory_pool: + self.isp_communicator.memory_pool.reset_lazy_pools() + + success_update, grad_norm_groups = self._update_parameters() + self._record_metrics(batch_count, batch, start_time, loss, moe_loss, success_update, grad_norm_groups) + timer("one-batch").stop() + + if self._should_evaluate(): + self._evaluate() + + if self.ckpt_manager.try_save_checkpoint(self.train_state): + return True + + self._update_profilers(batch_count, prof) + return False + + def _load_and_prepare_batch(self, batch_count: int, train_iter): + batch, train_iter = load_new_batch(train_dl=self.train_dl, train_iter=train_iter, train_state=self.train_state) + self.train_state.batch_count = batch_count + self.train_state.num_consumed_samples_in_epoch += len(batch[1]) + if batch[0].get("type_ids", None) is not None: + self.metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + return batch, train_iter + + def _forward_backward(self, batch): + self.zero_grad() + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = self.execute_schedule( + batch, forward_only=False, return_loss=True, return_output_label=False + ) + else: + _, _, loss = self.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) + moe_loss = None + return loss, moe_loss + + def _update_parameters(self): + trainer_result = self.step() + assert trainer_result is not None + success_update, grad_norm_groups = trainer_result + if success_update: + self.train_state.step_count += 1 + else: + self.train_state.inf_nan_skip_batches += 1 + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): + logger.warning(f"Warning: skip parameter update at step {self.train_state.batch_count}.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=f"Warning: skip parameter update at step {self.train_state.batch_count}.", ) + return success_update, grad_norm_groups + + def _record_metrics(self, batch_count: int, batch, start_time, loss, moe_loss, success_update, grad_norm_groups): + get_tflops_func = partial( + get_megatron_flops, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.data["seq_len"], + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.model["mlp_ratio"], + ) + record_current_batch_training_metrics( + get_tflops_func=get_tflops_func, + logger=logger, + writer=self.writer, + success_update=success_update, + batch_count=batch_count, + batch=batch, + train_state=self.train_state, + optimizer=self.optimizer, + beta2_scheduler=self.beta2_scheduler, + engine=self.engine, + start_time=start_time, + very_begining_time=self.very_beginning_time, + loss=loss, + moe_loss=moe_loss, + grad_norm=grad_norm_groups, + metric=self.metric, + ) - timer("one-batch").stop() - - # evaluate on validation data loaders - if gpc.config.data.valid_every > 0 and self.train_state.step_count % gpc.config.data.valid_every == 0: - evaluate_on_val_dls( - self, - val_dls=self.val_dls, - writer=self.writer, - logger=logger, - step_count=self.train_state.step_count, - ) - - # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" - # # save batch sampler that tracks the true consumed samples - now_break = self.ckpt_manager.try_save_checkpoint(self.train_state) - if now_break: - break - - if self.memory_profiler is not None: - self.memory_profiler.step() - - if batch_count % 2 == 0: - prof.step() + def _should_evaluate(self) -> bool: + return ( + gpc.config.data.valid_every > 0 + and self.train_state.step_count > 0 + and self.train_state.step_count % gpc.config.data.valid_every == 0 + ) - # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + def _evaluate(self): + evaluate_on_val_dls( + self, + val_dls=self.val_dls, + writer=self.writer, + logger=logger, + step_count=self.train_state.step_count, + ) - self.ckpt_manager.wait_async_upload_finish() + def _update_profilers(self, batch_count: int, prof): + if self.memory_profiler is not None: + self.memory_profiler.step() + if batch_count % 2 == 0: + prof.step() diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index 15781b15..ae0452b1 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -7,10 +7,10 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.data.streaming.batch_sampler import StreamingStaticBatchSampler -from internlm.data.streaming.collaters import pack_collate_fn +from internlm.data.streaming.collaters import streaming_packed_collate_fn from internlm.data.streaming.dataset import ( - HuggingFacePackedDataset, - HuggingFaceStreamingDataset, + StreamingDataset, + StreamingPackedDatasetWithCut, ) from internlm.data.tokenized.batch_sampler import ( StaticBatchSampler, @@ -119,15 +119,15 @@ def get_tokenized_valid_loader_items(data_cfg): return valid_ds, valid_collate_fn -def get_hf_train_loader_items(data_cfg): - assert not data_cfg.pack_sample_into_one, "hf dataloader curently only supports pack_sample_into_one=False" - train_ds = HuggingFaceStreamingDataset( +def get_streaming_train_loader_items(data_cfg): + assert not data_cfg.pack_sample_into_one, "streaming dataloader curently only supports pack_sample_into_one=False" + train_ds = StreamingDataset( dataset_name=data_cfg.train_folder, tokenizer_name=data_cfg.tokenizer_path, model_max_length=data_cfg.seq_len, subset_name=data_cfg.get("subset_name", None), ) - train_ds = HuggingFacePackedDataset( + train_ds = StreamingPackedDatasetWithCut( dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz, @@ -137,7 +137,10 @@ def get_hf_train_loader_items(data_cfg): batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size ) train_collate_fn = partial( - pack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + streaming_packed_collate_fn, + micro_num=data_cfg.micro_num, + micro_bsz=data_cfg.micro_bsz, + seq_len=data_cfg.seq_len, ) return train_ds, train_sampler, train_collate_fn @@ -150,14 +153,10 @@ def build_train_loader_with_data_type(): """ data_cfg = gpc.config.data - train_folder = data_cfg.get("train_folder", None) - if data_cfg.type == DataType.tokenized.name: train_ds, train_sampler, train_collate_fn = get_tokenized_train_loader_items(data_cfg) - dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"] - elif data_cfg.type == DataType.hf.name: - train_ds, train_sampler, train_collate_fn = get_hf_train_loader_items(data_cfg) - dataset_types = ["en"] + elif data_cfg.type == DataType.streaming.name: + train_ds, train_sampler, train_collate_fn = get_streaming_train_loader_items(data_cfg) else: raise ValueError(f"dataset type {data_cfg.type} is not supported") @@ -171,6 +170,9 @@ def build_train_loader_with_data_type(): persistent_workers=data_cfg.get("num_worker", 4) > 0, ) + train_folder = data_cfg.get("train_folder", None) + dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"] + return train_dl, dataset_types @@ -179,7 +181,7 @@ def build_valid_loader_with_data_type(): data_cfg = gpc.config.data - if data_cfg.type in [DataType.tokenized.name, DataType.hf.name]: + if data_cfg.type in [DataType.tokenized.name, DataType.streaming.name]: valid_ds, valid_collate_fn = get_tokenized_valid_loader_items(data_cfg) else: raise ValueError(f"dataset type {data_cfg.type} is not supported") diff --git a/internlm/data/streaming/__init__.py b/internlm/data/streaming/__init__.py index a1f58aa0..98b99791 100644 --- a/internlm/data/streaming/__init__.py +++ b/internlm/data/streaming/__init__.py @@ -1,12 +1,12 @@ from .batch_sampler import StreamingStaticBatchSampler -from .collaters import pack_collate_fn -from .dataset import HuggingFacePackedDataset, HuggingFaceStreamingDataset -from .utils import hf_simple_resume +from .collaters import streaming_packed_collate_fn +from .dataset import StreamingDataset, StreamingPackedDatasetWithCut +from .utils import streaming_simple_resume __all__ = [ "StreamingStaticBatchSampler", - "pack_collate_fn", - "HuggingFaceStreamingDataset", - "HuggingFacePackedDataset", - "hf_simple_resume", + "streaming_packed_collate_fn", + "StreamingDataset", + "StreamingPackedDatasetWithCut", + "streaming_simple_resume", ] diff --git a/internlm/data/streaming/collaters.py b/internlm/data/streaming/collaters.py index 39ee2394..bbce69c4 100644 --- a/internlm/data/streaming/collaters.py +++ b/internlm/data/streaming/collaters.py @@ -1,7 +1,7 @@ import torch -def pack_collate_fn(batch, micro_num, micro_bsz, seq_len): +def streaming_packed_collate_fn(batch, micro_num, micro_bsz, seq_len): packed_length = micro_bsz * seq_len input_ids_list = [] diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py index 68178736..c565571d 100644 --- a/internlm/data/streaming/dataset.py +++ b/internlm/data/streaming/dataset.py @@ -11,9 +11,9 @@ from transformers import AutoTokenizer -class HuggingFaceStreamingDataset(Dataset): +class StreamingDataset(Dataset): """ - Streaming and on-the-fly tokenized dataset for huggingface + Streaming and on-the-fly tokenized dataset """ def __init__( @@ -47,7 +47,6 @@ def _tokenize(self, samples): texts = [sample["text"] for sample in samples] tokenized_outputs = self.tokenizer(texts, truncation=True) for i in range(len(samples)): - assert "input_ids" in tokenized_outputs, "huggingface tokenizer should generate input_ids" if len(tokenized_outputs["input_ids"][i]) > 0: yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} @@ -55,9 +54,9 @@ def __getitem__(self, _): return next(self.senior_iterator) -class HuggingFacePackedDataset(Dataset): +class StreamingPackedDatasetWithCut(Dataset): """ - Simple packed dataset for huggingface + Packed dataset for streaming """ def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0): diff --git a/internlm/data/streaming/utils.py b/internlm/data/streaming/utils.py index ee331ba2..adf63124 100644 --- a/internlm/data/streaming/utils.py +++ b/internlm/data/streaming/utils.py @@ -4,11 +4,11 @@ from internlm.core.context import global_context as gpc -# simple auto_resume for huggingface streaming dataloader -def hf_simple_resume(train_state): +# simple auto_resume for streaming dataloader +def streaming_simple_resume(train_state): skip_batches = gpc.config.data.get("skip_batches", "") if train_state.batch_count > 0: - assert skip_batches == "", "skip_batches should be empty when huggingface dataloader resume from ckpts" + assert skip_batches == "", "skip_batches should be empty when streaming dataloader resume from ckpts" skip_batches = f"0-{train_state.batch_count - 1}" train_state.batch_count = 0 train_state.num_consumed_samples_in_epoch = 0 diff --git a/internlm/data/train_state.py b/internlm/data/train_state.py index 776cd04e..aa061b82 100644 --- a/internlm/data/train_state.py +++ b/internlm/data/train_state.py @@ -6,7 +6,7 @@ def get_train_state(dataloader): # initialize and resume train state - if gpc.config.data.type in [DataType.tokenized.name, DataType.hf.name]: + if gpc.config.data.type in [DataType.tokenized.name, DataType.streaming.name]: train_state = TrainState(gpc.config, dataloader.batch_sampler) else: raise ValueError(f"dataset type {gpc.config.data.type} is not supported") diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 63cc0fbf..b05c9ad0 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -6,8 +6,6 @@ import torch from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer import ParallelMode -from internlm.utils.utils import ModelType def get_dataset_type_ids_map(path): @@ -70,9 +68,4 @@ def packed_data_normalizer(data, label): data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0) data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() - if gpc.config.model_type == ModelType.HF.name: - gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens") - gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen") - data["position_ids"] = data.pop("indexes") - return data, label diff --git a/internlm/model/builder.py b/internlm/model/builder.py index b50a1fdb..f5516f2d 100644 --- a/internlm/model/builder.py +++ b/internlm/model/builder.py @@ -5,19 +5,12 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper -from internlm.model.registry import hf_config_initializer, model_initializer -from internlm.model.utils import convert_hf_config +from internlm.model.registry import model_initializer from internlm.utils.common import get_current_device -from internlm.utils.utils import ModelType def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: - if model_type == ModelType.HF.name: - extra_kwargs = {"return_dict": False, "attn_implementation": "flash_attention_2"} - config = hf_config_initializer.get_module(module_name=model_type)(**extra_kwargs) - convert_hf_config(config) - kwargs = dict(gpc.config.model) num_layers = kwargs.pop("num_layers") @@ -34,13 +27,10 @@ def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: model_buidler = model_initializer.get_module(module_name=model_type) if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): - if model_type == ModelType.HF.name: - model = model_buidler(config).to(kwargs["device"]) - else: - kwargs["first"] = kwargs["last"] = True - kwargs["start_layer_idx"] = 0 - kwargs["num_layers"] = num_layers - model = model_buidler(**kwargs).to(kwargs["device"]) + kwargs["first"] = kwargs["last"] = True + kwargs["start_layer_idx"] = 0 + kwargs["num_layers"] = num_layers + model = model_buidler(**kwargs).to(kwargs["device"]) setattr(model, "first_layer", 0) setattr(model, "last_layer", num_layers) else: diff --git a/internlm/model/registry.py b/internlm/model/registry.py index affd0ab6..a1921ab6 100644 --- a/internlm/model/registry.py +++ b/internlm/model/registry.py @@ -8,10 +8,11 @@ from internlm.model.modeling_llama import Llama2 from internlm.model.modeling_llava import Llava from internlm.model.modeling_moe import Internlm1MoE +from internlm.utils.common import SingletonMeta from internlm.utils.utils import ModelType -class Registry: +class Registry(metaclass=SingletonMeta): """This is a registry class used to register classes and modules so that a universal object builder can be enabled. @@ -34,13 +35,12 @@ def register_module(self, module_name: str, func: Callable): module_name (str): The name of module to be registered. Returns: function: The module to be registered, so as to use it normally if via importing. - Raises: - AssertionError: Raises an AssertionError if the module has already been registered before. """ - assert module_name not in self._registry, f"{module_name} already registered in {self.name}" - - self._registry[module_name] = func + if self.has(module_name): + return + else: + self._registry[module_name] = func def get_module(self, module_name: str): """Retrieves a module with name `module_name` and returns the module if it has @@ -54,9 +54,10 @@ def get_module(self, module_name: str): NameError: Raises a NameError if the module to be retrieved has neither been registered directly nor as third party modules before. """ - if module_name in self._registry: + if self.has(module_name): return self._registry[module_name] - raise NameError(f"Module {module_name} not found in the registry {self.name}") + else: + raise NameError(f"Module {module_name} not found in the registry {self.name}") def has(self, module_name: str): """Searches for a module with name `module_name` and returns a boolean value indicating @@ -74,7 +75,6 @@ def has(self, module_name: str): model_initializer = Registry("model_initializer") -hf_config_initializer = Registry("hf_config_initializer") def register_model_initializer() -> None: @@ -83,3 +83,6 @@ def register_model_initializer() -> None: model_initializer.register_module(ModelType.LLAMA2.name, Llama2) model_initializer.register_module(ModelType.INTERNLM_MoE.name, Internlm1MoE) model_initializer.register_module(ModelType.LLAVA.name, Llava) + + +register_model_initializer() diff --git a/internlm/model/utils.py b/internlm/model/utils.py index e4a40dab..c2311007 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List -from internlm.core.context import global_context as gpc from internlm.model.modules.mha import MHA @@ -52,15 +51,3 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]: kwargs["max_seqlen"] = args[3] return kwargs - - -def convert_hf_config(config): - gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = config.vocab_size - gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = config.hidden_size - gpc.config.model.num_layers = gpc.config.NUM_LAYER = config.num_hidden_layers - gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = config.num_attention_heads - gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = config.intermediate_size / config.hidden_size - - # For models that use GQA - if hasattr(config, "num_key_value_heads"): - gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = config.num_key_value_heads diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 47ab70ce..dcded6f8 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -93,6 +93,8 @@ except (ImportError, ModuleNotFoundError): pass +IS_INJECTED = "is_injected" + logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -163,7 +165,6 @@ def _check_module(name, module): def initialize_model(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None): """ Initialize model with Automatic Mixed Precision. - Returns: torch.nn.Module: The neural network model to be trained or evaluated. @@ -178,6 +179,25 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f if post_process_func: post_process_func(pre_process_output) + return inject_model(model) + + +def inject_model(model): + """ + Inject model with Automatic Mixed Precision. + + Args: + torch.nn.Module: + The bare neural network model to be trained or evaluated. + + Returns: + torch.nn.Module: + The injected neural network model to be trained or evaluated. + """ + + if hasattr(model, IS_INJECTED) and getattr(model, IS_INJECTED): + return model + # should be set before NaiveAMPModel set_fp32_attr_for_model(model) @@ -217,6 +237,9 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA set_mode(random_mode) + # set is_injected flag + setattr(model, "IS_INJECTED", True) + return model @@ -525,7 +548,7 @@ def record_current_batch_training_metrics( train_state, optimizer, beta2_scheduler, - trainer, + engine, start_time, very_begining_time, loss, @@ -547,10 +570,10 @@ def record_current_batch_training_metrics( if success_update and gpc.is_rank_for_log(): lr = optimizer.param_groups[0]["lr"] - if hasattr(trainer.engine.optimizer, "grad_scaler"): - scaler = trainer.engine.optimizer.grad_scaler._scale.item() - elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"): - scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item() + if hasattr(engine.optimizer, "grad_scaler"): + scaler = engine.optimizer.grad_scaler._scale.item() + elif hasattr(engine.optimizer.optim, "grad_scaler"): + scaler = engine.optimizer.optim.grad_scaler._scale.item() num_tokens_in_batch = batch[1].nelement() real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL)) diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index 38080b1e..03dee6df 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -42,16 +42,15 @@ def __str__(self) -> str: class ModelType(Enum): - HF = 1 - INTERNLM = 2 - INTERNLM2_PUBLIC = 3 - LLAMA2 = 4 - INTERNLM_MoE = 5 - LLAVA = 6 + INTERNLM = 1 + INTERNLM2_PUBLIC = 2 + LLAMA2 = 3 + INTERNLM_MoE = 4 + LLAVA = 5 class DataType(Enum): - hf = 1 + streaming = 1 tokenized = 2 diff --git a/tests/test_training/7B_check_init.py b/tests/test_training/7B_check_init.py index c682cf4e..b13f254a 100644 --- a/tests/test_training/7B_check_init.py +++ b/tests/test_training/7B_check_init.py @@ -32,7 +32,7 @@ # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. - # load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("all",), ckpt_type="internlm"), + # load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("all",), ckpt_type="internevo"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) # with an automatic restart mechanism upon training reboot. @@ -144,18 +144,7 @@ use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) -""" -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. -""" + parallel = dict( zero1=dict(size=-1), tensor=dict(size=2, mode="mtp"), diff --git a/train.py b/train.py index 08534420..acfacdf3 100644 --- a/train.py +++ b/train.py @@ -8,15 +8,15 @@ build_valid_loader_with_data_type, ) from internlm.initialize import initialize_distributed_env +from internlm.model.builder import create_model from internlm.monitor import internevo_monitor -from internlm.train import initialize_model from internlm.utils.common import parse_args @internevo_monitor(feishu_alert=True, clean_run=True) def main(args): # initialize model - model = initialize_model() + model = create_model(model_type=gpc.config.model_type) # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() @@ -24,11 +24,8 @@ def main(args): # initialize validation dataloader val_dls = build_valid_loader_with_data_type() - # initialize kwargs - kwargs = vars(args) | {"dataset_types": dataset_types} - # build trainer - trainer = TrainerBuilder(model, train_dl, val_dls, **kwargs) + trainer = TrainerBuilder(model, train_dl, val_dls, **(vars(args) | {"dataset_types": dataset_types})) # training trainer.fit()