Skip to content

Commit

Permalink
Add BERT Embedding Models (#11737)
Browse files Browse the repository at this point in the history
* Add BERT Embedding model

* Fix validation loss func

* Revert lightning/io changes

* Revert llm/api

* add global in-batch negatives option

* Apply isort and black reformatting

Signed-off-by: suiyoubi <suiyoubi@users.noreply.github.com>

* codeQL

* pylint

* Apply isort and black reformatting

Signed-off-by: suiyoubi <suiyoubi@users.noreply.github.com>

* pylint

* minor changes

* Apply isort and black reformatting

Signed-off-by: suiyoubi <suiyoubi@users.noreply.github.com>

* remove test code for specter

* remove import

---------

Signed-off-by: suiyoubi <suiyoubi@users.noreply.github.com>
Co-authored-by: suiyoubi <suiyoubi@users.noreply.github.com>
  • Loading branch information
suiyoubi and suiyoubi authored Jan 6, 2025
1 parent 638d0f7 commit 5ed118f
Show file tree
Hide file tree
Showing 9 changed files with 794 additions and 5 deletions.
9 changes: 8 additions & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
safe_import("transformer_engine")

from nemo.collections.llm import peft
from nemo.collections.llm.bert.data import BERTMockDataModule, BERTPreTrainingDataModule
from nemo.collections.llm.bert.data import BERTMockDataModule, BERTPreTrainingDataModule, SpecterDataModule
from nemo.collections.llm.bert.model import (
BertConfig,
BertEmbeddingLargeConfig,
BertEmbeddingMiniConfig,
BertEmbeddingModel,
BertModel,
HuggingFaceBertBaseConfig,
HuggingFaceBertConfig,
Expand Down Expand Up @@ -157,7 +160,10 @@
"T5Config3B",
"T5Config11B",
"BertConfig",
"BertEmbeddingModel",
"BertModel",
"BertEmbeddingLargeConfig",
"BertEmbeddingMiniConfig",
"t5_data_step",
"t5_forward_step",
"MaskedTokenLossReduction",
Expand Down Expand Up @@ -247,6 +253,7 @@
"MegatronBertLargeConfig",
"BERTMockDataModule",
"BERTPreTrainingDataModule",
"SpecterDataModule",
"DollyDataModule",
"tokenizer",
"mock",
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/llm/bert/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nemo.collections.llm.bert.data.mock import BERTMockDataModule
from nemo.collections.llm.bert.data.pre_training import BERTPreTrainingDataModule
from nemo.collections.llm.bert.data.specter import SpecterDataModule

__all__ = ["BERTPreTrainingDataModule", "BERTMockDataModule"]
__all__ = ["BERTPreTrainingDataModule", "BERTMockDataModule", "SpecterDataModule"]
62 changes: 62 additions & 0 deletions nemo/collections/llm/bert/data/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import TYPE_CHECKING, Optional

from nemo.collections.nlp.data.information_retrieval.bert_embedding_dataset import BertEmbeddingDataset
from nemo.lightning.base import NEMO_DATASETS_CACHE

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


def get_dataset_root(name: str) -> Path:
"""Retrieve the root path for the dataset. Create the folder if not exists."""
output = Path(NEMO_DATASETS_CACHE) / name
output.mkdir(parents=True, exist_ok=True)

return output


def create_sft_dataset(
path: Path,
tokenizer: "TokenizerSpec",
seq_length: int = 2048,
add_bos: bool = False,
add_eos: bool = True,
seed: int = 1234,
index_mapping_dir: Optional[str] = None,
truncation_method: str = 'right',
memmap_workers: int = 2,
data_type: str = 'train',
num_hard_negatives: int = 1,
**kwargs,
) -> "BertEmbeddingDataset":
"""Create BertEmbeddingDataset for SFT training."""

return BertEmbeddingDataset(
file_path=str(path),
tokenizer=tokenizer,
max_seq_length=seq_length,
add_bos=add_bos,
add_eos=add_eos,
memmap_workers=memmap_workers,
seed=seed,
index_mapping_dir=index_mapping_dir,
truncation_method=truncation_method,
data_type=data_type,
num_hard_negatives=num_hard_negatives,
**kwargs,
)
227 changes: 227 additions & 0 deletions nemo/collections/llm/bert/data/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import lightning.pytorch as pl
from torch.utils.data import DataLoader

from nemo.collections.common.tokenizers import AutoTokenizer
from nemo.collections.llm.bert.data.core import create_sft_dataset
from nemo.lightning.data import WrappedDataLoader
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class FineTuningDataModule(pl.LightningDataModule):
"""Base class for fine-tuning an Bert.
This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from
`pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch
creation for training, validation, and testing.
Args:
dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data.
seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048.
tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text.
If not provided, a Megatron GPT2 BPE tokenizer will be used.
micro_batch_size (int, optional): The micro batch size for training. Defaults to 4.
global_batch_size (int, optional): The global batch size for training. Defaults to 8.
rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training.
Defaults to None.
seed (int, optional): The random seed for data shuffling. Defaults to 1234.
memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset.
Defaults to 1.
num_workers (int, optional): The number of worker processes for data loading. Defaults to 8.
pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training.
Defaults to True.
persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs.
Defaults to False.
dataset_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments to pass into the GPTSFTDataset class
"""

def __init__(
self,
dataset_root: Union[str, Path],
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
dataset_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.seq_length = seq_length
self.seed = seed
self.dataset_root = Path(dataset_root)
self.tokenizer = tokenizer
self.memmap_workers = memmap_workers
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
self.data_sampler = None
self.max_train_samples = None
self.dataset_kwargs = dataset_kwargs or {}

def setup(self, stage: str):
"""Called by pytorch lightning in datamodule setup"""

# data_sampler is used in `setup_data_sampler` in MegatronStrategy.setup
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=self.rampup_batch_size,
dataloader_type="batch",
)

# Follows the calculation in nemo.collections.nlp.data.language_modeling.megatron.
# base_dataset_utils.get_datasets_weights_and_num_samples
self.max_train_samples = int(math.ceil(self.global_batch_size * self.trainer.max_steps * 1.005))

def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(
self.trainer.global_step - self.data_sampler.init_global_step
)
return {"consumed_samples": consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat
Args:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches

except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
self.data_sampler.if_first_step = 1

def train_dataloader(self) -> DataLoader:
# pylint: disable=C0115,C0116
return self._create_dataloader(
self._create_dataset(
self.train_path,
max_num_samples=self.max_train_samples,
**self.dataset_kwargs,
),
mode="train",
)

def val_dataloader(self) -> DataLoader:
# pylint: disable=C0115,C0116
return self._create_dataloader(
self._create_dataset(
self.train_path,
max_num_samples=self.max_train_samples,
**self.dataset_kwargs,
),
mode="train",
)

def test_dataloader(self) -> DataLoader:
# pylint: disable=C0115,C0116
return self._create_dataloader(
self._create_dataset(
self.train_path,
max_num_samples=self.max_train_samples,
**self.dataset_kwargs,
),
mode="train",
)

@lru_cache
def _create_dataset(self, path, **kwargs):
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=self.seq_length,
memmap_workers=self.memmap_workers,
seed=self.seed,
**kwargs,
)

def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:
return WrappedDataLoader(
mode=mode,
dataset=dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
**kwargs,
)

@property
def train_path(self) -> Path:
"""Path to training dataset file"""
return self.dataset_root / "training.jsonl"

@property
def validation_path(self) -> Path:
"""Path to validation dataset file"""
return self.dataset_root / "validation.jsonl"

@property
def test_path(self) -> Path:
"""Path to test dataset file"""
return self.dataset_root / "test.jsonl"

def _extract_tokenizer_model_name(self) -> str:
"""Automatically get the model name from model path."""
if isinstance(self.tokenizer, AutoTokenizer):
name = self.tokenizer.tokenizer.name_or_path
if name.endswith("context/nemo_tokenizer"):
# NEMO_HOME/hf_org/hf_model/context/nemo_tokenizer => hf_org--hf_model
tokenizer_model_name = '--'.join(name.split("/")[-4:-2])
elif name.endswith("nemo_tokenizer"):
# NEMO_HOME/hf_org/hf_model/nemo_tokenizer => hf_org--hf_model
tokenizer_model_name = '--'.join(name.split("/")[-3:-1])
else:
# hf_org/hf_model => hf_org--hf_model
tokenizer_model_name = name.replace("/", "--")
else:
tokenizer_model_name = f"unknown_tokenizer_{hash(self.tokenizer)}"
return tokenizer_model_name
Loading

0 comments on commit 5ed118f

Please sign in to comment.