From 2bf57f8383cd84aa145905aa01a30f59824feb2b Mon Sep 17 00:00:00 2001 From: bhack Date: Tue, 1 Sep 2020 19:48:02 +0200 Subject: [PATCH] Base optimizer tracking (#2126) * Update lookahead.py Inital fix of https://github.com/tensorflow/addons/issues/2094 https://github.com/tensorflow/addons/pull/2102 * Fix linting * Resolve name conflict with mixed prexision * Track baseline optimizer in avg --- tensorflow_addons/optimizers/average_wrapper.py | 1 + tensorflow_addons/optimizers/lookahead.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow_addons/optimizers/average_wrapper.py b/tensorflow_addons/optimizers/average_wrapper.py index ded496478f..d7c46d2d71 100644 --- a/tensorflow_addons/optimizers/average_wrapper.py +++ b/tensorflow_addons/optimizers/average_wrapper.py @@ -46,6 +46,7 @@ def __init__( raise TypeError("sequential_update must be of bool type") self._optimizer = optimizer + self._track_trackable(self._optimizer, "awg_optimizer") if sequential_update is not None: warnings.warn( diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 31fe13b043..e18ccaa8e0 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -80,6 +80,7 @@ def __init__( self._set_hyper("sync_period", sync_period) self._set_hyper("slow_step_size", slow_step_size) self._initialized = False + self._track_trackable(self._optimizer, "lh_base_optimizer") def _create_slots(self, var_list): self._optimizer._create_slots(