Skip to content

Commit

Permalink
feat(dataloader): refine implementation of mocked and megatron datalo…
Browse files Browse the repository at this point in the history
…ader (#344)

Co-authored-by: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com>
  • Loading branch information
zigzagcai and sallyjunjun authored Dec 10, 2024
1 parent cd53c32 commit ae2243c
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 226 deletions.
36 changes: 25 additions & 11 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import subprocess
from functools import partial

import torch
import torch.distributed as dist
from torch.utils.data import ConcatDataset, DataLoader

from internlm.accelerator.abstract_accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.data.megatron.batch_sampler import MegatronBatchSampler
from internlm.data.megatron.collaters import megatron_collate_fn
from internlm.data.megatron.dataset import build_megatron_dataset
from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler
Expand Down Expand Up @@ -41,8 +42,8 @@
from internlm.utils.logger import get_logger
from internlm.utils.utils import DataType

# global llm logger
logger = get_logger(__file__)
internlm_accelerator = get_accelerator()


def get_tokenized_train_loader_items(data_cfg):
Expand Down Expand Up @@ -156,10 +157,14 @@ def get_streaming_train_loader_items(data_cfg):


def get_megatron_train_loader_items(data_cfg):
assert data_cfg.get(
"pack_sample_into_one", False
), "megatron dataloader curently only supports pack_sample_into_one=True"
try:
from internlm.data.megatron import helpers # noqa # pylint: disable=W0611
except ImportError:
if gpc.is_rank_for_log():
# Compile dynamic library on-demand
if gpc.get_global_rank() % internlm_accelerator.device_count() == 0:
subprocess.run( # noqa # pylint: disable=W1510
[
"g++",
Expand All @@ -173,23 +178,28 @@ def get_megatron_train_loader_items(data_cfg):
"internlm/data/megatron/helpers.cpp",
"-o",
"internlm/data/megatron/helpers.so",
]
],
)
torch.distributed.barrier()

# NOTICE: Currently we only support single megatron dataset, a.k.a., single .bin and .idx
# Megatron dataset (.bin and.idx) should be generated by Megatron-LM tools/preprocess_data.py
# https://github.com/NVIDIA/Megatron-LM/blob/main/tools/preprocess_data.py
train_ds = build_megatron_dataset(
data_prefix=data_cfg.train_folder,
data_impl=data_cfg.get("data_impl", "infer"),
splits_string="1.0, 0.0, 0.0",
train_valid_test_num_samples=[9600000, 0, 0],
seq_len=data_cfg.seq_len,
seed=data_cfg.get("seed", 1024),
skip_warmup=True,
)

train_sampler = MegatronBatchSampler(
total_samples=len(train_ds),
consumed_samples=0,
train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num * data_cfg.micro_bsz,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=data_cfg.get("seed", 1024),
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)

train_collate_fn = partial(
Expand All @@ -203,14 +213,18 @@ def get_mock_train_loader_items(data_cfg):
assert data_cfg.get(
"pack_sample_into_one", False
), "mocked dataloader curently only supports pack_sample_into_one=True"

train_ds = MockedDataset(
train_folder=data_cfg.train_folder,
micro_bsz=data_cfg.micro_bsz,
micro_num=data_cfg.micro_num,
seq_len=data_cfg.seq_len,
)

train_sampler = MockedSequentialBatchSampler(train_ds, data_cfg.micro_num)

train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz)

return train_ds, train_sampler, train_collate_fn


Expand Down
2 changes: 0 additions & 2 deletions internlm/data/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .batch_sampler import MegatronBatchSampler
from .collaters import megatron_collate_fn
from .dataset import build_megatron_dataset

__all__ = [
"MegatronBatchSampler",
"build_megatron_dataset",
"megatron_collate_fn",
]
62 changes: 0 additions & 62 deletions internlm/data/megatron/batch_sampler.py

This file was deleted.

56 changes: 22 additions & 34 deletions internlm/data/megatron/collaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,36 @@


def megatron_collate_fn(batch, micro_num, micro_bsz, seq_len):

input_ids_result = [[] for _ in range(micro_num)]
labels_result = [[] for _ in range(micro_num)]
cu_seqlens = []
input_ids_list = [[] for _ in range(micro_num)]
labels_list = [[] for _ in range(micro_num)]
cu_seqlens_list = []
indexes = []
indexes_list = []

for i, item in enumerate(batch):
assert i < micro_num * micro_bsz
seq_len_list = item["text"]
assert len(seq_len_list) == seq_len + 1

