diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index a69896ce..891885c3 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -1,5 +1,5 @@ JOB_NAME = "7b_internlm2_train" -model_type="INTERNLM2_PUBLIC" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False VOCAB_SIZE = 92544 @@ -205,3 +205,18 @@ # metric_dtype can be "fp32" or other string # only when set to "fp32" will use fp32 to calc in metrics # metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index ff0569d3..7d063919 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -25,7 +25,7 @@ mlp_ratio=MLP_RATIO, multiple_of=MULTIPLE_OF, norm_type="rmsnorm", - adapt_hf=True, + qk_interleaved=False, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py index 82b06249..1347b98f 100644 --- a/configs/_base_/models/internlm2_20B.py +++ b/configs/_base_/models/internlm2_20B.py @@ -23,7 +23,7 @@ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - adapt_hf=True, + qk_interleaved=False, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py index 81f5acd4..94cae4b3 100644 --- a/configs/_base_/models/internlm2_7B.py +++ b/configs/_base_/models/internlm2_7B.py @@ -23,7 +23,7 @@ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - adapt_hf=False, + qk_interleaved=True, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/doc/usage.md b/doc/usage.md index 78b92960..ad78fe2e 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -459,6 +459,31 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - 2023-07-07 12:29:16,994 INFO train.py:323 in record_current_batch_training_metrics -- tflops=189.3109313713174,step=5,loss=9.822169303894043,tgs (tokens/gpu/second)=4262.67,lr=1.4000000000000001e-06,loss_scale=65536.0,grad_norm=47.10386835560855,micro_num=4,num_consumed_tokens=786432,inf_nan_skip_batches=0,num_samples_in_batch=17,largest_length=2048,largest_batch=6,smallest_batch=3,adam_beta2=0.95,fwd_bwd_time=3.69 ``` +### 加载训练的checkpoint并生成 + +若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示: +```bash +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python generate.py --config ./configs/7B_sft.py +``` + +在配置文件中添加`generation`配置 +``` +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) +``` + ### 长文本生成 在推理阶段,我们可以使用 Dynamic NTK RoPE 来代替原始的 RoPE,从而使得模型能够适应长文本的输入输出,达到 16K 的外推效果。 diff --git a/generate.py b/generate.py new file mode 100644 index 00000000..4ae76029 --- /dev/null +++ b/generate.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import gc +import json +import logging +import os +import shutil +import socket +import traceback +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from internlm.accelerator import get_accelerator +from internlm.apis.inference import SequenceGenerator +from internlm.core.context import global_context as gpc +from internlm.data import build_generation_loader_with_data_type +from internlm.initialize import initialize_distributed_env +from internlm.monitor import initialize_monitor_manager +from internlm.monitor.monitor import monitor_manager as mm +from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.utils.common import ( + enable_pytorch_expandable_segments, + launch_time, + parse_args, +) +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.storage_manager import init_storage_manager +from tools.load_internlm2_model import get_model_device, merge_pp_within_tp + +# global llm logger +logger = logging.getLogger(__file__) +internlm_accelerator = get_accelerator() + + +def get_latest_subdirectory(folder_path): + if ":" in folder_path: + prefix, folder_path = folder_path.split(":", 1) + prefix += ":" + else: + prefix = "" + subdirectories = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))] + subdirectories_sorted = sorted( + subdirectories, key=lambda x: os.path.getctime(os.path.join(folder_path, x)), reverse=True + ) + if subdirectories_sorted: + return prefix + os.path.join(folder_path, subdirectories_sorted[0]) + else: + return None + + +def main(): + enable_pytorch_expandable_segments() + + generation_config = gpc.config["generation"] + + generation_config = type( + "", + (object,), + { + "output_folder": Path(generation_config["output_folder"]), + "ckpt_folder": generation_config["ckpt_folder"] + if "ckpt_folder" in generation_config + else get_latest_subdirectory(gpc.config.ckpt.save_ckpt_folder), + "data_folder": generation_config["data_folder"] if "data_folder" in generation_config else None, + "batch_size": generation_config.get("batch_size", None), + "eos_id": generation_config.get("eos_id", 2), + "bos_id": generation_config.get("bos_id", 1), + "pad_id": generation_config.get("bos_id", 1), + "additional_eos_token_list": generation_config.get("additional_eos_token_list", None), + "max_length": generation_config.get("max_length", 100), + "do_sample": generation_config.get("do_sample", True), + "temperature": generation_config.get("temperature", 1.0), + "num_beams": generation_config.get("num_beams", 1), + "top_k": generation_config.get("top_k", 50), + "top_p": generation_config.get("top_p", 1.0), + "repetition_penalty": generation_config.get("repetition_penalty", 1), + "length_penalty": generation_config.get("length_penalty", 1.0), + }, + ) + + if not os.path.exists(generation_config.output_folder.absolute()): + generation_config.output_folder.mkdir(exist_ok=True, parents=True) + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + torch.distributed.broadcast_object_list(objs, src=0) + current_time = objs[0].replace(":", ".") + global logger + logger = get_logger( + __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() + ) + + try: + init_storage_manager(False, None, None) + except AssertionError: + pass + except Exception as e: + raise e + + # initialize model + model = initialize_model() + _ = initialize_parallel_communicator(model) + model = model.model + + state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True) + missing_k, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if len(missing_k) != 0: + logger.warning(f"Warning: missing keys {missing_k}") + if len(unexpected_keys) != 0: + logger.warning(f"Warning: unexpected keys {unexpected_keys}") + + param_dtype = gpc.config.model.dtype + if isinstance(param_dtype, str): + try: + param_dtype = eval(param_dtype) # pylint: disable=W0123 + finally: + pass + if param_dtype == "torch.tf32": + param_dtype = torch.float32 + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + model.to(param_dtype) + model.eval() + torch.distributed.barrier() + + data_cfg = gpc.config.data + if generation_config.data_folder: + data_cfg.valid_folder = generation_config.data_folder + gene_dls = build_generation_loader_with_data_type(data_cfg, generation_config) + + sequenece_generator = SequenceGenerator( + decoder=model, + eos_token_id=generation_config.eos_id, + pad_token_id=generation_config.bos_id, + bos_token_id=generation_config.pad_id, + additional_eos_token_list=generation_config.additional_eos_token_list, + ) + + ds_count = 0 + gc.disable() + with torch.inference_mode(): + for ds_name, gene_dl in gene_dls.items(): + if len(gene_dl) == 0: + logger.info(f"Validation dataset: {ds_name} is empty") + continue + timer(f"dataset {ds_count}").start() + + # pylint: disable=forgotten-debug-statement + all_output_str = [] + # pylint: disable=unused-variable + for val_idx, (labels, input_ids) in tqdm( + enumerate(gene_dl), + desc="generate.", + total=len(gene_dl), + position=1, + leave=False, + ): + empty_cache_and_diag(val_idx, interval=gpc.config.data.empty_cache_and_diag_interval) + input_ids = torch.LongTensor(input_ids) + if input_ids.size(1) >= generation_config.max_length: + logger.warning( + f"Not generating for the {val_idx}'th batch, because the sequence " + f"length of the batch is {input_ids.size(1)} over the max generation" + f"length {generation_config.max_length}" + ) + output_ids = input_ids[:, : generation_config.max_length, ...] + else: + input_ids = input_ids.clamp(min=0, max=gpc.config.model.vocab_size).to(get_model_device(model)) + output_ids = sequenece_generator.generate( + tokens=input_ids, + max_length=generation_config.max_length, + do_sample=generation_config.do_sample, + temperature=generation_config.temperature, + num_beams=generation_config.num_beams, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + repetition_penalty=generation_config.repetition_penalty, + length_penalty=generation_config.length_penalty, + ) + for output in output_ids: + not_pad_indices = torch.nonzero(output != generation_config.pad_id) + if not_pad_indices.nelement() != 0: + sequence = output[not_pad_indices[0] :] + else: + sequence = output + sequence = sequence.tolist() + line = str.encode(json.dumps({"tokens": sequence})) + all_output_str.append( + ( + line, + len(line), + ) + ) + + bin_meta, last_position = [], 0 + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin"), "wb") as file: + for line, token_num in all_output_str: + file.write(line) + bin_meta.append((last_position, token_num)) + last_position += len(line) + + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin.meta"), "wb") as file: + np.save(file, bin_meta) + + timer(f"dataset {ds_count}").stop() + ds_count += 1 + + +if __name__ == "__main__": + args = parse_args() + hostname = socket.gethostname() + + # initialize distributed environment + initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + assert hasattr(gpc, "config") and gpc.config is not None + assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file" + assert ( + "output_folder" in gpc.config["generation"] + ), "Must set `output_folder` for the save folder of generation data" + + # initialize monitor manager context + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): + try: + main() + except Exception: + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", + ) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + ) + + # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + finally: + # local rank0 delete all files in shm_path, when use shm + devices_per_node = internlm_accelerator.device_count() + local_rank = gpc.get_global_rank() % devices_per_node + if gpc.config.data.use_shm and local_rank == 0: + if os.path.exists(gpc.config.data.shm_path): + shutil.rmtree(gpc.config.data.shm_path) diff --git a/internlm/data/__init__.py b/internlm/data/__init__.py index 08ad5d88..35f6ade4 100644 --- a/internlm/data/__init__.py +++ b/internlm/data/__init__.py @@ -1,4 +1,5 @@ from .build_dataloader import ( + build_generation_loader_with_data_type, build_train_loader_with_data_type, build_valid_loader_with_data_type, ) @@ -6,4 +7,5 @@ __all__ = [ "build_train_loader_with_data_type", "build_valid_loader_with_data_type", + "build_generation_loader_with_data_type", ] diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index 5af73b84..aa09a960 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -16,7 +16,11 @@ StaticBatchSampler, get_dpsampler_dataloader, ) -from internlm.data.tokenized.collaters import jsonl_ds_collate_fn, packed_collate_fn +from internlm.data.tokenized.collaters import ( + generation_collate_fn, + jsonl_ds_collate_fn, + packed_collate_fn, +) from internlm.data.tokenized.dataset import get_dataset_dict from internlm.data.tokenized.dummy_dataset import RandomDataset from internlm.data.tokenized.dummy_dataset_multimodal import RandomDatasetMultimodal @@ -213,3 +217,46 @@ def build_valid_loader_with_data_type(): ) return val_dls + + +def build_generation_loader_with_data_type(data_cfg, generation_cfg): + """Generate and return the validation data loader based on data type.""" + + if data_cfg.type == "tokenized": + gene_ds, _ = get_tokenized_valid_loader_items(data_cfg) + else: + raise ValueError(f"dataset type {data_cfg.type} is not supported") + + if gene_ds is None: + return None + + gene_dls = {} + for gene_name, ds in gene_ds.items(): + # making the batch_size of validate larger can speed up the evaluation, but it should not be too large, + # otherwise too much data may be dropped + batch_size = min( + data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) + ) + batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz + if generation_cfg.batch_size: + batch_size = generation_cfg.batch_size + + if batch_size == 0 and gpc.is_rank_for_log(): + logger.info(f"skip validate {gene_name}.") + continue + + gene_dls[gene_name] = get_dpsampler_dataloader( + ds, + shuffle=False, + num_workers=data_cfg.get("num_worker", 0), + batch_size=batch_size, + collate_fn=partial(generation_collate_fn, pad_id=generation_cfg.pad_id), + ) + + if gpc.is_rank_for_log(): + logger.info( + f"load validation dataset {gene_name} with valid batch size {str(batch_size)} and " + f"samples {str(len(gene_dls[gene_name]))}." + ) + + return gene_dls diff --git a/internlm/data/tokenized/collaters.py b/internlm/data/tokenized/collaters.py index 785ecc60..fab7c5ac 100644 --- a/internlm/data/tokenized/collaters.py +++ b/internlm/data/tokenized/collaters.py @@ -100,3 +100,29 @@ def jsonl_ds_collate_fn(batch, max_length_per_sample): return {"input_ids": xs, "images": images}, ys else: return {"input_ids": xs}, ys + + +def generation_collate_fn(batch, pad_id=0): + """ + Collate function for generation dataset. + + Args: + batch (List[Dict]): List of dictionaries representing each sample in batch. + Each dictionary contains "tokens". + + Returns: + Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", + and the tensor of padded "labels". + + """ + xs, ys = [], [] + for x in batch: + tokens = [abs(w) for w in x["tokens"]] + labels = [w if w > 0 else -100 for w in x["tokens"]] + labels = labels[1:] + [-100] + xs.append(torch.as_tensor(tokens[::-1])) + ys.append(torch.as_tensor(labels[::-1])) # y has been shifted + xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=pad_id).flip(dims=[1]) + ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100).flip(dims=[1]) + + return {"input_ids": xs}, ys diff --git a/internlm/data/tokenized/dataset.py b/internlm/data/tokenized/dataset.py index e39a39b7..8991272b 100644 --- a/internlm/data/tokenized/dataset.py +++ b/internlm/data/tokenized/dataset.py @@ -51,6 +51,6 @@ def get_dataset_dict(folder, split="valid") -> Dict: datasets.append(ds) if datasets: ds = ConcatDataset(datasets=datasets) - data_dict[os.path.basename(root)] = ds + data_dict[os.path.basename(root.rstrip(os.path.sep))] = ds return data_dict diff --git a/internlm/data/tokenized/packed_dataset.py b/internlm/data/tokenized/packed_dataset.py index 1d525965..b2a8b109 100644 --- a/internlm/data/tokenized/packed_dataset.py +++ b/internlm/data/tokenized/packed_dataset.py @@ -599,6 +599,7 @@ class PackedDatasetWithPadForMultimodal(PackedDataset): Args: dataset: The original dataset to pack. max_length_per_sample: The maximum length of each original sample. Default is 2048. + padding_side: The padding side. Default is "right". packed_length: The length of each packed sample. Default is 4096. padding_idx: The token id of padding. Default is 0. """ @@ -609,13 +610,17 @@ def __init__( max_length_per_sample: int = 2048, packed_length: int = 4096, padding_idx: int = 0, + padding_side: str = "right", image_token_id: int = 200000, + has_image: bool = True, ): super().__init__(dataset, max_length_per_sample, packed_length) self.padding_idx = padding_idx + self.padding_side = padding_side self.sample_indices, self.belongs = self.accu_sample_len(self.seed) self.num_tokens = sum(self.lengths) self.image_token_id = image_token_id + self.has_image = has_image def get_dataset_name(self): return self.dataset.get_dataset_name() @@ -653,7 +658,10 @@ def __len__(self): def build_pack(self, index): - pack, cu_seqlens, indexes, labels, type_ids, images = [], [0], [], [], [], [] + pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], [] + + if self.has_image: + images = [] start_pos = np.searchsorted(self.belongs, index, "left") end_pos = np.searchsorted(self.belongs, index, "right") @@ -665,8 +673,9 @@ def build_pack(self, index): for sample_idx in cur_samples: sample = self.dataset[sample_idx] length = min(len(sample["tokens"]), self.max_length_per_sample) - cur_images = sample["images"] - images.extend(cur_images) + if self.has_image: + cur_images = sample["images"] + images.extend(cur_images) chunk = sample["tokens"][:length] pack.extend(chunk) cu_seqlens.append(cu_seqlens[-1] + len(chunk)) @@ -680,10 +689,16 @@ def build_pack(self, index): indexes.extend(list(range(length))) if cu_seqlens[-1] != self.packed_length: - pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) - labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) - type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) - indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + if self.padding_side == "right": + pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) + type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) + indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + else: + pack = [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + pack + labels = [-100] * (self.packed_length - cu_seqlens[-1]) + labels + type_ids = [0] * (self.packed_length - cu_seqlens[-1]) + type_ids + indexes = [0] * (self.packed_length - cu_seqlens[-1]) + indexes cu_seqlens.append(self.packed_length) out = { diff --git a/tools/README.md b/tools/README.md index 2b47b1f4..a24040ca 100644 --- a/tools/README.md +++ b/tools/README.md @@ -5,7 +5,7 @@ ├── interface.py # 生成用的接口 ├── internlm_sft_on_moss.py # 在 moss 数据集上进行 SFT 训练的样例 ├── intern_moss_example.py # 在 moss 数据集上进行训练的样例 -├── load_internlm_model.py # 加载 InternLM 原生格式并进行推理的工具 +├── load_internlm2_model.py # 加载 InternLM 原生格式并进行推理的工具 ├── openai_api.py # 使用 OpenAI 接口实现的流式部署 ├── pal_inference.py # PAL 范式推理的工具 ├── README_EN.md @@ -141,3 +141,58 @@ if __name__ == "__main__": if hasattr(chunk.choices[0].delta, "content"): print(chunk.choices[0].delta.content, end="", flush=True) ``` + +# load_internlm2_model.py + +加载`InternEvo`框架训练的模型权重并进行推理 + +```bash +torchrun --master_port 12321 --nnodes=1 --node_rank=0 --nproc_per_node=1 --ckpt_dir=[where the internlm2 model weights are stored] --tokenizer_path=tools/tokenizer_internlm2.model tools/load_internlm2_model.py +``` + +LLaMA 7B推理的例子: + +```python + model = initialize_internlm_model( + model_type="LLAMA2", + ckpt_dir=args.ckpt_dir, + model_config=dict( + num_chunks=1, + checkpoint=0.2, + dtype="torch.bfloat16", + embed_split_hidden=True, + num_layers=32, + hidden_size=4096, + vocab_size=32000, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=32, + mlp_ratio=2.675, + use_flash_attn=True, + norm_type="rmsnorm", + apply_post_layer_norm=False, + no_bias=True, + layer_norm_epsilon=1e-5, + ), + del_model_prefix=True, + ) + + from sentencepiece import SentencePieceProcessor + + prompt = """<|User|>:{query}\n<|Bot|>:""" + prompt = prompt.replace("{query}", "hello") + # LLaMA tokenizer转换成SentencePieceProcessor 或 此处加载Huggingface Tokenizer,则需额外将generate中调用的decode等方法修改成HF风格 + tokenizer = SentencePieceProcessor(args.tokenizer_path) + generation_config = GenerationConfig() + output_generator = internlm_interactive_generation( + model=model, + tokenizer=tokenizer, + prompt=prompt, + generation_config=generation_config, + additional_eos_token_list=[tokenizer.eos_id()], + ) + + for text in output_generator: + print(text) +``` diff --git a/tools/README_EN.md b/tools/README_EN.md index 63aba410..fe93560d 100644 --- a/tools/README_EN.md +++ b/tools/README_EN.md @@ -6,7 +6,7 @@ This directory provide some tools for model training with the following file str ├── interface.py # interface for generation ├── internlm_sft_on_moss.py # example for SFT training on moss dataset ├── intern_moss_example.py # example for training on moss dataset -├── load_internlm_model.py # tools for loading InternLM checkpoints and generating +├── load_internlm2_model.py # tools for loading InternLM checkpoints and generating ├── openai_api.py # stream deployment with OpenAI APIs ├── pal_inference.py # tools for PAL reasoning ├── README_EN.md diff --git a/tools/load_internlm_model.py b/tools/load_internlm2_model.py similarity index 85% rename from tools/load_internlm_model.py rename to tools/load_internlm2_model.py index 3de52c22..6f1561b0 100644 --- a/tools/load_internlm_model.py +++ b/tools/load_internlm2_model.py @@ -1,3 +1,4 @@ +import argparse import inspect import logging import os @@ -9,9 +10,8 @@ from internlm.apis.inference import SequenceGenerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.initialize.launch import launch_from_torch -from internlm.model.registry import model_initializer -from internlm.train import initialize_model +from internlm.initialize.launch import initialize_distributed_env +from internlm.train import initialize_model, initialize_parallel_communicator from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -102,6 +102,10 @@ def match_fn_signature(func: Callable, args_dict: Dict) -> None: logger.warning(f"These args:{args_set} are popped for func:{func.__name__}.") +def use_torchrun_starter(): + return os.getenv("RANK") is not None + + def get_tp_rank() -> int: """Get the tensor parallel rank. This script uses torchrun to initialize the environment, so RANK in the environment variable is the tensor @@ -119,7 +123,7 @@ def get_tp_world_size() -> int: Returns: int: The tensor parallel world size to which the current process belongs. """ - return int(os.environ.get("WORLD_SIZE", 0)) + return int(os.environ.get("WORLD_SIZE", 1)) def initialize_internlm_model( @@ -173,27 +177,32 @@ def initialize_internlm_model( model_config["dtype"] = param_dtype model_config["parallel_output"] = False # FIXME: fix it. - match_fn_signature(model_initializer.get_module(model_type), model_config) if gpc.is_rank_for_log(): logger.info(f"model_config: {model_config}.") - launch_from_torch( + + initialize_distributed_env( config=dict( model_type=model_type, model=model_config, parallel=dict( zero1=dict(size=1, fsdp=False), pipeline=dict(size=1, interleaved_overlap=True), - tensor=get_tp_world_size(), + tensor=dict(size=get_tp_world_size(), mode="mtp"), sequence_parallel=0, ), ), + launcher="torch" if use_torchrun_starter() else "slurm", seed=seed, + master_port=23574, + args_check=False, ) - model = initialize_model() # Directly get the origin model without NativeAMP wrapper. + model = initialize_model() + _ = initialize_parallel_communicator(model) model = model.model state_dict = merge_pp_within_tp(ckpt_dir, del_model_prefix=del_model_prefix) + load_info = model.load_state_dict(state_dict, strict=False) logger.info(f"Rank:{gpc.get_local_rank(ParallelMode.TENSOR)}. Load info: {load_info}.") @@ -224,11 +233,11 @@ def internlm_interactive_generation( sequenece_generator = SequenceGenerator( decoder=model, eos_token_id=tokenizer.eos_id(), - pad_token_id=tokenizer.eos_id(), + pad_token_id=tokenizer.bos_id(), bos_token_id=tokenizer.bos_id(), additional_eos_token_list=additional_eos_token_list, ) - additional_eos_token_list = torch.LongTensor(additional_eos_token_list) + additional_eos_token_list = torch.LongTensor(additional_eos_token_list) if additional_eos_token_list else None input_ids = [tokenizer.bos_id()] + tokenizer.encode(prompt) input_ids = torch.LongTensor([input_ids]).to(get_model_device(model)) output_generator = sequenece_generator.streaming_generate( @@ -250,32 +259,48 @@ def internlm_interactive_generation( yield cur_output +def get_default_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_dir", type=str, help="path to the ckpt file", required=True) + parser.add_argument( + "--tokenizer_path", type=str, default="tools/tokenizer_internlm2.model", help="path to the tokenizer file" + ) + + return parser + + if __name__ == "__main__": + parser = get_default_parser() + args = parser.parse_args() + """ Here is a simple example to generate with origin internlm model architecture. Use the following command to run: - >>> torchrun --master_port 12331 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm_model.py + >>> torchrun --master_port 12321 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm2_model.py """ model = initialize_internlm_model( - model_type="INTERNLM", - ckpt_dir="[Please replace this with the directory where the internlm model weights are stored]", + model_type="INTERNLM2_PUBLIC", + ckpt_dir=args.ckpt_dir, model_config=dict( - checkpoint=False, - num_attention_heads=32, + num_chunks=1, + checkpoint=0.2, + dtype="torch.bfloat16", embed_split_hidden=True, - vocab_size=103168, - embed_grad_scale=1, - parallel_output=False, - hidden_size=4096, num_layers=32, - mlp_ratio=8 / 3, - apply_post_layer_norm=False, - dtype="torch.bfloat16", + hidden_size=4096, + vocab_size=92544, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=8, + mlp_ratio=3.5, + use_flash_attn=True, norm_type="rmsnorm", + qk_interleaved=True, + apply_post_layer_norm=False, + no_bias=True, layer_norm_epsilon=1e-5, - use_flash_attn=True, - num_chunks=1, - use_dynamic_ntk_rope=True, + rope_base=1000000, ), del_model_prefix=True, ) @@ -284,15 +309,14 @@ def internlm_interactive_generation( prompt = """<|User|>:{query}\n<|Bot|>:""" prompt = prompt.replace("{query}", "hello") - tokenizer = SentencePieceProcessor("tools/tokenizer_internlm.model") # pylint: disable=E1121 - + tokenizer = SentencePieceProcessor(args.tokenizer_path) # pylint: disable=E1121 generation_config = GenerationConfig() output_generator = internlm_interactive_generation( model=model, tokenizer=tokenizer, prompt=prompt, generation_config=generation_config, - additional_eos_token_list=[103028], + additional_eos_token_list=[tokenizer.eos_id()], ) for text in output_generator: diff --git a/web_demo_internlm.py b/web_demo_internlm.py index 8730c0c2..abe0568e 100644 --- a/web_demo_internlm.py +++ b/web_demo_internlm.py @@ -8,7 +8,7 @@ from internlm.accelerator import get_accelerator from tools.interface import GenerationConfig -from tools.load_internlm_model import ( +from tools.load_internlm2_model import ( initialize_internlm_model, internlm_interactive_generation, )