Skip to content

Commit

Permalink
[fix] There is no signature "strict_predicate_restore" in register_ch…
Browse files Browse the repository at this point in the history
…eckpoint_saver fucntion from tf saved_model registration when TF version 2.9.

Make a function signature inspect to be compitible.
  • Loading branch information
MoFHeka authored and rhdong committed Sep 17, 2024
1 parent 852fa2e commit e85da20
Showing 1 changed file with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# lint-as: python3
"""patch on tensorflow"""

import inspect
import functools
import os.path
import re
Expand Down Expand Up @@ -569,12 +570,20 @@ def patch_on_tf_save_restore():
from tensorflow.python.saved_model.registration.registration import register_checkpoint_saver
class_obj = de.Variable
predicate = lambda x: isinstance(x, class_obj)
register_checkpoint_saver("DECustomSaver",
name=class_obj.__name__,
predicate=predicate,
save_fn=_de_var_fs_save_fn,
restore_fn=_de_var_fs_restore_fn,
strict_predicate_restore=False)
prekwargs = {
"package": "DECustomSaver",
"name": class_obj.__name__,
"predicate": predicate,
"save_fn": _de_var_fs_save_fn,
"restore_fn": _de_var_fs_restore_fn,
"strict_predicate_restore": False
}
rcs_sig = inspect.signature(register_checkpoint_saver)
kwargs = {}
for param in rcs_sig.parameters.values():
k_name = param.name
kwargs[k_name] = prekwargs[k_name]
register_checkpoint_saver(**kwargs)
except:
functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver
saver.Saver = _DynamicEmbeddingSaver
Expand Down

0 comments on commit e85da20

Please sign in to comment.