Skip to content

Commit

Permalink
fix adamw
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardyehuang committed Mar 7, 2022
1 parent 013a7ce commit 2bfd740
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion core_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_optimizer(
decay_strategy="poly",
optimizer="sgd",
sgd_momentum_rate=0.9,
adamw_weight_decay=0.0001,
):

kwargs = {
Expand All @@ -31,6 +32,7 @@ def get_optimizer(
"decay_strategy": decay_strategy,
"optimizer": optimizer,
"sgd_momentum_rate": sgd_momentum_rate,
"adamw_weight_decay": adamw_weight_decay,
}

keys = kwargs.keys()
Expand Down Expand Up @@ -98,6 +100,7 @@ def __get_optimizer(
decay_strategy="poly",
optimizer="sgd",
sgd_momentum_rate=0.9,
adamw_weight_decay=0.0001,
):

learning_rate = initial_lr
Expand Down Expand Up @@ -127,7 +130,7 @@ def __get_optimizer(
elif optimizer == "amsgrad":
_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, amsgrad=True)
elif optimizer == "adamw":
_optimizer = AdamW(weight_decay=0, learning_rate=learning_rate)
_optimizer = AdamW(weight_decay=adamw_weight_decay, learning_rate=learning_rate)
else:
raise ValueError(f"Unsupported optimizer {optimizer}")

Expand Down

0 comments on commit 2bfd740

Please sign in to comment.