Skip to content

Commit

Permalink
NA
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713366693
  • Loading branch information
zhouhao138 authored and Responsible ML Infra Team committed Jan 8, 2025
1 parent fd265d4 commit 7f48025
Showing 1 changed file with 76 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.eval_saved_model.example_trainers import linear_classifier
from tensorflow_model_analysis.utils import example_keras_model
from werkzeug import test as werkzeug_test
from werkzeug import wrappers

from google.protobuf import text_format
from tensorboard.backend import application
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
from tensorboard.plugins import base_plugin
Expand Down Expand Up @@ -74,19 +75,20 @@ def tearDown(self):
super(PluginTest, self).tearDown()
shutil.rmtree(self._log_dir, ignore_errors=True)

def _exportEvalSavedModel(self, classifier):
def _export_keras_model(self, classifier):
temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir")
_, eval_export_dir = classifier(None, temp_eval_export_dir)
return eval_export_dir
classifier.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse")
tf.saved_model.save(classifier, temp_eval_export_dir)
return temp_eval_export_dir

def _writeTFExamplesToTFRecords(self, examples):
def _write_tf_examples_to_tfrecords(self, examples):
data_location = os.path.join(self.get_temp_dir(), "input_data.rio")
with tf.io.TFRecordWriter(data_location) as writer:
for example in examples:
writer.write(example.SerializeToString())
return data_location

def _makeExample(self, age, language, label):
def _make_example(self, age, language, label):
example = tf.train.Example()
example.features.feature["age"].float_list.value[:] = [age]
example.features.feature["language"].bytes_list.value[:] = [
Expand All @@ -95,6 +97,27 @@ def _makeExample(self, age, language, label):
example.features.feature["label"].float_list.value[:] = [label]
return example

def _make_eval_config(self):
return text_format.Parse(
"""
model_specs {
signature_name: "serving_default"
prediction_key: "predictions" # placeholder
label_key: "label" # placeholder
}
slicing_specs {}
metrics_specs {
metrics {
class_name: "ExampleCount"
}
metrics {
class_name: "Accuracy"
}
}
""",
tfma.EvalConfig(),
)

def testRoutes(self):
self.assertIsInstance(self._routes["/get_evaluation_result"],
abc.Callable)
Expand All @@ -112,14 +135,14 @@ def testRoutes(self):
"foo": "".encode("utf-8")
}},
)
def testIsActive(self, get_random_stub):
def testIsActive(self, get_random_stub): # pylint: disable=unused-argument
self.assertTrue(self._plugin.is_active())

@mock.patch.object(
event_multiplexer.EventMultiplexer,
"PluginRunToTagToContent",
return_value={})
def testIsInactive(self, get_random_stub):
def testIsInactive(self, get_random_stub): # pylint: disable=unused-argument
self.assertFalse(self._plugin.is_active())

def testIndexJsRoute(self):
Expand All @@ -130,57 +153,75 @@ def testIndexJsRoute(self):
def testVulcanizedTemplateRoute(self):
"""Tests that the /tags route offers the correct run to tag mapping."""
response = self._server.get(
"/data/plugin/fairness_indicators/vulcanized_tfma.js")
"/data/plugin/fairness_indicators/vulcanized_tfma.js"
)
self.assertEqual(200, response.status_code)

def testGetEvalResultsRoute(self):
model_location = self._exportEvalSavedModel(
linear_classifier.simple_linear_classifier)
model_location = self._export_keras_model(
example_keras_model.get_example_classifier_model(
input_feature_key="language"
)
)
examples = [
self._makeExample(age=3.0, language="english", label=1.0),
self._makeExample(age=3.0, language="chinese", label=0.0),
self._makeExample(age=4.0, language="english", label=1.0),
self._makeExample(age=5.0, language="chinese", label=1.0),
self._makeExample(age=5.0, language="hindi", label=1.0)
self._make_example(age=3.0, language="english", label=1.0),
self._make_example(age=3.0, language="chinese", label=0.0),
self._make_example(age=4.0, language="english", label=1.0),
self._make_example(age=5.0, language="chinese", label=1.0),
self._make_example(age=5.0, language="hindi", label=1.0),
]
data_location = self._writeTFExamplesToTFRecords(examples)
eval_config = self._make_eval_config()
data_location = self._write_tf_examples_to_tfrecords(examples)
_ = tfma.run_model_analysis(
eval_shared_model=tfma.default_eval_shared_model(
eval_saved_model_path=model_location, example_weight_key="age"),
eval_saved_model_path=model_location, eval_config=eval_config
),
eval_config=eval_config,
data_location=data_location,
output_path=self._eval_result_output_dir)
output_path=self._eval_result_output_dir,
)

response = self._server.get(
"/data/plugin/fairness_indicators/get_evaluation_result?run=.")
"/data/plugin/fairness_indicators/get_evaluation_result?run=."
)
self.assertEqual(200, response.status_code)

def testGetEvalResultsFromURLRoute(self):
model_location = self._exportEvalSavedModel(
linear_classifier.simple_linear_classifier)
model_location = self._export_keras_model(
example_keras_model.get_example_classifier_model(
input_feature_key="language"
)
)
examples = [
self._makeExample(age=3.0, language="english", label=1.0),
self._makeExample(age=3.0, language="chinese", label=0.0),
self._makeExample(age=4.0, language="english", label=1.0),
self._makeExample(age=5.0, language="chinese", label=1.0),
self._makeExample(age=5.0, language="hindi", label=1.0)
self._make_example(age=3.0, language="english", label=1.0),
self._make_example(age=3.0, language="chinese", label=0.0),
self._make_example(age=4.0, language="english", label=1.0),
self._make_example(age=5.0, language="chinese", label=1.0),
self._make_example(age=5.0, language="hindi", label=1.0),
]
data_location = self._writeTFExamplesToTFRecords(examples)
eval_config = self._make_eval_config()
data_location = self._write_tf_examples_to_tfrecords(examples)
_ = tfma.run_model_analysis(
eval_shared_model=tfma.default_eval_shared_model(
eval_saved_model_path=model_location, example_weight_key="age"),
eval_saved_model_path=model_location, eval_config=eval_config
),
eval_config=eval_config,
data_location=data_location,
output_path=self._eval_result_output_dir)
output_path=self._eval_result_output_dir,
)

response = self._server.get(
"/data/plugin/fairness_indicators/" +
"get_evaluation_result_from_remote_path?evaluation_output_path=" +
os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY))
"/data/plugin/fairness_indicators/"
+ "get_evaluation_result_from_remote_path?evaluation_output_path="
+ os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY)
)
self.assertEqual(200, response.status_code)

def testGetOutputFileFormat(self):
self.assertEqual("", self._plugin._get_output_file_format("abc_path"))
self.assertEqual("tfrecord",
self._plugin._get_output_file_format("abc_path.tfrecord"))
self.assertEqual(
"tfrecord", self._plugin._get_output_file_format("abc_path.tfrecord")
)


if __name__ == "__main__":
Expand Down

0 comments on commit 7f48025

Please sign in to comment.