Skip to content

Commit

Permalink
initial refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Aug 21, 2024
1 parent 81fb735 commit 1c81359
Show file tree
Hide file tree
Showing 11 changed files with 271 additions and 283 deletions.
6 changes: 0 additions & 6 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def split_data_sequence_parallel(data, label):
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim)
if (
gpc.config.model_type == ModelType.HF.name
and "position_ids" in data
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_indexes_seq_dim)

data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim)
label = _split(label, ParallelMode.TENSOR, dim=_seq_dim)
Expand Down
428 changes: 233 additions & 195 deletions internlm/core/trainer_builder.py

Large diffs are not rendered by default.

16 changes: 6 additions & 10 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,17 @@ def build_train_loader_with_data_type():
"""
Build and return the training data loader based on data type.
Returns: A tuple of (train_dl, dataset_types).
Returns: train_dl
"""
data_cfg = gpc.config.data

train_folder = data_cfg.get("train_folder", None)


if data_cfg.type == DataType.tokenized.name:
train_ds, train_sampler, train_collate_fn = get_tokenized_train_loader_items(data_cfg)
dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"]
elif data_cfg.type == DataType.hf.name:
elif data_cfg.type == DataType.streaming.name:
train_ds, train_sampler, train_collate_fn = get_hf_train_loader_items(data_cfg)
dataset_types = ["en"]
else:
raise ValueError(f"dataset type {data_cfg.type} is not supported")

# Create the training data loader
train_dl = DataLoader(
dataset=train_ds,
Expand All @@ -171,15 +167,15 @@ def build_train_loader_with_data_type():
persistent_workers=data_cfg.get("num_worker", 4) > 0,
)

return train_dl, dataset_types
return train_dl


def build_valid_loader_with_data_type():
"""Generate and return the validation data loader based on data type."""

data_cfg = gpc.config.data

if data_cfg.type in [DataType.tokenized.name, DataType.hf.name]:
if data_cfg.type in [DataType.tokenized.name, DataType.streaming.name]:
valid_ds, valid_collate_fn = get_tokenized_valid_loader_items(data_cfg)
else:
raise ValueError(f"dataset type {data_cfg.type} is not supported")
Expand Down
2 changes: 1 addition & 1 deletion internlm/data/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def get_train_state(dataloader):
# initialize and resume train state
if gpc.config.data.type in [DataType.tokenized.name, DataType.hf.name]:
if gpc.config.data.type in [DataType.tokenized.name, DataType.streaming.name]:
train_state = TrainState(gpc.config, dataloader.batch_sampler)
else:
raise ValueError(f"dataset type {gpc.config.data.type} is not supported")
Expand Down
7 changes: 1 addition & 6 deletions internlm/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,5 @@ def packed_data_normalizer(data, label):
data["indexes"] = data["indexes"][0]
data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0)
data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item()

if gpc.config.model_type == ModelType.HF.name:
gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens")
gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen")
data["position_ids"] = data.pop("indexes")


return data, label
20 changes: 5 additions & 15 deletions internlm/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,12 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper
from internlm.model.registry import hf_config_initializer, model_initializer
from internlm.model.utils import convert_hf_config
from internlm.model.registry import model_initializer
from internlm.utils.common import get_current_device
from internlm.utils.utils import ModelType


def create_model(model_type) -> Union[nn.Module, List[nn.Module]]:

if model_type == ModelType.HF.name:
extra_kwargs = {"return_dict": False, "attn_implementation": "flash_attention_2"}
config = hf_config_initializer.get_module(module_name=model_type)(**extra_kwargs)
convert_hf_config(config)

kwargs = dict(gpc.config.model)

num_layers = kwargs.pop("num_layers")
Expand All @@ -34,13 +27,10 @@ def create_model(model_type) -> Union[nn.Module, List[nn.Module]]:
model_buidler = model_initializer.get_module(module_name=model_type)

if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
if model_type == ModelType.HF.name:
model = model_buidler(config).to(kwargs["device"])
else:
kwargs["first"] = kwargs["last"] = True
kwargs["start_layer_idx"] = 0
kwargs["num_layers"] = num_layers
model = model_buidler(**kwargs).to(kwargs["device"])
kwargs["first"] = kwargs["last"] = True
kwargs["start_layer_idx"] = 0
kwargs["num_layers"] = num_layers
model = model_buidler(**kwargs).to(kwargs["device"])
setattr(model, "first_layer", 0)
setattr(model, "last_layer", num_layers)
else:
Expand Down
2 changes: 0 additions & 2 deletions internlm/model/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def has(self, module_name: str):


