Skip to content

Commit

Permalink
Add last_value to ScalarTimeSeries interface. (#6579)
Browse files Browse the repository at this point in the history
This allows `list_scalars` of some of the data provider implementations to return the last scalar value, which will improve the performance when loading experiments with Hparams data.

Googlers, see b/292102513 for context.

Tested internally: cl/563163418

#hparams
  • Loading branch information
yatbear authored Sep 14, 2023
1 parent 1dde0cb commit 4f6f6a2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
14 changes: 14 additions & 0 deletions tensorboard/data/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ class _TimeSeries:
"_plugin_content",
"_description",
"_display_name",
"_last_value",
)

def __init__(
Expand All @@ -787,12 +788,14 @@ def __init__(
plugin_content,
description,
display_name,
last_value=None,
):
self._max_step = max_step
self._max_wall_time = max_wall_time
self._plugin_content = plugin_content
self._description = description
self._display_name = display_name
self._last_value = last_value

@property
def max_step(self):
Expand All @@ -814,6 +817,10 @@ def description(self):
def display_name(self):
return self._display_name

@property
def last_value(self):
return self._last_value


class ScalarTimeSeries(_TimeSeries):
"""Metadata about a scalar time series for a particular run and tag.
Expand All @@ -830,6 +837,9 @@ class ScalarTimeSeries(_TimeSeries):
empty if no description was specified.
display_name: An optional long-form Markdown description, as a `str` that is
empty if no description was specified. Deprecated; may be removed soon.
last_value: An optional value for the latest scalar in the time series,
corresponding to the scalar at `max_step`. Note that this field might NOT
be populated by all data provider implementations.
"""

def __eq__(self, other):
Expand All @@ -845,6 +855,8 @@ def __eq__(self, other):
return False
if self._display_name != other._display_name:
return False
if self._last_value != other._last_value:
return False
return True

def __hash__(self):
Expand All @@ -855,6 +867,7 @@ def __hash__(self):
self._plugin_content,
self._description,
self._display_name,
self._last_value,
)
)

Expand All @@ -866,6 +879,7 @@ def __repr__(self):
"plugin_content=%r" % (self._plugin_content,),
"description=%r" % (self._description,),
"display_name=%r" % (self._display_name,),
"last_value=%r" % (self._last_value,),
)
)

Expand Down
23 changes: 16 additions & 7 deletions tensorboard/data/provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def test_repr(self):

class ScalarTimeSeriesTest(tb_test.TestCase):
def _scalar_time_series(
self, max_step, max_wall_time, plugin_content, description, display_name
self,
max_step,
max_wall_time,
plugin_content,
description,
display_name,
last_value,
):
# Helper to use explicit kwargs.
return provider.ScalarTimeSeries(
Expand All @@ -100,6 +106,7 @@ def _scalar_time_series(
plugin_content=plugin_content,
description=description,
display_name=display_name,
last_value=last_value,
)

def test_repr(self):
Expand All @@ -109,26 +116,28 @@ def test_repr(self):
plugin_content=b"AB\xCD\xEF!\x00",
description="test test",
display_name="one two",
last_value=0.0001,
)
repr_ = repr(x)
self.assertIn(repr(x.max_step), repr_)
self.assertIn(repr(x.max_wall_time), repr_)
self.assertIn(repr(x.plugin_content), repr_)
self.assertIn(repr(x.description), repr_)
self.assertIn(repr(x.display_name), repr_)
self.assertIn(repr(x.last_value), repr_)

def test_eq(self):
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum")
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum", 1024)
self.assertEqual(x1, x2)
self.assertNotEqual(x1, x3)
self.assertNotEqual(x1, object())

def test_hash(self):
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two")
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum")
x1 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
x2 = self._scalar_time_series(77, 1234.5, b"\x12", "one", "two", 512)
x3 = self._scalar_time_series(66, 4321.0, b"\x7F", "hmm", "hum", 1024)
self.assertEqual(hash(x1), hash(x2))
# The next check is technically not required by the `__hash__`
# contract, but _should_ pass; failure on this assertion would at
Expand Down

0 comments on commit 4f6f6a2

Please sign in to comment.