Skip to content

Commit

Permalink
Resolves potential Nones caught by tytype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675998642
  • Loading branch information
Sonnet Contributor authored and copybara-github committed Oct 1, 2024
1 parent 6d59725 commit 99b1809
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions sonnet/src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def initialize(self, value: tf.Tensor):
def update(self, value: tf.Tensor):
"""See base class."""
self.initialize(value)
self.sum.assign_add(value)
self._checked_sum.assign_add(value)

@property
def _checked_sum(self):
if self.sum is None:
raise ValueError("Metric is not initialized. Call `initialize` first.")
return self.sum

@property
def value(self) -> tf.Tensor:
Expand All @@ -71,6 +77,8 @@ def value(self) -> tf.Tensor:

def reset(self):
"""See base class."""
if self.sum is None:
raise ValueError("Metric is not initialized. Call `initialize` first.")
self.sum.assign(tf.zeros_like(self.sum))


Expand All @@ -90,15 +98,23 @@ def initialize(self, value: tf.Tensor):
def update(self, value: tf.Tensor):
"""See base class."""
self.initialize(value)
self.sum.assign_add(value)
self._checkedsum.assign_add(value)
self.count.assign_add(1)

@property
def _checked_sum(self) -> tf.Variable:
if self.sum is None:
raise ValueError("Metric is not initialized. Call `initialize` first.")
return self.sum

@property
def value(self) -> tf.Tensor:
"""See base class."""
# TODO(cjfj): Assert summed type is floating-point?
return self.sum / tf.cast(self.count, dtype=self.sum.dtype)
return self._checked_sum / tf.cast(
self.count, dtype=self._checked_sum.dtype
)

def reset(self):
self.sum.assign(tf.zeros_like(self.sum))
self._checked_sum.assign(tf.zeros_like(self._checked_sum))
self.count.assign(0)

0 comments on commit 99b1809

Please sign in to comment.