diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 2cab7537..44f94c1f 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -22,14 +22,14 @@ CHECKPOINT_EVERY = 50 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. - enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. + enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), load_ckpt_folder="local:llm_ckpts/", # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined # load function such as "llama" load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering @@ -214,6 +214,8 @@ ), ) +use_apex_adam = False + # metric_dtype can be "fp32" or other string # only when set to "fp32" will use fp32 to calc in metrics # metric_dtype = "fp32" diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b1e766d1..0aca4332 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -73,6 +73,9 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + if "use_apex_adam" not in gpc.config: + gpc.config._add_item("use_apex_adam", False) + # procssing the parallel config in gpc if "zero1" not in gpc.config.parallel: gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False)) diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 43071245..83e64bda 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -138,7 +138,7 @@ def __init__( qk_interleaved=qk_interleaved, bias=not no_bias, rope_base=rope_base, - enable_qkv_fusion=False, + enable_qkv_fusion=True, ) self.dropout1 = nn.Dropout(drop_rate) diff --git a/internlm/solver/optimizer/compatible_adamw.py b/internlm/solver/optimizer/compatible_adamw.py index bca8c274..f1c35f54 100644 --- a/internlm/solver/optimizer/compatible_adamw.py +++ b/internlm/solver/optimizer/compatible_adamw.py @@ -19,34 +19,48 @@ npu_adamw_impl = False +try: + from apex.optimizers import FusedAdam as apex_adam + + apex_adamw_impl = True +except (ModuleNotFoundError, ImportError): + apex_adamw_impl = False + + # TODO: 给上次一个统一的接口,这些接口都能被下层的各种实现支持,哪些参数应该保留,那些参数应该省略? -def new_compatible_adamw(params, lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8): +def new_compatible_adamw( + params, lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, use_apex_adam=False +): """ return a compatibel adamw instance. """ - adam_extra_kwargs = {} - backend = internlm_accelerator.get_accelerator_backend() - - if backend is AcceleratorType.GPU and torch.__version__ >= "2.1.0": - if gpc.is_rank_for_log(): - logger.warning( - "Use fused AdamaW to avoid nan grad norm when " - "model size is larger and use_fp32_norm=True, Please note this!" - ) - adam_extra_kwargs["fused"] = True - elif backend is AcceleratorType.NPU: - if gpc.is_rank_for_log(): - logger.warning( - "Use normal AdamaW, NPU fused_adamw currently has" - "accuracy issues and is not supported yet. Please note this!" - ) - # TODO: support npu version adamw - elif backend is AcceleratorType.DIPU: - if gpc.is_rank_for_log(): - logger.warning("Use torch.optim.AdamW rather than deeplink adamw. Please note this!") - # TODO: support deeplink version adamw - else: - if gpc.is_rank_for_log(): - logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!") + if not use_apex_adam: + adam_extra_kwargs = {} + backend = internlm_accelerator.get_accelerator_backend() - return torch.optim.AdamW(params, lr=lr, betas=betas, eps=eps, **adam_extra_kwargs) + if backend is AcceleratorType.GPU and torch.__version__ >= "2.1.0": + if gpc.is_rank_for_log(): + logger.warning( + "Use fused AdamaW to avoid nan grad norm when " + "model size is larger and use_fp32_norm=True, Please note this!" + ) + adam_extra_kwargs["fused"] = True + elif backend is AcceleratorType.NPU: + if gpc.is_rank_for_log(): + logger.warning( + "Use normal AdamaW, NPU fused_adamw currently has" + "accuracy issues and is not supported yet. Please note this!" + ) + # TODO: support npu version adamw + elif backend is AcceleratorType.DIPU: + if gpc.is_rank_for_log(): + logger.warning("Use torch.optim.AdamW rather than deeplink adamw. Please note this!") + # TODO: support deeplink version adamw + else: + if gpc.is_rank_for_log(): + logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!") + + return torch.optim.AdamW(params, lr=lr, betas=betas, eps=eps, **adam_extra_kwargs) + else: + assert apex_adamw_impl, "FusedAdam cannot be imported from apex.optimizers" + return apex_adam(params, lr=lr, betas=betas, eps=eps, adam_w_mode=True) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index ead3dd98..0867c430 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -430,6 +430,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato adam_cfg = gpc.config.adam zero_cfg = gpc.config.hybrid_zero_optimizer grad_scal_cfg = gpc.config.grad_scaler + use_apex_adam = getattr(gpc.config, "use_apex_adam", False) if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim: map_param_block(model) @@ -441,6 +442,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps, + use_apex_adam=use_apex_adam, ) if (