micro_bsz_index = i % micro_bsz
micro_num_index = i // micro_bsz

input_ids_result[micro_num_index].append(seq_len_list[:-1])
labels_result[micro_num_index].append(seq_len_list[1:])

cu_seqlens.append(seq_len * micro_bsz_index)
indexes = indexes + list(range(seq_len))
assert len(batch) == micro_bsz * micro_num
for idx, b in enumerate(batch):
tokens = b["text"]
# The length of megatron preprocessed data samples is (seq_len + 1)
# So we use the first seq_len tokens as input and the last seq_len tokens as shifted labels
assert len(tokens) == seq_len + 1
micro_bsz_index = idx % micro_bsz
micro_num_index = idx // micro_bsz
input_ids_list[micro_num_index].append(tokens[:-1])
labels_list[micro_num_index].append(tokens[1:])

if micro_bsz_index == micro_bsz - 1:
input_ids_result[micro_num_index] = torch.cat(
[torch.from_numpy(arr).long() for arr in input_ids_result[micro_num_index]], dim=0
# Since megatron data sample is numpy format, we need to convert it to tensor and concate within micro batch
input_ids_list[micro_num_index] = torch.cat(
[torch.from_numpy(arr) for arr in input_ids_list[micro_num_index]], dim=0
)
labels_result[micro_num_index] = torch.cat(
[torch.from_numpy(arr).long() for arr in labels_result[micro_num_index]], dim=0
labels_list[micro_num_index] = torch.cat(
[torch.from_numpy(arr) for arr in labels_list[micro_num_index]], dim=0
)
cu_seqlens.append(seq_len * micro_bsz)
cu_seqlens_list.append(torch.IntTensor(cu_seqlens))
cu_seqlens = []
indexes_list.append(torch.IntTensor(indexes))
indexes = []

input_ids = torch.stack(input_ids_result)
labels = torch.stack(labels_result)
indexes = torch.stack(indexes_list)
cu_seqlens_list.append(torch.IntTensor([i * seq_len for i in range(micro_bsz + 1)]))
indexes_list.append(torch.IntTensor(list(range(seq_len)) * micro_bsz))

return {
"input_ids": input_ids,
"input_ids": torch.stack(input_ids_list),
"cu_seqlens": cu_seqlens_list,
"indexes": indexes,
"indexes": torch.stack(indexes_list),
"type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64),
}, labels
}, torch.stack(labels_list)
90 changes: 17 additions & 73 deletions internlm/data/megatron/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/gpt_dataset.py
# adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/indexed_dataset.py

import hashlib
import os
import struct
Expand Down Expand Up @@ -764,82 +765,25 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
return indexed_dataset


def get_train_valid_test_split_(splits_string, size):
"""Get dataset splits from comma or '/' separated string list."""

splits = []
if splits_string.find(",") != -1:
splits = [float(s) for s in splits_string.split(",")]
elif splits_string.find("/") != -1:
splits = [float(s) for s in splits_string.split("/")]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index


def build_megatron_dataset(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_len,
seed,
skip_warmup,
return_doc_ids=False,
*,
data_cache_path=None,
):

# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)

total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)

# Print stats about the splits.
print_rank_0(" > dataset split:")

def print_split_stats(index, name):
print_rank_0(" {}:".format(name))
print_rank_0(
" document indices in [{}, {}) total of {} "
"documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
)

print_split_stats(0, "train")

def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
dataset = GPTDataset(
name,
data_prefix,
documents,
indexed_dataset,
splits_string,
train_valid_test_num_samples[index],
seq_len,
seed,
return_doc_ids,
data_cache_path=data_cache_path,
)
return dataset

train_dataset = build_dataset(0, "train")

return train_dataset
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl="infer", skip_warmup=True)

# GPT dataset.
return GPTDataset(
name="train",
data_prefix=data_prefix,
documents=np.arange(start=0, stop=indexed_dataset.sizes.shape[0], step=1, dtype=np.int32),
indexed_dataset=indexed_dataset,
splits_string="1.0, 0.0, 0.0", # proportion of dataset for train/valid/test, we set 1.0 for train only
num_samples=gpc.config.data.micro_bsz
* gpc.config.data.micro_num
* gpc.get_world_size(ParallelMode.DATA)
* gpc.config.data.total_steps, # total number of train samples
seq_length=seq_len,
seed=seed,
)
Loading

0 comments on commit ae2243c

Please sign in to comment.