model_initializer = Registry("model_initializer")
hf_config_initializer = Registry("hf_config_initializer")


def register_model_initializer() -> None:
model_initializer.register_module(ModelType.INTERNLM.name, InternLM1)
Expand Down
12 changes: 0 additions & 12 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,3 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]:
kwargs["max_seqlen"] = args[3]

return kwargs


def convert_hf_config(config):
gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = config.vocab_size
gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = config.hidden_size
gpc.config.model.num_layers = gpc.config.NUM_LAYER = config.num_hidden_layers
gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = config.num_attention_heads
gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = config.intermediate_size / config.hidden_size

# For models that use GQA
if hasattr(config, "num_key_value_heads"):
gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = config.num_key_value_heads
21 changes: 6 additions & 15 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,14 @@ def _check_module(name, module):


@llm_timeout(func_name="initialize_model")
def initialize_model(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None):
def initialize_model(model):
"""
Initialize model with Automatic Mixed Precision.
Returns:
torch.nn.Module:
The neural network model to be trained or evaluated.
"""
if pre_process_func:
pre_process_output = pre_process_func()

register_model_initializer()

model = create_model(model_type=gpc.config.model_type)

if post_process_func:
post_process_func(pre_process_output)

# should be set before NaiveAMPModel
set_fp32_attr_for_model(model)
Expand Down Expand Up @@ -525,7 +516,7 @@ def record_current_batch_training_metrics(
train_state,
optimizer,
beta2_scheduler,
trainer,
engine,
start_time,
very_begining_time,
loss,
Expand All @@ -547,10 +538,10 @@ def record_current_batch_training_metrics(

if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"]
if hasattr(trainer.engine.optimizer, "grad_scaler"):
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
if hasattr(engine.optimizer, "grad_scaler"):
scaler = engine.optimizer.grad_scaler._scale.item()
elif hasattr(engine.optimizer.optim, "grad_scaler"):
scaler = engine.optimizer.optim.grad_scaler._scale.item()

num_tokens_in_batch = batch[1].nelement()
real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL))
Expand Down
19 changes: 9 additions & 10 deletions internlm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def read_base():


class QKVPackType(IntEnum):
QKVPACKED = 2
KVPACKED = 3
QKVSPLITED = 4
QKVPACKED = 1
KVPACKED = 2
QKVSPLITED = 3

def __str__(self) -> str:
return str(self.value)
Expand All @@ -42,16 +42,15 @@ def __str__(self) -> str:


class ModelType(Enum):
HF = 1
INTERNLM = 2
INTERNLM2_PUBLIC = 3
LLAMA2 = 4
INTERNLM_MoE = 5
LLAVA = 6
INTERNLM = 1
INTERNLM2_PUBLIC = 2
LLAMA2 = 3
INTERNLM_MoE = 4
LLAVA = 5


class DataType(Enum):
hf = 1
streaming = 1
tokenized = 2


Expand Down
21 changes: 10 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,26 @@
build_valid_loader_with_data_type,
)
from internlm.initialize import initialize_distributed_env
from internlm.model.builder import create_model
from internlm.model.registry import register_model_initializer
from internlm.monitor import internevo_monitor
from internlm.train import initialize_model
from internlm.utils.common import parse_args


@internevo_monitor(feishu_alert=True, clean_run=True)
def main(args):
# initialize model
model = initialize_model()

# initialize built-in model
register_model_initializer()
model = create_model(model_type=gpc.config.model_type)

# initialize train dataloader
train_dl, dataset_types = build_train_loader_with_data_type()

train_dl = build_train_loader_with_data_type()
# initialize validation dataloader
val_dls = build_valid_loader_with_data_type()

# initialize kwargs
kwargs = vars(args) | {"dataset_types": dataset_types}


# build trainer
trainer = TrainerBuilder(model, train_dl, val_dls, **kwargs)
trainer = TrainerBuilder(model, train_dl, val_dls, **vars(args))

# training
trainer.fit()
Expand Down

0 comments on commit 1c81359

Please sign in to comment.