Skip to content

Commit

Permalink
feat(loss)/add different operator types for cross_entropy (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong authored Dec 17, 2024
1 parent 0ec6cdc commit 141e9eb
Show file tree
Hide file tree
Showing 22 changed files with 682 additions and 377 deletions.
14 changes: 14 additions & 0 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@
clip_grad_norm=1.0,
)


# loss config (dict):
# 1. label_smoothing
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
# default is "py_vocab_parallel".
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
# "apex_naive": cross_entropy from apex
# "py_naive": self-implemented cross_entropy
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy
# * op_types that ends with "naive" only support parallel_output=False;
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.

loss = dict(
label_smoothing=0,
moe_loss_coeff=0.1,
Expand Down
18 changes: 15 additions & 3 deletions configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,21 @@
clip_grad_norm=1.0,
)

loss = dict(
label_smoothing=0,
)

# loss config (dict):
# 1. label_smoothing
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
# default is "py_vocab_parallel".
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
# "apex_naive": cross_entropy from apex
# "py_naive": self-implemented cross_entropy
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy

# * op_types that ends with "naive" only support parallel_output=False;
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.
loss = dict(label_smoothing=0, op_type="py_vocab_parallel")

adam = dict(
lr=1e-4,
Expand Down
16 changes: 16 additions & 0 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,24 @@
clip_grad_norm=1.0,
)


# loss config (dict):
# 1. label_smoothing
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
# default is "py_vocab_parallel".
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
# "apex_naive": cross_entropy from apex
# "py_naive": self-implemented cross_entropy
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy

# * op_types that ends with "naive" only support parallel_output=False;
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.

loss = dict(
label_smoothing=0,
op_type="flash_vocab_parallel",
)

adam = dict(
Expand Down
10 changes: 6 additions & 4 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from internlm.data.train_state import get_train_state
from internlm.eval.evaluation import evaluate_on_val_dls
from internlm.initialize.initialize_trainer import initialize_trainer
from internlm.model.losses.ce_loss import FlashGPTLMLoss
from internlm.model.losses.ce_loss import InternLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor.monitor import send_alert_message
from internlm.train.pipeline import (
Expand Down Expand Up @@ -172,9 +172,11 @@ def _read_config(self, config_path: str) -> list:
with open(config_path, "r") as f:
return f.readlines()

def _initialize_criterion(self) -> FlashGPTLMLoss:
return FlashGPTLMLoss(
parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing
def _initialize_criterion(self) -> InternLoss:
return InternLoss(
parallel_output=gpc.config.model.parallel_output,
label_smoothing=gpc.config.loss.label_smoothing,
op_type=gpc.config.loss.op_type,
)

def _initialize_checkpoint_manager(
Expand Down
19 changes: 8 additions & 11 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,17 +351,6 @@ def args_sanity_check():
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)

old_parallel_output = gpc.config.model.get("parallel_output", None)
# Try to change user setting
if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU:
gpc.config.model.update({"parallel_output": False})
if old_parallel_output is True and gpc.is_rank_for_log():
logger.warning(
"'parallel_output' is converted from 'True' to 'False'."
"Because 'parallel_output' only support by FlashCrossEntropyLoss."
"Please make sure you are using flash attention in cuda device."
)

if "MoE" in gpc.config.get("model_type", ModelType.INTERNLM.name):
if "num_experts" not in model:
model._add_item("num_experts", 1)
Expand Down Expand Up @@ -449,6 +438,9 @@ def args_sanity_check():
]:
gpc.config.parallel.sequence_parallel = True

if gpc.config.model.get("parallel_output", False) is False:
logger.warning("When enable sequence parallel, it recommend to enable parallel_output")

# set default value for weight parallel
if gpc.config.parallel["weight"].get("overlap", None) is None:
gpc.config.parallel["weight"]["overlap"] = False
Expand Down Expand Up @@ -583,6 +575,11 @@ def args_sanity_check():
gpc.config.data.use_packed_dataset is False
), "only unpacked data is supported when using 2D sequence parallel."

# loss operator type
loss_cfg = gpc.config.loss
if loss_cfg.get("op_type", None) is None:
loss_cfg._add_item("op_type", "py_vocab_parallel")


def launch(
config: Union[str, Path, Config, Dict],
Expand Down
4 changes: 2 additions & 2 deletions internlm/model/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .ce_loss import FlashGPTLMLoss
from .ce_loss import InternLoss

__all__ = [
"FlashGPTLMLoss",
"InternLoss",
]
72 changes: 53 additions & 19 deletions internlm/model/losses/ce_loss.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,61 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
from torch import nn

from internlm.core.context import global_context as gpc
from internlm.accelerator import get_accelerator
from internlm.model.ops.cross_entropy import new_cross_entropy
from internlm.utils.logger import get_logger

logger = get_logger(__file__)
internlm_accelerator = get_accelerator()


class FlashGPTLMLoss(nn.Module):
"""
Loss function for flash GPT Language Model.
class InternLoss(nn.Module):
"""We use a base class to wrap different CrossEntropy implementations
and unify input and output parameters.
This class is designed not to rely on gpc, making it easy to transplant.
Different variants of CrossEntropy, with supporting parallel computation and inplace operations.
If parallel_output is False, the output will gather head's output, only 'FlashCrossEntropyLoss' and
'CrossEntropyApexVocabParallel' support it.
"""

def __init__(self, parallel_output=True, label_smoothing=0):
def __init__(
self,
parallel_output=False,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
inplace_backward=True,
op_type="py_vocab_parallel",
) -> None:
super().__init__()

if label_smoothing is not None:
if label_smoothing != 0:
if gpc.is_rank_for_log():
print(f"use label_smoothing: {label_smoothing}")
print(f"use label_smoothing: {label_smoothing}", flush=True)
else:
label_smoothing = 0

self.label_smoothing = label_smoothing

self.reduction = reduction
self.ignore_index = ignore_index
self.op_type = op_type

assert self.reduction in [
"mean",
"none",
], f"Only support reduction is mean/none, but the passed in reduction is {self.reduction}"

# In order to facilitate the calculation of loss for different datasets, we set reduction as 'none',
# and do loss reduction ourselves.
self.loss_fn = new_cross_entropy(
reduction="mean",
label_smoothing=self.label_smoothing,
op_type=op_type,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
parallel_output=parallel_output,
inplace_backward=True,
inplace_backward=inplace_backward,
reduction="none",
)

def forward(self, *args):
Expand All @@ -44,9 +69,18 @@ def forward(self, *args):
raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
shift_logits = logits.contiguous().view(-1, logits.size(-1))
shift_labels = labels.contiguous().view(-1)
loss = self.loss_fn(
shift_logits, shift_labels
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
# calculated through the calculation range, and -100 must be outside this range, so there is no problem

with torch.autocast(device_type=internlm_accelerator.get_backend_name()):
loss_list = self.loss_fn(
shift_logits, shift_labels
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
# # calculated through the calculation range, and -100 must be outside this range, so there is no problem

cond = shift_labels != self.ignore_index
if self.reduction == "mean":
# This loss is only for one dp rank.
loss = loss_list.sum() / (cond).sum()
else:
loss = loss_list

return loss
1 change: 1 addition & 0 deletions internlm/model/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
reduction="none",
parallel_output=gpc.config.model.parallel_output,
inplace_backward=True,
op_type=gpc.config.loss.op_type,
)
self.scatter_sum = scatter_sum_impl

Expand Down
Loading

0 comments on commit 141e9eb

Please sign in to comment.