diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 5773574564..637b64024d 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -915,6 +915,12 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): low_indices = high_indices // 2 else: low_indices = tf.convert_to_tensor(low_indices) + + indices_rank = tf.get_static_value(ps.rank(low_indices)) + x_rank = tf.get_static_value(ps.rank(x)) + if indices_rank is None or x_rank is None: + raise ValueError("`indices` and `x` ranks must be statically known.") + # Broadcast indices together. high_indices = high_indices + tf.zeros_like(low_indices) low_indices = low_indices + tf.zeros_like(high_indices) diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index ce0d91b357..b7cd7dff38 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -15,10 +15,14 @@ """Tests for Sample Stats Ops.""" # Dependency imports -import functools +import itertools + import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf +from absl.testing import parameterized +from tensorflow.python.framework.errors_impl import InvalidArgumentError + from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.stats import sample_stats @@ -721,7 +725,8 @@ def apply_func(vector, l, h): out = np.transpose(t_out, axes=dims) return out - def check_gaussian_windowed(self, shape, indice_shape, axis, + + def check_gaussian_windowed_func(self, shape, indice_shape, axis, window_func, np_func): stat_shape = np.array(shape).astype(np.int32) stat_shape[axis] = 1 @@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis, def _make_dynamic_shape(self, x): return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) - def check_windowed(self, func, numpy_func): - check_fn = functools.partial(self.check_gaussian_windowed, - window_func=func, np_func=numpy_func) - check_fn((64, 4, 8), (128, 1, 1), axis=0) - check_fn((64, 4, 8), (32, 1, 1), axis=0) - check_fn((64, 4, 8), (32, 4, 1), axis=0) - check_fn((64, 4, 8), (32, 4, 8), axis=0) - check_fn((64, 4, 8), (64, 4, 8), axis=0) - check_fn((64, 4, 8), (128, 1), axis=0) - check_fn((64, 4, 8), (32,), axis=0) - check_fn((64, 4, 8), (32, 4), axis=0) - - check_fn((64, 4, 8), (64, 64, 1), axis=1) - check_fn((64, 4, 8), (1, 64, 1), axis=1) - check_fn((64, 4, 8), (64, 2, 8), axis=1) - check_fn((64, 4, 8), (64, 4, 8), axis=1) - check_fn((64, 4, 8), (16,), axis=1) - check_fn((64, 4, 8), (1, 64), axis=1) - - check_fn((64, 4, 8), (64, 4, 64), axis=2) - check_fn((64, 4, 8), (1, 1, 64), axis=2) - check_fn((64, 4, 8), (64, 4, 4), axis=2) - check_fn((64, 4, 8), (1, 1, 4), axis=2) - check_fn((64, 4, 8), (64, 4, 8), axis=2) - check_fn((64, 4, 8), (16,), axis=2) - check_fn((64, 4, 8), (1, 4), axis=2) - check_fn((64, 4, 8), (64, 4), axis=2) - - with self.assertRaises(Exception): - # Non broadcastable shapes - check_fn((64, 4, 8), (4, 1, 4), axis=2) - - with self.assertRaises(Exception): - # Non broadcastable shapes - check_fn((64, 4, 8), (2, 4), axis=2) - - def test_windowed_mean(self): - self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean) - - def test_windowed_mean_graph(self): - func = tf.function(sample_stats.windowed_mean) - self.check_windowed(func=func, numpy_func=np.mean) - - def test_windowed_variance(self): - self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) + @parameterized.named_parameters(*[( + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8), ], + [((128, 1, 1), 0), + ((32, 1, 1), 0), + ((32, 4, 1), 0), + ((32, 4, 8), 0), + ((64, 4, 8), 0), + ((128, 1), 0), + ((32,), 0), + ((32, 4), 0), + + ((64, 64, 1), 1), + ((1, 64, 1), 1), + ((64, 2, 8), 1), + ((64, 4, 8), 1), + ((16,), 1), + ((1, 64), 1), + + ((64, 4, 64), 2), + ((1, 1, 64), 2), + ((64, 4, 4), 2), + ((1, 1, 4), 2), + ((64, 4, 8), 2), + ((16,), 2), + ((1, 4), 2), + ((64, 4), 2)], + [ + (sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var) + ])]) + def test_windowed(self, shape, indice_shape, axis, window_func, np_func): + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, + np_func) + + + @parameterized.named_parameters(*[( + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8), ], + [((4, 1, 4), 2), ((2, 4), 2)], + [(sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var)])]) + def test_non_broadcastable_shapes(self, shape, indice_shape, axis, + window_func, np_func): + with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError), + '^shape mismatch|Incompatible shapes'): + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, + np_func) @test_util.test_all_tf_execution_regimes