Skip to content

Commit

Permalink
Feat(adam): support apex FusedAdam (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Sep 19, 2024
1 parent f951cdd commit fd11a58
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 29 deletions.
6 changes: 4 additions & 2 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# load function such as "llama"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
Expand Down Expand Up @@ -214,6 +214,8 @@
),
)

use_apex_adam = False

# metric_dtype can be "fp32" or other string
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"
3 changes: 3 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def args_sanity_check():
if "model_type" not in gpc.config:
gpc.config._add_item("model_type", ModelType.INTERNLM.name)

if "use_apex_adam" not in gpc.config:
gpc.config._add_item("use_apex_adam", False)

# procssing the parallel config in gpc
if "zero1" not in gpc.config.parallel:
gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False))
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
qk_interleaved=qk_interleaved,
bias=not no_bias,
rope_base=rope_base,
enable_qkv_fusion=False,
enable_qkv_fusion=True,
)

self.dropout1 = nn.Dropout(drop_rate)
Expand Down
66 changes: 40 additions & 26 deletions internlm/solver/optimizer/compatible_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,48 @@
npu_adamw_impl = False


try:
from apex.optimizers import FusedAdam as apex_adam

apex_adamw_impl = True
except (ModuleNotFoundError, ImportError):
apex_adamw_impl = False


# TODO: 给上次一个统一的接口,这些接口都能被下层的各种实现支持,哪些参数应该保留,那些参数应该省略?
def new_compatible_adamw(params, lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8):
def new_compatible_adamw(
params, lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, use_apex_adam=False
):
"""
return a compatibel adamw instance.
"""
adam_extra_kwargs = {}
backend = internlm_accelerator.get_accelerator_backend()

if backend is AcceleratorType.GPU and torch.__version__ >= "2.1.0":
if gpc.is_rank_for_log():
logger.warning(
"Use fused AdamaW to avoid nan grad norm when "
"model size is larger and use_fp32_norm=True, Please note this!"
)
adam_extra_kwargs["fused"] = True
elif backend is AcceleratorType.NPU:
if gpc.is_rank_for_log():
logger.warning(
"Use normal AdamaW, NPU fused_adamw currently has"
"accuracy issues and is not supported yet. Please note this!"
)
# TODO: support npu version adamw
elif backend is AcceleratorType.DIPU:
if gpc.is_rank_for_log():
logger.warning("Use torch.optim.AdamW rather than deeplink adamw. Please note this!")
# TODO: support deeplink version adamw
else:
if gpc.is_rank_for_log():
logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!")
if not use_apex_adam:
adam_extra_kwargs = {}
backend = internlm_accelerator.get_accelerator_backend()

return torch.optim.AdamW(params, lr=lr, betas=betas, eps=eps, **adam_extra_kwargs)
if backend is AcceleratorType.GPU and torch.__version__ >= "2.1.0":
if gpc.is_rank_for_log():
logger.warning(
"Use fused AdamaW to avoid nan grad norm when "
"model size is larger and use_fp32_norm=True, Please note this!"
)
adam_extra_kwargs["fused"] = True
elif backend is AcceleratorType.NPU:
if gpc.is_rank_for_log():
logger.warning(
"Use normal AdamaW, NPU fused_adamw currently has"
"accuracy issues and is not supported yet. Please note this!"
)
# TODO: support npu version adamw
elif backend is AcceleratorType.DIPU:
if gpc.is_rank_for_log():
logger.warning("Use torch.optim.AdamW rather than deeplink adamw. Please note this!")
# TODO: support deeplink version adamw
else:
if gpc.is_rank_for_log():
logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!")

return torch.optim.AdamW(params, lr=lr, betas=betas, eps=eps, **adam_extra_kwargs)
else:
assert apex_adamw_impl, "FusedAdam cannot be imported from apex.optimizers"
return apex_adam(params, lr=lr, betas=betas, eps=eps, adam_w_mode=True)
2 changes: 2 additions & 0 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato
adam_cfg = gpc.config.adam
zero_cfg = gpc.config.hybrid_zero_optimizer
grad_scal_cfg = gpc.config.grad_scaler
use_apex_adam = getattr(gpc.config, "use_apex_adam", False)

if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim:
map_param_block(model)
Expand All @@ -441,6 +442,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,
use_apex_adam=use_apex_adam,
)

if (
Expand Down

0 comments on commit fd11a58

Please sign in to comment.