From e85da2006a89fe708ac0a97352a781c2d5aabcb2 Mon Sep 17 00:00:00 2001 From: MoFHeka Date: Fri, 6 Sep 2024 18:46:25 +0800 Subject: [PATCH] [fix] There is no signature "strict_predicate_restore" in register_checkpoint_saver fucntion from tf saved_model registration when TF version 2.9. Make a function signature inspect to be compitible. --- .../python/ops/tf_save_restore_patch.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py index ccd063eb..d0c167af 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py @@ -15,6 +15,7 @@ # lint-as: python3 """patch on tensorflow""" +import inspect import functools import os.path import re @@ -569,12 +570,20 @@ def patch_on_tf_save_restore(): from tensorflow.python.saved_model.registration.registration import register_checkpoint_saver class_obj = de.Variable predicate = lambda x: isinstance(x, class_obj) - register_checkpoint_saver("DECustomSaver", - name=class_obj.__name__, - predicate=predicate, - save_fn=_de_var_fs_save_fn, - restore_fn=_de_var_fs_restore_fn, - strict_predicate_restore=False) + prekwargs = { + "package": "DECustomSaver", + "name": class_obj.__name__, + "predicate": predicate, + "save_fn": _de_var_fs_save_fn, + "restore_fn": _de_var_fs_restore_fn, + "strict_predicate_restore": False + } + rcs_sig = inspect.signature(register_checkpoint_saver) + kwargs = {} + for param in rcs_sig.parameters.values(): + k_name = param.name + kwargs[k_name] = prekwargs[k_name] + register_checkpoint_saver(**kwargs) except: functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver saver.Saver = _DynamicEmbeddingSaver