From 1f532145c8b50a0f6ff09f47ac7883c23303a035 Mon Sep 17 00:00:00 2001 From: artem-sereda Date: Sun, 14 May 2023 17:39:36 +0200 Subject: [PATCH] allow registering custom layers in prune and clustering registries. --- .../python/core/api/clustering/keras/__init__.py | 1 + .../python/core/api/sparsity/keras/__init__.py | 2 ++ .../python/core/clustering/keras/clustering_registry.py | 4 ++++ .../python/core/sparsity/keras/prune_registry.py | 6 ++++-- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py b/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py index df355124a..c36b17413 100644 --- a/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py +++ b/tensorflow_model_optimization/python/core/api/clustering/keras/__init__.py @@ -25,4 +25,5 @@ from tensorflow_model_optimization.python.core.clustering.keras.clustering_algorithm import ClusteringAlgorithm from tensorflow_model_optimization.python.core.clustering.keras.clustering_callbacks import ClusteringSummaries from tensorflow_model_optimization.python.core.clustering.keras.clusterable_layer import ClusterableLayer +from tensorflow_model_optimization.python.core.clustering.keras.clustering_registry import ClusteringRegistry # pylint: enable=g-bad-import-order diff --git a/tensorflow_model_optimization/python/core/api/sparsity/keras/__init__.py b/tensorflow_model_optimization/python/core/api/sparsity/keras/__init__.py index 0a0507cb9..5778b8437 100644 --- a/tensorflow_model_optimization/python/core/api/sparsity/keras/__init__.py +++ b/tensorflow_model_optimization/python/core/api/sparsity/keras/__init__.py @@ -31,4 +31,6 @@ from tensorflow_model_optimization.python.core.sparsity.keras.pruning_policy import PruningPolicy from tensorflow_model_optimization.python.core.sparsity.keras.pruning_policy import PruneForLatencyOnXNNPack +from tensorflow_model_optimization.python.core.sparsity.keras.prune_registry import PruneRegistry + # pylint: enable=g-bad-import-order diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py index 0421ca607..cb8edaad9 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py @@ -220,3 +220,7 @@ def get_clusterable_weights_mha(): # pylint: disable=missing-docstring layer.get_clusterable_weights = get_clusterable_weights return layer + + @classmethod + def register_clusterable_layer(cls, layer: layers.Layer, clusterable_weights: list[str]): + cls._LAYERS_WEIGHTS_MAP[layer] = clusterable_weights diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py index fdabb8f5c..be741fd52 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py @@ -19,8 +19,6 @@ from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer -# TODO(b/139939526): move to public API. - layers = tf.keras.layers layers_compat_v1 = tf.compat.v1.keras.layers @@ -226,3 +224,7 @@ def get_prunable_weights_mha_weight(weight_name): layer.get_prunable_weights = get_prunable_weights return layer + + @classmethod + def register_prunable_layer(cls, layer: layers.Layer, prunable_weights: list[str]): + cls._LAYERS_WEIGHTS_MAP[layer] = prunable_weights