-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(cross_entropy.py): replace the fa loss with apex loss (#317)
- Loading branch information
1 parent
3dfb540
commit 4452ad6
Showing
5 changed files
with
256 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
## Parallel Computing Loss | ||
|
||
The parallel computing loss function in InternEvo is adapted from [Apex]( https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py). Users can replace the loss function with [Flash-Attention](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) to obtain speedup, which may lead to loss divergence. | ||
|
||
For detailed modifications in InternEvo,please refer to the code [InternEvo-parallel-loss](https://github.com/InternLM/InternEvo/blob/develop/internlm/model/ops/cross_entropy.py) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
## 并行计算loss | ||
|
||
InternEvo目前使用的并行计算loss方法改编自[Apex]( https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py)。如需要加速计算loss,可将并行计算loss方法改为[Flash-Attention](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py)的并行计算方法,需要注意的是,这可能会出现loss不收敛的情况。 | ||
|
||
具体修改代码可见[InternEvo-parallel-loss](https://github.com/InternLM/InternEvo/blob/develop/internlm/model/ops/cross_entropy.py) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters