From ddd75864e57012d338ca9b341d89caad0565c716 Mon Sep 17 00:00:00 2001 From: Julian Qian Date: Mon, 15 Jul 2024 19:57:37 -0700 Subject: [PATCH] add HkvHashTableExportWithScores op --- .../core/ops/hkv_hashtable_ops.cc | 20 +++++++++ .../python/keras/layers/embedding_test.py | 23 ++++++++++ .../dynamic_embedding/python/keras/models.py | 16 +++---- .../kernel_tests/hkv_hashtable_evict_test.py | 44 +++++++++++++++++++ .../python/ops/dynamic_embedding_creator.py | 4 +- .../python/ops/dynamic_embedding_variable.py | 22 +++++++++- .../python/ops/hkv_hashtable_ops.py | 14 ++++++ 7 files changed, 132 insertions(+), 11 deletions(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc index ac365667f..9ed3a059a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/hkv_hashtable_ops.cc @@ -256,6 +256,26 @@ REGISTER_OP(PREFIX_OP_NAME(HkvHashTableSaveToFileSystem)) .Attr("dirpath_env: string") .Attr("append_to_file: bool") .Attr("buffer_size: int >= 1"); +REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExportWithScores)) + .Input("table_handle: resource") + .Output("keys: key_dtype") + .Output("values: value_dtype") + .Output("scores: int64") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("split_size: int") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + ShapeHandle keys = c->UnknownShapeOfRank(1); + ShapeHandle values = c->UnknownShapeOfRank(1); + ShapeHandle scores = c->UnknownShapeOfRank(1); + ShapeAndType value_shape_and_type; + c->set_output(0, keys); + c->set_output(1, values); + c->set_output(2, scores); + return TFOkStatus; + }); REGISTER_OP(PREFIX_OP_NAME(HkvHashTableExportKeysAndScores)) .Input("table_handle: resource") .Output("keys: key_dtype") diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py index 7d050fade..c5087a8e6 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding_test.py @@ -157,6 +157,29 @@ def test_backward(self): model.fit(x, y, verbose=0) self.assertAllEqual(emb_layer.params.size(), start) + def test_backward_adagrad(self): + if not context.executing_eagerly(): + self.skipTest('Only test in eager mode') + init = tf.keras.initializers.RandomNormal(seed=0) + model = get_sequential_model(de.keras.layers.Embedding, + 4, + initializer=init, + bp_v2=False, + name='go582') + optmz = tf.keras.optimizers.Adagrad(1E-4) + optmz = de.DynamicEmbeddingOptimizer(optmz) + emb_layer = model.layers[0] + model.compile(optimizer=optmz, loss='binary_crossentropy') + start = 0 + batch_size = 10 + for i in range(1, 10): + x = math_ops.range(start, start + batch_size * i, dtype=dtypes.int64) + x = tf.reshape(x, (batch_size, -1)) + start += batch_size * i + y = tf.zeros((batch_size, 1), dtype=dtypes.float32) + model.fit(x, y, verbose=0) + self.assertAllEqual(emb_layer.params.size(), start) + def test_backward_bp_v2(self): if not context.executing_eagerly(): self.skipTest('Only test in eager mode') diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py index 65fc8be54..36bf0bf17 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py @@ -99,7 +99,7 @@ def _check_saveable_and_redirect_new_de_dir(hvd_rank=0): if hasattr(de_var, 'saveable'): de_var.saveable._saver_config.save_path = de_dir - def _traverse_emb_layers_and_save(hvd_rank=0): + def _traverse_emb_layers_and_save(proc_size=1, proc_rank=0): for var in model.variables: if not hasattr(var, "params"): continue @@ -117,24 +117,24 @@ def _traverse_emb_layers_and_save(hvd_rank=0): a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars for de_opt_var in de_opt_vars: de_opt_var.save_to_file_system(dirpath=de_dir, - proc_size=hvd.size(), - proc_rank=hvd.rank()) - if hvd_rank == 0: + proc_size=proc_size, + proc_rank=proc_rank) + if proc_rank == 0: # FileSystemSaver works well at rank 0. continue # save Dynamic Embedding Parameters de_var.save_to_file_system(dirpath=de_dir, - proc_size=hvd.size(), - proc_rank=hvd.rank()) + proc_size=proc_size, + proc_rank=proc_rank) if hvd is None: call_original_save_func() - _traverse_emb_layers_and_save(0) + _traverse_emb_layers_and_save() else: _check_saveable_and_redirect_new_de_dir(hvd.rank()) if hvd.rank() == 0: call_original_save_func() - _traverse_emb_layers_and_save(hvd.rank()) + _traverse_emb_layers_and_save(hvd.size, hvd.rank()) hvd.join() # Sync for avoiding rank conflict diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py index 3cf8b77fe..b916317c4 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/hkv_hashtable_evict_test.py @@ -194,6 +194,50 @@ def test_export_keys_and_scores(self): del table + @test_util.run_in_graph_and_eager_modes() + def test_export_with_scores(self): + if not is_gpu_available: + self.skipTest('Only test when gpu is available.') + key_dtype = dtypes.int64 + value_dtype = dtypes.int32 + dim = 8 + for strategy in de.HkvEvictStrategy: + with self.session(use_gpu=True, config=default_config): + table = de.get_variable( + str(strategy), + key_dtype=key_dtype, + value_dtype=value_dtype, + initializer=0, + dim=dim, + init_size=1024, + kv_creator=de.HkvHashTableCreator( + config=de.HkvHashTableConfig(init_capacity=1024, + max_capacity=1024, + max_hbm_for_values=1024 * 64, + evict_strategy=strategy, + gen_scores_fn=gen_scores_fn))) + keys = constant_op.constant( + np.array([0, 1, 2, 3]).astype(_type_converter(key_dtype)), + key_dtype) + values = constant_op.constant( + _convert([[0] * dim, [1] * dim, [2] * dim, [3] * dim], value_dtype), + value_dtype) + + self.evaluate(table.upsert(keys, values)) + + exported_keys, exported_values, exported_scores = self.evaluate( + table.export_with_scores(1)) + self.assertAllEqual(np.sort(exported_keys), keys) + self.assertAllEqual(exported_values, values) + if strategy is de.HkvEvictStrategy.CUSTOMIZED: + self.assertAllEqual(np.sort(exported_scores), gen_scores_fn(keys)) + elif strategy is de.HkvEvictStrategy.EPOCHLFU: + self.assertAllEqual(exported_scores, np.full((4), 1)) + elif strategy is de.HkvEvictStrategy.LFU: + self.assertAllEqual(exported_scores, np.ones(4)) + + del table + def test_evict_strategy_lfu(self): if not is_gpu_available: self.skipTest('Only test when gpu is available.') diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py index c08c7bca6..c8d7638a8 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_creator.py @@ -395,7 +395,7 @@ def __init__(self, proc_size: int = None, proc_rank: int = None, save_path: str = None, - buffer_size: int = 4194304): + buffer_size: int = 4096): """ FileSystemSaverConfig can be used to assign save_path of DynamicEmbeddings. """ if type(proc_rank) != type(proc_size): @@ -493,7 +493,7 @@ def __init__(self, proc_size: int = None, proc_rank: int = None, save_path: str = None, - buffer_size: int = 4194304): + buffer_size: int = 4096): self.config = FileSystemSaverConfig(proc_size=proc_size, proc_rank=proc_rank, save_path=save_path, diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py index e8fd70367..146df872d 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py @@ -717,7 +717,27 @@ def _convert_anything_to_init(self, raw_init, dim): else: raise ValueError except: - init = array_ops.fill([dim], array_ops.reshape(init, [-1])[0]) + + def is_indexable_and_nonempty(obj): + has_getitem = hasattr(obj, '__getitem__') + is_nonempty = hasattr(obj, '__len__') and len(obj) > 0 + return has_getitem and is_nonempty + + if isinstance(init, int) or isinstance(init, float): + first_element = init + elif not isinstance(init, tf.Tensor) and is_indexable_and_nonempty(init): + first_element = init[0] + else: + reshaped_init = array_ops.reshape(init, [-1]) + size_of_reshaped_init = tf.size(reshaped_init) + + def get_default_value(): + default_value = 0.0 if self.value_dtype.is_floating else 0 + return tf.constant(default_value, dtype=self.value_dtype) + + first_element = tf.cond(tf.greater(size_of_reshaped_init, 0), + lambda: reshaped_init[0], get_default_value) + init = array_ops.fill([dim], first_element) init = math_ops.cast(init, dtype=self.value_dtype) return init diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py index 434c0b32f..22e53ba45 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/hkv_hashtable_ops.py @@ -433,6 +433,20 @@ def export_keys_and_scores(self, split_size, name=None): split_size=split_size) return keys, scores + def export_with_scores(self, split_size, name=None): + if not (split_size > 0 and isinstance(split_size, int)): + raise ValueError(f'split_size must be positive integer.') + + with ops.name_scope(name, "%s_lookup_table_export_with_scores" % self.name, + [self.resource_handle]): + with ops.colocate_with(self.resource_handle): + keys, values, scores = hkv_ops.tfra_hkv_hash_table_export_with_scores( + self.resource_handle, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + split_size=split_size) + return keys, values, scores + def save_to_file_system(self, dirpath, file_name=None,