Skip to content

Commit

Permalink
Remove unused keyword arguments to Keras Model.save and Model.load.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707577369
  • Loading branch information
vkarampudi authored and Responsible ML Infra Team committed Dec 27, 2024
1 parent 6ba5f98 commit 0da0891
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 45 deletions.
2 changes: 1 addition & 1 deletion fairness_indicators/example_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_example_model(self):
]),
batch_size=1,
)
classifier.save(self._model_dir, save_format='tf')
tf.saved_model.save(classifier, self._model_dir)

eval_config = text_format.Parse(
"""
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def select_constraint(default, nightly=None, git_master=None):
return default

REQUIRED_PACKAGES = [
'tensorflow>=2.15,<2.16',
'tensorflow>=2.16,<2.17',
'tensorflow-hub>=0.16.1,<1.0.0',
'tensorflow-data-validation' + select_constraint(
default='>=1.15.1,<2.0.0',
nightly='>=1.16.0.dev',
default='>=1.16.1,<2.0.0',
nightly='>=1.17.0.dev',
git_master='@git+https://github.com/tensorflow/data-validation@master'),
'tensorflow-model-analysis' + select_constraint(
default='>=0.46,<0.47',
nightly='>=0.47.0.dev',
default='>=0.47.0,<0.48.0',
nightly='>=0.48.0.dev',
git_master='@git+https://github.com/tensorflow/model-analysis@master'),
'witwidget>=1.4.4,<2',
'protobuf>=3.20.3,<5',
Expand Down
8 changes: 4 additions & 4 deletions tensorboard_plugin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def select_constraint(default, nightly=None, git_master=None):

REQUIRED_PACKAGES = [
'protobuf>=3.20.3,<5',
'tensorboard>=2.15.2,<2.16.0',
'tensorflow>=2.15,<2.16',
'tensorboard>=2.16.2,<2.17.0',
'tensorflow>=2.16,<2.17',
'tensorflow-model-analysis'
+ select_constraint(
default='>=0.46,<0.47',
nightly='>=0.47.0.dev',
default='>=0.47,<0.48',
nightly='>=0.48.0.dev',
git_master='@git+https://github.com/tensorflow/model-analysis@master',
),
'werkzeug<2',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from tensorboard_plugin_fairness_indicators import metadata
import six
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.addons.fairness.view import widget_view
# from tensorflow_model_analysis.addons.fairness.view import widget_view
from tensorflow_model_analysis.view import widget_view
from werkzeug import wrappers
from google.protobuf import json_format
from tensorboard.backend import http_util
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,38 @@
from tensorboard_plugin_fairness_indicators import plugin
from tensorboard_plugin_fairness_indicators import summary_v2
import six
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
import tensorflow as tf2
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.eval_saved_model.example_trainers import linear_classifier
from werkzeug import test as werkzeug_test
from werkzeug import wrappers

from tensorboard.backend import application
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
from tensorboard.plugins import base_plugin

tf.enable_eager_execution()
Sequential = models.Sequential
Dense = layers.Dense

tf = tf2


# Define keras based linear classifier.
def create_linear_classifier(model_dir):

inputs = tf.keras.Input(shape=(2,))
outputs = layers.Dense(1, activation="sigmoid")(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.compile(
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
)

tf.saved_model.save(model, model_dir)
return model


class PluginTest(tf.test.TestCase):
"""Tests for Fairness Indicators plugin server."""

Expand Down Expand Up @@ -74,19 +91,19 @@ def tearDown(self):
super(PluginTest, self).tearDown()
shutil.rmtree(self._log_dir, ignore_errors=True)

def _exportEvalSavedModel(self, classifier):
def _export_eval_saved_model(self):
"""Export the evaluation saved model."""
temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir")
_, eval_export_dir = classifier(None, temp_eval_export_dir)
return eval_export_dir
return create_linear_classifier(temp_eval_export_dir)

def _writeTFExamplesToTFRecords(self, examples):
def _write_tf_examples_to_tfrecords(self, examples):
data_location = os.path.join(self.get_temp_dir(), "input_data.rio")
with tf.io.TFRecordWriter(data_location) as writer:
for example in examples:
writer.write(example.SerializeToString())
return data_location

def _makeExample(self, age, language, label):
def _make_tf_example(self, age, language, label):
example = tf.train.Example()
example.features.feature["age"].float_list.value[:] = [age]
example.features.feature["language"].bytes_list.value[:] = [
Expand All @@ -112,14 +129,14 @@ def testRoutes(self):
"foo": "".encode("utf-8")
}},
)
def testIsActive(self, get_random_stub):
def testIsActive(self):
self.assertTrue(self._plugin.is_active())

@mock.patch.object(
event_multiplexer.EventMultiplexer,
"PluginRunToTagToContent",
return_value={})
def testIsInactive(self, get_random_stub):
def testIsInactive(self):
self.assertFalse(self._plugin.is_active())

def testIndexJsRoute(self):
Expand All @@ -134,16 +151,15 @@ def testVulcanizedTemplateRoute(self):
self.assertEqual(200, response.status_code)

def testGetEvalResultsRoute(self):
model_location = self._exportEvalSavedModel(
linear_classifier.simple_linear_classifier)
model_location = self._export_eval_saved_model() # Call the method
examples = [
self._makeExample(age=3.0, language="english", label=1.0),
self._makeExample(age=3.0, language="chinese", label=0.0),
self._makeExample(age=4.0, language="english", label=1.0),
self._makeExample(age=5.0, language="chinese", label=1.0),
self._makeExample(age=5.0, language="hindi", label=1.0)
self._make_tf_example(age=3.0, language="english", label=1.0),
self._make_tf_example(age=3.0, language="chinese", label=0.0),
self._make_tf_example(age=4.0, language="english", label=1.0),
self._make_tf_example(age=5.0, language="chinese", label=1.0),
self._make_tf_example(age=5.0, language="hindi", label=1.0),
]
data_location = self._writeTFExamplesToTFRecords(examples)
data_location = self._write_tf_examples_to_tfrecords(examples)
_ = tfma.run_model_analysis(
eval_shared_model=tfma.default_eval_shared_model(
eval_saved_model_path=model_location, example_weight_key="age"),
Expand All @@ -155,32 +171,36 @@ def testGetEvalResultsRoute(self):
self.assertEqual(200, response.status_code)

def testGetEvalResultsFromURLRoute(self):
model_location = self._exportEvalSavedModel(
linear_classifier.simple_linear_classifier)
model_location = self._export_eval_saved_model() # Call the method
examples = [
self._makeExample(age=3.0, language="english", label=1.0),
self._makeExample(age=3.0, language="chinese", label=0.0),
self._makeExample(age=4.0, language="english", label=1.0),
self._makeExample(age=5.0, language="chinese", label=1.0),
self._makeExample(age=5.0, language="hindi", label=1.0)
self._make_tf_example(age=3.0, language="english", label=1.0),
self._make_tf_example(age=3.0, language="chinese", label=0.0),
self._make_tf_example(age=4.0, language="english", label=1.0),
self._make_tf_example(age=5.0, language="chinese", label=1.0),
self._make_tf_example(age=5.0, language="hindi", label=1.0),
]
data_location = self._writeTFExamplesToTFRecords(examples)
data_location = self._write_tf_examples_to_tfrecords(examples)
_ = tfma.run_model_analysis(
eval_shared_model=tfma.default_eval_shared_model(
eval_saved_model_path=model_location, example_weight_key="age"),
data_location=data_location,
output_path=self._eval_result_output_dir)

response = self._server.get(
"/data/plugin/fairness_indicators/" +
"get_evaluation_result_from_remote_path?evaluation_output_path=" +
os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY))
"/data/plugin/fairness_indicators/"
+ "get_evaluation_result_from_remote_path?evaluation_output_path="
+ self._eval_result_output_dir
)
self.assertEqual(200, response.status_code)

def testGetOutputFileFormat(self):
self.assertEqual("", self._plugin._get_output_file_format("abc_path"))
self.assertEqual("tfrecord",
self._plugin._get_output_file_format("abc_path.tfrecord"))
def test_get_output_file_format(self):
evaluation_output_path = os.path.join(
self._eval_result_output_dir, "eval_result.tfrecord"
)
self.assertEqual(
self._plugin._get_output_file_format(evaluation_output_path),
"tfrecord",
)


if __name__ == "__main__":
Expand Down

0 comments on commit 0da0891

Please sign in to comment.