Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Llama3.2 1B Embedding Model Support #11909

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from nemo.collections.llm.gpt.data import (
AlpacaDataModule,
ChatDataModule,
CustomRetrievalDataModule,
DollyDataModule,
FineTuningDataModule,
HFDatasetDataModule,
Expand Down Expand Up @@ -109,6 +110,8 @@
Nemotron4Config340B,
NemotronConfig,
NemotronModel,
NVEmbedLlama32Config1B,
NVEmbedLlamaModel,
NVIDIAMambaConfig8B,
NVIDIAMambaHybridConfig8B,
Phi3Config,
Expand Down Expand Up @@ -150,6 +153,7 @@
__all__ = [
"MockDataModule",
"T5MockDataModule",
"CustomRetrievalDataModule",
"GPTModel",
"GPTConfig",
"gpt_data_step",
Expand Down Expand Up @@ -185,6 +189,8 @@
"Nemotron4Config15B",
"Nemotron4Config340B",
"NemotronConfig",
"NVEmbedLlamaModel",
"NVEmbedLlama32Config1B",
"Phi3Config",
"Phi3ConfigMini",
"Phi3Model",
Expand Down
94 changes: 94 additions & 0 deletions nemo/collections/llm/bert/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,100 @@ def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
return torch.tensor(0.0, device=torch.cuda.current_device())


class HardNegativesCELoss(MegatronLossReduction):
"""
This loss uses hard-negative samples.
The difference of this loss to the default MultipleNegativesRankingLoss
from Sentence Transformers is that the latter shares the hard negatives
as negatives for all examples, whereas this loss uses hard negatives
exclusively for the example they are associated.
"""

def __init__(
self,
validation_step: bool = False,
val_drop_last: bool = True,
num_hard_negatives: int = 1,
scale: float = 50,
label_smoothing: float = 0.0,
encode_separately: bool = True,
) -> None:
super().__init__()
self.validation_step = validation_step
self.val_drop_last = val_drop_last
self.num_hard_negatives = num_hard_negatives
self.scale = scale
self.encode_separately = encode_separately
self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

def forward(
self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
from megatron.core import parallel_state

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size != 1:
raise NotImplementedError(f'CP is not supported for {self.__class__} yet.')

num_tensors_per_example = 2 + self.num_hard_negatives # 1 query, 1 pos, num_hard_negatives negs
current_train_n_passages = 1 + self.num_hard_negatives
batch_size = forward_out.shape[0] // num_tensors_per_example
# Get Query, Key (Positives, Negatives)
if self.encode_separately:
# forward_out was concat of [query, key]
query = forward_out[:batch_size]
key = forward_out[batch_size:]
else:
# forward_out was chunked [(q1, k1), (q2, k2), ...]
chunks = forward_out.chunk(batch_size)
query = torch.stack([item[0] for item in chunks])
key = torch.cat(
[torch.stack([item[i + 1] for item in chunks]) for i in range(current_train_n_passages)],
dim=0,
)

assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
assert key.shape[0] / query.shape[0] == current_train_n_passages, '{} / {} != {}'.format(
key.shape[0], query.shape[0], current_train_n_passages
)
query_shape = query.shape
repeated_query = query.repeat(1, 1, current_train_n_passages).reshape(
query_shape[0] * current_train_n_passages, query_shape[1]
)
scores = torch.sum(repeated_query * key, dim=-1).reshape(query_shape[0], current_train_n_passages)
labels = torch.zeros(query_shape[0], dtype=torch.long, device=query.device)

scores *= self.scale
ce_loss = self.cross_entropy_loss(scores, labels)
reduced_loss = average_losses_across_data_parallel_group([ce_loss])
return ce_loss, {"avg": reduced_loss}

def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
"""Taken from: https://github.com/NVIDIA/NeMo/blob/main
/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 ."""
if losses_reduced_per_micro_batch:
if "avg" in losses_reduced_per_micro_batch[0]:
loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)

return loss_tensor.mean()

# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list: List[torch.Tensor] = [
loss_sum["loss_sum_and_ub_size"]
for loss_sum in losses_reduced_per_micro_batch
if loss_sum["loss_sum_and_ub_size"][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(dim=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0], device=torch.cuda.current_device())
)
return loss_sum

return torch.tensor(0.0, device=torch.cuda.current_device())


class BERTInBatchExclusiveHardNegativesRankingLoss(MegatronLossReduction):
"""
This loss uses in-batch negative samples + hard-negative samples.
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule
from nemo.collections.llm.gpt.data.retrieval import CustomRetrievalDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule

__all__ = [
Expand All @@ -31,4 +32,5 @@
"PreTrainingDataModule",
"build_pretraining_datamodule",
"SquadDataModule",
"CustomRetrievalDataModule",
]
96 changes: 96 additions & 0 deletions nemo/collections/llm/gpt/data/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import os.path
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from datasets import Dataset

from nemo.collections.llm.bert.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.core import get_dataset_root
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


# Custom Retrieval Data Module loaded with json file
class CustomRetrievalDataModule(FineTuningDataModule):
""" """

def __init__(
self,
data_root: str,
dataset_identifier: str = "custom_retrieval_dataset",
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
query_key: str = "question",
pos_doc_key: str = "pos_doc",
neg_doc_key: str = "neg_doc",
dataset_kwargs: Optional[Dict[str, Any]] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

assert packed_sequence_specs is None, "RetrievalDataModule does not support packed sequences."
assert os.path.exists(data_root), "Data root does not exist."
self.query_key = query_key
self.pos_doc_key = pos_doc_key
self.neg_doc_key = neg_doc_key
self.unprocessed_root = data_root
super().__init__(
dataset_root=get_dataset_root(dataset_identifier),
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
dataset_kwargs=dataset_kwargs,
)

def prepare_data(self) -> None:
"""Prepare data if not split already."""
if not self.train_path.exists() or self.force_redownload:
self._preprocess_and_split_data()
super().prepare_data()

def _preprocess_and_split_data(self, train_ratio: float = 0.95, val_ratio: float = 0.04):
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")

test_ratio = 1 - train_ratio - val_ratio
save_splits = {}
dataset = Dataset.from_list(json.load(open(self.unprocessed_root, 'r')))
split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)
split_dataset2 = split_dataset['test'].train_test_split(
test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed
)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset2['train']
save_splits['test'] = split_dataset2['test']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"
with output_file.open("w", encoding="utf-8") as f:
for o in dataset:
# We only write one positive document for now
# All negative document are written
pos_doc = o[self.pos_doc_key][0] if isinstance(o[self.pos_doc_key], list) else o[self.pos_doc_key]
neg_doc = o[self.neg_doc_key] if isinstance(o[self.pos_doc_key], list) else [o[self.neg_doc_key]]
f.write(json.dumps({"query": o[self.query_key], "pos_doc": pos_doc, "neg_doc": neg_doc}) + "\n")

logging.info(f"{split_name} split saved to {output_file}")
3 changes: 3 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
NemotronConfig,
NemotronModel,
)
from nemo.collections.llm.gpt.model.nvembed_llama import NVEmbedLlama32Config1B, NVEmbedLlamaModel
from nemo.collections.llm.gpt.model.phi3mini import Phi3Config, Phi3ConfigMini, Phi3Model
from nemo.collections.llm.gpt.model.qwen2 import (
Qwen2Config,
Expand Down Expand Up @@ -145,6 +146,8 @@
"Nemotron3Config22B",
"Nemotron4Config340B",
"NemotronModel",
"NVEmbedLlamaModel",
"NVEmbedLlama32Config1B",
"Phi3Config",
"Phi3ConfigMini",
"Phi3Model",
Expand Down
Loading
Loading