Skip to content

Commit

Permalink
fix(cross_entropy.py): replace the fa loss with apex loss (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong authored Sep 6, 2024
1 parent 3dfb540 commit 4452ad6
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 19 deletions.
27 changes: 26 additions & 1 deletion README-zh-Hans.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
<div align="center">
<b>InternEvo 特性列表</b>
</div>
<table align="center">
<table>
<tbody>
<tr align="center" valign="bottom">
<td>
Expand Down Expand Up @@ -171,6 +171,31 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
</tbody>
</table>

## 常见tips

<div align="center">
</div>
<table>
<tbody>
<tr align="center" valign="bottom">
<td>
<b>现象</b>
</td>
<td>
<b>介绍</b>
</td>
</tr>
<tr valign="bottom">
<td>
<b>在Vocab维度并行计算loss</b>
</td>
<td>
<b><a href="doc/parallel_output.md">说明</a></b>
</td>
</tr>
</tbody>
</table>

## 贡献

我们感谢所有的贡献者为改进和提升 InternEvo 所作出的努力。非常欢迎社区用户能参与进项目中来。请参考贡献指南来了解参与项目贡献的相关指引。
Expand Down
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
<div align="center">
<b>InternEvo Feature Zoo</b>
</div>
<table align="center">
<table>
<tbody>
<tr align="center" valign="bottom">
<td>
Expand Down Expand Up @@ -171,6 +171,31 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
</tbody>
</table>

## Common Tips

<div align="center">
</div>
<table>
<tbody>
<tr align="center" valign="bottom">
<td>
<b>Item</b>
</td>
<td>
<b>Introduction</b>
</td>
</tr>
<tr valign="bottom">
<td>
<b>Parallel Computing Loss</b>
</td>
<td>
<b><a href="doc/en/parallel_output.md">link</a></b>
</td>
</tr>
</tbody>
</table>

## Contribution

We appreciate all the contributors for their efforts to improve and enhance InternEvo. Community users are highly encouraged to participate in the project. Please refer to the contribution guidelines for instructions on how to contribute to the project.
Expand Down
5 changes: 5 additions & 0 deletions doc/en/parallel_output.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## Parallel Computing Loss

The parallel computing loss function in InternEvo is adapted from [Apex]( https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py). Users can replace the loss function with [Flash-Attention](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) to obtain speedup, which may lead to loss divergence.

For detailed modifications in InternEvo,please refer to the code [InternEvo-parallel-loss](https://github.com/InternLM/InternEvo/blob/develop/internlm/model/ops/cross_entropy.py)
5 changes: 5 additions & 0 deletions doc/parallel_output.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## 并行计算loss

InternEvo目前使用的并行计算loss方法改编自[Apex]( https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py)。如需要加速计算loss,可将并行计算loss方法改为[Flash-Attention](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py)的并行计算方法,需要注意的是,这可能会出现loss不收敛的情况。

具体修改代码可见[InternEvo-parallel-loss](https://github.com/InternLM/InternEvo/blob/develop/internlm/model/ops/cross_entropy.py)
211 changes: 194 additions & 17 deletions internlm/model/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,14 @@
"""

import torch
import torch.distributed as dist
from torch import nn

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger

try:
from flash_attn.losses.cross_entropy import (
CrossEntropyLoss as FlashCrossEntropyLoss,
)

flash_cross_entropy_impl = True
except (ModuleNotFoundError, ImportError):
flash_cross_entropy_impl = False

logger = get_logger(__file__)
internlm_accelerator = get_accelerator()

Expand Down Expand Up @@ -152,6 +144,191 @@ def forward(self, _input, target):
return _loss_list.view(-1)


class _VocabParallelCrossEntropy(torch.autograd.Function):
"""Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
Supports vocab parallel loss calculation, but does not support inplace backward.
NOTE: This class is different from the original Apex implementation. Apex will calculate the loss of
ignore_index and flashCrossEntropy will set it to 0. InterEvo adapts the second approach.
"""

@staticmethod
@internlm_accelerator.amp.custom_fwd
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, process_group=None):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
if process_group is not None and dist.get_world_size(process_group) > 1:
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)

# Get the partition's vocab indecies
# get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
if process_group is not None and dist.get_world_size(process_group) > 1:
rank = dist.get_rank(process_group)
# world_size = dist.get_world_size(process_group)
part_len = vocab_parallel_logits.shape[-1]
vocab_start_index, vocab_end_index = part_len * rank, part_len * (rank + 1)
else:
vocab_start_index, vocab_end_index = 0, vocab_parallel_logits.shape[-1]

# vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
ignore_mask = target == -100
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0

# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0

# All reduce is needed to get the chunks from other GPUs.
if process_group is not None and dist.get_world_size(process_group) > 1:
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)

# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)

if process_group is not None and dist.get_world_size(process_group) > 1:
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)

# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

# Loss = log(sum(exp(logits))) - predicted-logit.
sum_exp_logits = torch.log(sum_exp_logits)
loss = sum_exp_logits - predicted_logits
loss[ignore_mask] = 0.0

vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
r"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)

# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs

ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, ignore_mask)

return loss

@staticmethod
@internlm_accelerator.amp.custom_bwd
def backward(ctx, grad_output):

# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d, ignore_mask = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size

# All the inputs have softmax as thier gradient.
grad_input = softmax # s_{k}
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)

# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)

softmax_update = 1.0 - target_mask.view(-1).float()

if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
else:
grad_2d[arange_1d, masked_target_1d] -= softmax_update

# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
grad_input[ignore_mask] = 0.0 # set ignore token loss as 0.

return grad_input, None, None, None


class CrossEntropyApexVocabParallel(nn.Module):
"""Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
Supports vocab parallel loss calculation, but does not support inplace backward.
"""

def __init__(
self, ignore_index=-100, reduction="mean", label_smoothing=0.0, process_group=None, inplace_backward=False
):
super().__init__()
if reduction not in ["mean", "none"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
assert inplace_backward is False, "does not support inplace backward"
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.process_group = process_group

def forward(self, vocab_parallel_logits, target):
# assert vocab_parallel_logits.is_cuda and vocab_parallel_logits.is_cuda

# SoftmaxCrossEntropyLoss implicitly casts to float
loss = _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, self.label_smoothing, self.process_group)
if self.reduction == "mean":
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss


def flash_loss(
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
process_group=None,
inplace_backward=False, # pylint:disable=W0613
):
try:
from flash_attn.losses.cross_entropy import (
CrossEntropyLoss as FlashCrossEntropyLoss,
)

flash_cross_entropy_impl = True
except (ModuleNotFoundError, ImportError):
flash_cross_entropy_impl = False

assert (
gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
), "Only flash cross entropy support parallel_output"

assert (
internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU
), "flash cross entropy only support gpu backend"

return FlashCrossEntropyLoss(
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
process_group=process_group,
)


# TODO: ops是否需要实现更加统一的形式
def new_cross_entropy(
ignore_index: int = -100,
Expand All @@ -171,14 +348,14 @@ def new_cross_entropy(
# )

if parallel_output:
assert (
gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
), "Only flash cross entropy support parallel_output"
assert (
internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU
), "flash cross entropy only support gpu backend"

return FlashCrossEntropyLoss(
# return flash_loss(
# ignore_index=ignore_index,
# reduction=reduction,
# label_smoothing=label_smoothing,
# process_group=gpc.get_group(ParallelMode.TENSOR),
# )

return CrossEntropyApexVocabParallel(
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
Expand Down

0 comments on commit 4452ad6

Please sign in to comment.