diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index 99bf1e04..ac9bce8f 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -67,6 +67,12 @@ def new_cross_entropy( except KeyError: raise KeyError(f"op_type only support: {cross_entropy_op_name_map.keys()}") + if not gpc.config.get("use_fp32_logits", True): + assert op_type in [ + CrossEntropyOpType.flash_vocab_parallel, + CrossEntropyOpType.apex_naive, + ], "use_fp32_logits=False only support 'flash_vocab_parallel' or 'apex_naive' loss function" + if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: assert op_type in [ CrossEntropyOpType.torch_naive, diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 79e9caf4..242d7c6c 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -300,7 +300,7 @@ def inject_model(model): else: model = NaiveAMPModel( model=model, - output_to_fp32=gpc.is_no_pp_or_last_stage(), + output_to_fp32=gpc.is_no_pp_or_last_stage() and gpc.config.get("use_fp32_logits", True), dtype=gpc.config.model.get("dtype", torch.half), sync_buffer=False, )