Skip to content

Commit

Permalink
feat(usability): Refine model inject helper to support huggingface mo…
Browse files Browse the repository at this point in the history
…dels (#331)

Co-authored-by: sallyjunjun <jun_sally@126.com>
  • Loading branch information
season0528 and sallyjunjun authored Sep 20, 2024
1 parent 6e591ac commit 569eb25
Show file tree
Hide file tree
Showing 18 changed files with 332 additions and 136 deletions.
14 changes: 9 additions & 5 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def _split_data_for_sequence_parallel(data, label):
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim)
if (
"position_ids" in data
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name
):
data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_indexes_seq_dim)

# NOTICE: For compatibility where the shape of position_ids is [batch, seqlen, ...]
if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False):
_position_ids_seq_dim = 1
data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_position_ids_seq_dim)

data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim)

Expand Down Expand Up @@ -158,8 +158,12 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:

if linear_name in ("head", "output"):
return "head"
if linear_name in ("gate"):
return "head" # for MoE model
elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"):
return "column"
elif linear_name in ("fc1", "fc2", "linear_1", "linear_2"): # for vit model
return "column"
elif linear_name in ("wo", "out_proj", "w2") and tp_mode == TensorParallelMode.isp.name:
return "column"
elif linear_name in ("wo", "out_proj", "w2"):
Expand Down
10 changes: 5 additions & 5 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
from functools import partial
from typing import Dict, Optional
from typing import Dict, List, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -59,7 +59,7 @@ class TrainerBuilder(Trainer):
for seamless management of training, evaluation, and checkpointing.
Args:
model (torch.nn.Module): The model to be trained.
model (Union[torch.nn.Module, List[torch.nn.Module]]): The model to be trained.
train_dl (DataLoader): DataLoader for training data.
val_dls (Optional[Dict[str, DataLoader]], optional): DataLoaders for validation data.
**kwargs: Additional keyword arguments including:
Expand All @@ -74,7 +74,7 @@ class TrainerBuilder(Trainer):

def __init__(
self,
model: torch.nn.Module,
model: Union[torch.nn.Module, List[torch.nn.Module]],
train_dl: DataLoader,
val_dls: Optional[Dict[str, DataLoader]] = None,
**kwargs,
Expand All @@ -83,7 +83,7 @@ def __init__(
Initialize TrainerBuilder with necessary components for training.
Args:
model (torch.nn.Module): The model to be trained.
model (Union[torch.nn.Module, List[torch.nn.Module]]): The model to be trained.
train_dl (DataLoader): DataLoader for training data.
val_dls (Optional[Dict[str, DataLoader]], optional): DataLoaders for validation data.
**kwargs: Additional keyword arguments including:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _initialize_memory_profiler(self, model, optimizer, profiling) -> Optional[S

def _initialize_batch_skipper(self, train_state) -> BatchSkipper:
skip_batches = gpc.config.data.skip_batches
if gpc.config.data.type == DataType.tokenized.name and gpc.config.ckpt.auto_resume:
if gpc.config.data.type == DataType.streaming.name and gpc.config.ckpt.auto_resume:
skip_batches = streaming_simple_resume(train_state)
return BatchSkipper(skip_batches)

Expand Down
30 changes: 20 additions & 10 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from internlm.data.streaming.collaters import streaming_packed_collate_fn
from internlm.data.streaming.dataset import (
StreamingDataset,
StreamingPackedDatasetWithCut,
StreamingDatasetPackSampleIntoOneWithCut,
StreamingDatasetPackSampleWithPad,
)
from internlm.data.tokenized.batch_sampler import (
StaticBatchSampler,
Expand Down Expand Up @@ -128,19 +129,25 @@ def get_tokenized_valid_loader_items(data_cfg):


def get_streaming_train_loader_items(data_cfg):
assert not data_cfg.pack_sample_into_one, "streaming dataloader curently only supports pack_sample_into_one=False"
train_ds = StreamingDataset(
dataset_path=data_cfg.train_folder,
train_folder=data_cfg.train_folder,
tokenizer_path=data_cfg.tokenizer_path,
model_max_length=data_cfg.seq_len,
content_name=data_cfg.get("content_name", "text"),
subset_name=data_cfg.get("subset_name", None),
)
train_ds = StreamingPackedDatasetWithCut(
dataset=train_ds,
seq_len=data_cfg.seq_len,
micro_bsz=data_cfg.micro_bsz,
)
if data_cfg.get("pack_sample_into_one", False):
train_ds = StreamingDatasetPackSampleIntoOneWithCut(
dataset=train_ds,
seq_len=data_cfg.seq_len,
micro_bsz=data_cfg.micro_bsz,
)
else:
train_ds = StreamingDatasetPackSampleWithPad(
dataset=train_ds,
seq_len=data_cfg.seq_len,
micro_bsz=data_cfg.micro_bsz,
)
train_sampler = StreamingStaticBatchSampler(
batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size
)
Expand Down Expand Up @@ -192,11 +199,14 @@ def get_megatron_train_loader_items(data_cfg):


def get_mock_train_loader_items(data_cfg):
assert data_cfg.get(
"pack_sample_into_one", False
), "mocked dataloader curently only supports pack_sample_into_one=True"
train_ds = MockedDataset(
data_dir=data_cfg.train_folder, # defined the path of mocked data
train_folder=data_cfg.train_folder,
micro_bsz=data_cfg.micro_bsz,
micro_num=data_cfg.micro_num,
seq_len=data_cfg.seq_len,
mocked_steps=data_cfg.mocked_steps, # defined the steps of mocked data
)
train_sampler = MockedSequentialBatchSampler(train_ds, data_cfg.micro_num)
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz)
Expand Down
2 changes: 2 additions & 0 deletions internlm/data/megatron/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/gpt_dataset.py
# adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/indexed_dataset.py
import hashlib
import os
import struct
Expand Down
8 changes: 4 additions & 4 deletions internlm/data/mocked/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ class MockedSequentialBatchSampler:
MockedSequentialBatchSampler
"""

def __init__(self, data_source, micro_num):
self.data_source = data_source
def __init__(self, train_ds, micro_num):
self.train_ds = train_ds
self.micro_num = micro_num

def __iter__(self):
num_samples = len(self.data_source)
num_samples = len(self.train_ds)
for start in range(0, num_samples, self.micro_num):
end = min(start + self.micro_num, num_samples)
yield list(range(start, end))

def __len__(self):
return (len(self.data_source) + self.micro_num - 1) // self.micro_num
return (len(self.train_ds) + self.micro_num - 1) // self.micro_num

# TODO: implement copy method that compatible with InternEvo trainstate
def copy(self):
Expand Down
183 changes: 116 additions & 67 deletions internlm/data/mocked/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import glob
import os
import re
from typing import Dict, List

import torch
from torch.utils.data import Dataset
Expand All @@ -7,85 +10,131 @@
from internlm.core.context import global_context as gpc


def merge_tensors(file_pattern):
files = sorted(glob.glob(file_pattern))
tensors = []
for file in files:
tensor = torch.load(file)
tensors.append(tensor)
merged_tensor = torch.cat(tensors, dim=0)
return merged_tensor
def merge_tensors(fn_pattern: str) -> torch.Tensor:
"""
Merge per-step saved tensors into one, across all dp ranks.
Args:
fn_pattern: glob pattern for saved tensors, such like tokens_step{}_dp* or labels_step{}_dp*
Returns:
merged tensor
"""
return torch.cat([torch.load(fn) for fn in sorted(glob.glob(fn_pattern))], dim=0)


def split_tensors(raw_data: List[torch.Tensor], micro_bsz: int) -> List[torch.Tensor]:
"""
Split saved tensors into list of tensors where each element is micro batch with shape (micro_bsz, seq_len)
Args:
raw_data: list of tensors, where length is mocked_steps * micro_num * micro_bsz,
and each element is a tensor with shape (seq_len)
micro_bsz: micro batch size
Returns:
list of tensors, where length is mocked_steps * micro_num,
and each element is micro batch with shape (micro_bsz, seq_len)
"""
return [torch.cat(raw_data[i : i + micro_bsz], dim=0) for i in range(0, len(raw_data), micro_bsz)]

def process_raw_data(raw_data, micro_bsz):
num_groups = len(raw_data) // micro_bsz
result = []
for i in range(num_groups):
start_idx = i * micro_bsz
end_idx = start_idx + micro_bsz
group = raw_data[start_idx:end_idx]
concatenated = torch.cat(group, dim=0)
result.append(concatenated)
return result

def get_mocked_steps(data_dir: str) -> int:
step_pattern = r"_step(\d+)_dp"
mocked_steps = 0

for fn in os.listdir(data_dir):
step_match = re.search(step_pattern, fn)
if step_match:
step = int(step_match.group(1))
mocked_steps = max(mocked_steps, step)

return mocked_steps


class MockedDataset(Dataset):
"""
MockedDataset
Mocked dataset for easier precision alignment.
Suppose the saved data is with below format:
tokens_step{}_dp{}.pt, where {} is the saved step number (start from 0) and the dp rank (start from 0).
labels_step{}_dp{}.pt, where {} is the saved step number (start from 0) and the dp rank (start from 0).
Each of the saved data is a micro_num accumucalted tensor, where micro batch is (micro_bsz, seq_len).
Hence, the shape of tokens_step{}_dp{}.pt and labels_step{}_dp{}.pt is (micro_num * micro_bsz, seq_len).
"""

def __init__(self, data_dir, micro_bsz, seq_len, mocked_steps):
db_input_ids = []
db_labels = []
def __init__(self, train_folder: str, micro_bsz: int, micro_num: int, seq_len: int):

# load all saved data
for i in range(mocked_steps):
# define load pattern
input_ids_pattern = data_dir + f"_tokens_step{i+1}_dp*"
labels_pattern = data_dir + f"_labels_step{i+1}_dp*"
# merge input_ids, labels, and then chunk across dp
input_ids = torch.chunk(merge_tensors(input_ids_pattern), gpc.get_world_size(ParallelMode.DATA))[
gpc.get_local_rank(ParallelMode.DATA)
]
labels = torch.chunk(merge_tensors(labels_pattern), gpc.get_world_size(ParallelMode.DATA))[
gpc.get_local_rank(ParallelMode.DATA)
]
# load one step
db_input_ids.append(input_ids)
db_labels.append(labels)

# transform db
db_input_ids = torch.concat(db_input_ids, dim=0)
db_labels = torch.concat(db_labels, dim=0)
db_input_ids = [db_input_ids[i] for i in range(db_input_ids.size(0))]
db_labels = [db_labels[i] for i in range(db_labels.size(0))]

# gen data for internevo format
db_input_ids = process_raw_data(db_input_ids, micro_bsz)
db_labels = process_raw_data(db_labels, micro_bsz)
self.db_input_ids = [item.tolist() for item in db_input_ids]
self.db_labels = [item.tolist() for item in db_labels]

assert len(self.db_input_ids) == len(self.db_labels)
self.dataset_len = len(self.db_input_ids)
self.micro_bsz = micro_bsz
self.micro_num = micro_num
self.seq_len = seq_len

def __len__(self):
return self.dataset_len

def __getitem__(self, idx):
tokens = self.db_input_ids[idx]
cu_seqlens = list(range(self.micro_bsz + 1))
cu_seqlens = [i * self.seq_len for i in cu_seqlens]
indexes = list(range(self.seq_len)) * self.micro_bsz
labels = self.db_labels[idx]
type_ids = [0] * self.micro_bsz * self.seq_len
dp_size = gpc.get_world_size(ParallelMode.DATA)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
mocked_steps = get_mocked_steps(train_folder)

tokens_list = []
labels_list = []
for i in range(mocked_steps):
# define fn pattern
tokens_fn_pattern = f"{train_folder}/tokens_step{i}_dp*"
labels_fn_pattern = f"{train_folder}/labels_step{i}_dp*"

# merge per-step mocked data and chunk across dp ranks
tokens = torch.chunk(merge_tensors(tokens_fn_pattern), dp_size)[dp_rank] # (micro_num * micro_bsz, seq_len)
labels = torch.chunk(merge_tensors(labels_fn_pattern), dp_size)[dp_rank] # (micro_num * micro_bsz, seq_len)

# check and append
assert tokens.size() == labels.size(), "Mismatch for tokens and labels"
assert tokens.size(1) == seq_len, "Mismatch for seq_len"
assert tokens.size(0) == micro_bsz * micro_num, "Mismatch for global_bsz"
tokens_list.append(tokens)
labels_list.append(labels)

# concatenate across mocked_steps
db_tokens = torch.cat(tokens_list, dim=0) # (mocked_steps * micro_num * micro_bsz, seq_len)
db_labels = torch.cat(labels_list, dim=0) # (mocked_steps * micro_num * micro_bsz, seq_len)

# split into (mocked_steps * micro_num, packed_length), where packed_length = micro_bsz, seq_len
self.db_tokens = [
item.tolist() for item in split_tensors([db_tokens[i] for i in range(db_tokens.size(0))], micro_bsz)
]
self.db_labels = [
item.tolist() for item in split_tensors([db_labels[i] for i in range(db_labels.size(0))], micro_bsz)
]

# simple sanity check: ensure loaded per-step data is equivalent to saved per-step data
self.sanity_check(tokens_list, labels_list)

def __len__(self) -> int:
return len(self.db_tokens)

def __getitem__(self, idx: int) -> Dict[str, List[int]]:
return {
"tokens": tokens,
"cu_seqlens": cu_seqlens,
"indexes": indexes,
"labels": labels,
"type_ids": type_ids,
"tokens": self.db_tokens[idx],
"cu_seqlens": [i * self.seq_len for i in range(self.micro_bsz + 1)],
"indexes": list(range(self.seq_len)) * self.micro_bsz,
"labels": self.db_labels[idx],
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}

def sanity_check(self, tokens_list: List[torch.Tensor], labels_list: List[torch.Tensor]):
tokens_list_tocheck = []
for i in range(len(self.db_tokens)):
tokens_list_tocheck += self.db_tokens[i]
if (i + 1) % self.micro_num == 0:
tokens_list_ref = tokens_list[i // self.micro_num].flatten(0, 1).tolist()
assert tokens_list_tocheck == tokens_list_ref, "loaded tokens not equivalent to saved tokens"
tokens_list_tocheck = []

labels_list_tocheck = []
for i in range(len(self.db_labels)):
labels_list_tocheck += self.db_labels[i]
if (i + 1) % self.micro_num == 0:
labels_list_ref = labels_list[i // self.micro_num].flatten(0, 1).tolist()
assert labels_list_tocheck == labels_list_ref, "loaded labels not equivalent to saved labels"
labels_list_tocheck = []
9 changes: 7 additions & 2 deletions internlm/data/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from .batch_sampler import StreamingStaticBatchSampler
from .collaters import streaming_packed_collate_fn
from .dataset import StreamingDataset, StreamingPackedDatasetWithCut
from .dataset import (
StreamingDataset,
StreamingDatasetPackSampleIntoOneWithCut,
StreamingDatasetPackSampleWithPad,
)
from .utils import streaming_simple_resume

__all__ = [
"StreamingStaticBatchSampler",
"streaming_packed_collate_fn",
"StreamingDataset",
"StreamingPackedDatasetWithCut",
"StreamingDatasetPackSampleWithPad",
"StreamingDatasetPackSampleIntoOneWithCut",
"streaming_simple_resume",
]
Loading

0 comments on commit 569eb25

Please sign in to comment.