From 4452ad63e5e40208a447ac43c1ae2e595f25d6b0 Mon Sep 17 00:00:00 2001
From: ytxiong <45058324+yingtongxiong@users.noreply.github.com>
Date: Fri, 6 Sep 2024 19:53:04 +0800
Subject: [PATCH] fix(cross_entropy.py): replace the fa loss with apex loss
(#317)
---
README-zh-Hans.md | 27 +++-
README.md | 27 +++-
doc/en/parallel_output.md | 5 +
doc/parallel_output.md | 5 +
internlm/model/ops/cross_entropy.py | 211 +++++++++++++++++++++++++---
5 files changed, 256 insertions(+), 19 deletions(-)
create mode 100644 doc/en/parallel_output.md
create mode 100644 doc/parallel_output.md
diff --git a/README-zh-Hans.md b/README-zh-Hans.md
index 582dfd8a..ea38b04e 100644
--- a/README-zh-Hans.md
+++ b/README-zh-Hans.md
@@ -116,7 +116,7 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
InternEvo 特性列表
-
+
@@ -171,6 +171,31 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
|
+## 常见tips
+
+
+
+
+
+
+
+ 现象
+ |
+
+ 介绍
+ |
+
+
+
+ 在Vocab维度并行计算loss
+ |
+
+ 说明
+ |
+
+
+
+
## 贡献
我们感谢所有的贡献者为改进和提升 InternEvo 所作出的努力。非常欢迎社区用户能参与进项目中来。请参考贡献指南来了解参与项目贡献的相关指引。
diff --git a/README.md b/README.md
index 1a22f419..9a61d573 100644
--- a/README.md
+++ b/README.md
@@ -116,7 +116,7 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
InternEvo Feature Zoo
-
+
@@ -171,6 +171,31 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
|
+## Common Tips
+
+
+
+
+
+
+
+ Item
+ |
+
+ Introduction
+ |
+
+
+
+ Parallel Computing Loss
+ |
+
+ link
+ |
+
+
+
+
## Contribution
We appreciate all the contributors for their efforts to improve and enhance InternEvo. Community users are highly encouraged to participate in the project. Please refer to the contribution guidelines for instructions on how to contribute to the project.
diff --git a/doc/en/parallel_output.md b/doc/en/parallel_output.md
new file mode 100644
index 00000000..8d9fd0b3
--- /dev/null
+++ b/doc/en/parallel_output.md
@@ -0,0 +1,5 @@
+## Parallel Computing Loss
+
+The parallel computing loss function in InternEvo is adapted from [Apex]( https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py). Users can replace the loss function with [Flash-Attention](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) to obtain speedup, which may lead to loss divergence.
+
+For detailed modifications in InternEvo,please refer to the code [InternEvo-parallel-loss](https://github.com/InternLM/InternEvo/blob/develop/internlm/model/ops/cross_entropy.py)
\ No newline at end of file
diff --git a/doc/parallel_output.md b/doc/parallel_output.md
new file mode 100644
index 00000000..3be91779
--- /dev/null
+++ b/doc/parallel_output.md
@@ -0,0 +1,5 @@
+## 并行计算loss
+
+InternEvo目前使用的并行计算loss方法改编自[Apex]( https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py)。如需要加速计算loss,可将并行计算loss方法改为[Flash-Attention](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py)的并行计算方法,需要注意的是,这可能会出现loss不收敛的情况。
+
+具体修改代码可见[InternEvo-parallel-loss](https://github.com/InternLM/InternEvo/blob/develop/internlm/model/ops/cross_entropy.py)
\ No newline at end of file
diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py
index eba7f4dc..1b4b03c7 100644
--- a/internlm/model/ops/cross_entropy.py
+++ b/internlm/model/ops/cross_entropy.py
@@ -7,6 +7,7 @@
"""
import torch
+import torch.distributed as dist
from torch import nn
from internlm.accelerator import AcceleratorType, get_accelerator
@@ -14,15 +15,6 @@
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
-try:
- from flash_attn.losses.cross_entropy import (
- CrossEntropyLoss as FlashCrossEntropyLoss,
- )
-
- flash_cross_entropy_impl = True
-except (ModuleNotFoundError, ImportError):
- flash_cross_entropy_impl = False
-
logger = get_logger(__file__)
internlm_accelerator = get_accelerator()
@@ -152,6 +144,191 @@ def forward(self, _input, target):
return _loss_list.view(-1)
+class _VocabParallelCrossEntropy(torch.autograd.Function):
+ """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
+ Supports vocab parallel loss calculation, but does not support inplace backward.
+ NOTE: This class is different from the original Apex implementation. Apex will calculate the loss of
+ ignore_index and flashCrossEntropy will set it to 0. InterEvo adapts the second approach.
+ """
+
+ @staticmethod
+ @internlm_accelerator.amp.custom_fwd
+ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, process_group=None):
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ if process_group is not None and dist.get_world_size(process_group) > 1:
+ torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
+ # Subtract the maximum value.
+ vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
+
+ # Get the partition's vocab indecies
+ # get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ if process_group is not None and dist.get_world_size(process_group) > 1:
+ rank = dist.get_rank(process_group)
+ # world_size = dist.get_world_size(process_group)
+ part_len = vocab_parallel_logits.shape[-1]
+ vocab_start_index, vocab_end_index = part_len * rank, part_len * (rank + 1)
+ else:
+ vocab_start_index, vocab_end_index = 0, vocab_parallel_logits.shape[-1]
+
+ # vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ ignore_mask = target == -100
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0
+
+ # All reduce is needed to get the chunks from other GPUs.
+ if process_group is not None and dist.get_world_size(process_group) > 1:
+ torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = vocab_parallel_logits
+ torch.exp(vocab_parallel_logits, out=exp_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+
+ if process_group is not None and dist.get_world_size(process_group) > 1:
+ torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+ # Normalize and optionally smooth logits
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ sum_exp_logits = torch.log(sum_exp_logits)
+ loss = sum_exp_logits - predicted_logits
+ loss[ignore_mask] = 0.0
+
+ vocab_size = exp_logits.size(-1)
+ if label_smoothing > 0:
+ r"""
+ We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
+ = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
+ = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
+ = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
+ = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
+ = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
+ From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
+ """
+ assert 1.0 > label_smoothing > 0.0
+ smoothing = label_smoothing * vocab_size / (vocab_size - 1)
+
+ # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
+ log_probs = torch.log(exp_logits)
+ mean_log_probs = log_probs.mean(dim=-1)
+ loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
+
+ ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
+ # Store softmax, target-mask and masked-target for backward pass.
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, ignore_mask)
+
+ return loss
+
+ @staticmethod
+ @internlm_accelerator.amp.custom_bwd
+ def backward(ctx, grad_output):
+
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d, ignore_mask = ctx.saved_tensors
+ label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
+
+ # All the inputs have softmax as thier gradient.
+ grad_input = softmax # s_{k}
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+
+ softmax_update = 1.0 - target_mask.view(-1).float()
+
+ if label_smoothing > 0:
+ smoothing = label_smoothing * vocab_size / (vocab_size - 1)
+ grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
+ average_grad = 1 / vocab_size
+ grad_2d[arange_1d, :] -= smoothing * average_grad
+ else:
+ grad_2d[arange_1d, masked_target_1d] -= softmax_update
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+ grad_input[ignore_mask] = 0.0 # set ignore token loss as 0.
+
+ return grad_input, None, None, None
+
+
+class CrossEntropyApexVocabParallel(nn.Module):
+ """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
+ Supports vocab parallel loss calculation, but does not support inplace backward.
+ """
+
+ def __init__(
+ self, ignore_index=-100, reduction="mean", label_smoothing=0.0, process_group=None, inplace_backward=False
+ ):
+ super().__init__()
+ if reduction not in ["mean", "none"]:
+ raise NotImplementedError("Only support reduction = 'mean' or 'none'")
+ assert inplace_backward is False, "does not support inplace backward"
+ self.ignore_index = ignore_index
+ self.reduction = reduction
+ self.label_smoothing = label_smoothing
+ self.process_group = process_group
+
+ def forward(self, vocab_parallel_logits, target):
+ # assert vocab_parallel_logits.is_cuda and vocab_parallel_logits.is_cuda
+
+ # SoftmaxCrossEntropyLoss implicitly casts to float
+ loss = _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, self.label_smoothing, self.process_group)
+ if self.reduction == "mean":
+ return loss.sum() / (target != self.ignore_index).sum()
+ else:
+ return loss
+
+
+def flash_loss(
+ ignore_index=-100,
+ reduction="mean",
+ label_smoothing=0.0,
+ process_group=None,
+ inplace_backward=False, # pylint:disable=W0613
+):
+ try:
+ from flash_attn.losses.cross_entropy import (
+ CrossEntropyLoss as FlashCrossEntropyLoss,
+ )
+
+ flash_cross_entropy_impl = True
+ except (ModuleNotFoundError, ImportError):
+ flash_cross_entropy_impl = False
+
+ assert (
+ gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
+ ), "Only flash cross entropy support parallel_output"
+
+ assert (
+ internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU
+ ), "flash cross entropy only support gpu backend"
+
+ return FlashCrossEntropyLoss(
+ ignore_index=ignore_index,
+ reduction=reduction,
+ label_smoothing=label_smoothing,
+ process_group=process_group,
+ )
+
+
# TODO: ops是否需要实现更加统一的形式
def new_cross_entropy(
ignore_index: int = -100,
@@ -171,14 +348,14 @@ def new_cross_entropy(
# )
if parallel_output:
- assert (
- gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
- ), "Only flash cross entropy support parallel_output"
- assert (
- internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU
- ), "flash cross entropy only support gpu backend"
-
- return FlashCrossEntropyLoss(
+ # return flash_loss(
+ # ignore_index=ignore_index,
+ # reduction=reduction,
+ # label_smoothing=label_smoothing,
+ # process_group=gpc.get_group(ParallelMode.TENSOR),
+ # )
+
+ return CrossEntropyApexVocabParallel(
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,