From 4f6f6a2cb3bfb25d70e20ca20394b212d0e00f8b Mon Sep 17 00:00:00 2001 From: Yating Date: Thu, 14 Sep 2023 10:37:57 -0400 Subject: [PATCH] Add `last_value` to `ScalarTimeSeries` interface. (#6579) This allows `list_scalars` of some of the data provider implementations to return the last scalar value, which will improve the performance when loading experiments with Hparams data. Googlers, see b/292102513 for context. Tested internally: cl/563163418 #hparams --- tensorboard/data/provider.py | 14 ++++++++++++++ tensorboard/data/provider_test.py | 23 ++++++++++++++++------- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/tensorboard/data/provider.py b/tensorboard/data/provider.py index 9be75dd030..722f66de9a 100644 --- a/tensorboard/data/provider.py +++ b/tensorboard/data/provider.py @@ -777,6 +777,7 @@ class _TimeSeries: "_plugin_content", "_description", "_display_name", + "_last_value", ) def __init__( @@ -787,12 +788,14 @@ def __init__( plugin_content, description, display_name, + last_value=None, ): self._max_step = max_step self._max_wall_time = max_wall_time self._plugin_content = plugin_content self._description = description self._display_name = display_name + self._last_value = last_value @property def max_step(self): @@ -814,6 +817,10 @@ def description(self): def display_name(self): return self._display_name + @property + def last_value(self): + return self._last_value + class ScalarTimeSeries(_TimeSeries): """Metadata about a scalar time series for a particular run and tag. @@ -830,6 +837,9 @@ class ScalarTimeSeries(_TimeSeries): empty if no description was specified. display_name: An optional long-form Markdown description, as a `str` that is empty if no description was specified. Deprecated; may be removed soon. + last_value: An optional value for the latest scalar in the time series, + corresponding to the scalar at `max_step`. Note that this field might NOT + be populated by all data provider implementations. """ def __eq__(self, other): @@ -845,6 +855,8 @@ def __eq__(self, other): return False if self._display_name != other._display_name: return False + if self._last_value != other._last_value: + return False return True def __hash__(self): @@ -855,6 +867,7 @@ def __hash__(self): self._plugin_content, self._description, self._display_name, + self._last_value, ) ) @@ -866,6 +879,7 @@ def __repr__(self): "plugin_content=%r" % (self._plugin_content,), "description=%r" % (self._description,), "display_name=%r" % (self._display_name,), + "last_value=%r" % (self._last_value,), ) ) diff --git a/tensorboard/data/provider_test.py b/tensorboard/data/provider_test.py index ea9341d6eb..c5a7c38d24 100644 --- a/tensorboard/data/provider_test.py +++ b/tensorboard/data/provider_test.py @@ -91,7 +91,13 @@ def test_repr(self): class ScalarTimeSeriesTest(tb_test.TestCase): def _scalar_time_series( - self, max_step, max_wall_time, plugin_content, description, display_name + self, + max_step, + max_wall_time, + plugin_content, + description, + display_name, + last_value, ): # Helper to use explicit kwargs. return provider.ScalarTimeSeries( @@ -100,6 +106,7 @@ def _scalar_time_series( plugin_content=plugin_content, description=description, display_name=display_name, + last_value=last_value, ) def test_repr(self): @@ -109,6 +116,7 @@ def test_repr(self): plugin_content=b"AB\xCD\xEF!\x00", description="test test", display_name="one two", + last_value=0.0001, ) repr_ = repr(x) self.assertIn(repr(x.max_step), repr_) @@ -116,19 +124,20 @@ def test_repr(self): self.assertIn(repr(x.plugin_content), repr_) self.assertIn(repr(x.description), repr_) self.assertIn(repr(x.display_name), repr_) + self.assertIn(repr(x.last_value), repr_) def test_eq(self): - x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two") - x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two") - x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum") + x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512) + x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512) + x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum", 1024) self.assertEqual(x1, x2) self.assertNotEqual(x1, x3) self.assertNotEqual(x1, object()) def test_hash(self): - x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two") - x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two") - x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum") + x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512) + x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512) + x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum", 1024) self.assertEqual(hash(x1), hash(x2)) # The next check is technically not required by the `__hash__` # contract, but _should_ pass; failure on this assertion would at