From c6b8435fc8aac95ce6dc9f33f0a402f2fd05fc33 Mon Sep 17 00:00:00 2001 From: Tom Ward Date: Tue, 24 Sep 2024 14:58:47 -0700 Subject: [PATCH] Change `pack_tensor` to do unsafe casting when an explicit dtype is provided. This fixes issues when running with NumPy 2.0. Previously, we would cast using `same_kind`, which mean't only "safe" casts or casts within a kind, like float64 to float32, were allowed. In NumPy 2.0, the meaning of `same_kind` changed such that unsigned and signed integers are no longer deemed the same. Considering it was still possible for data loss even when using same_kind casting, we propose changing to use `unsafe` when an explicit dtype is provided, requiring the end user to ensure the value being packed in compatible with the dtype. PiperOrigin-RevId: 678408939 Change-Id: I1770714fb23335f5ee1a9ae5b7f6263394bb8297 --- dm_env_rpc/v1/spec_manager_test.py | 6 +----- dm_env_rpc/v1/tensor_utils.py | 2 +- dm_env_rpc/v1/tensor_utils_test.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/dm_env_rpc/v1/spec_manager_test.py b/dm_env_rpc/v1/spec_manager_test.py index a8781ff..ea14b7c 100644 --- a/dm_env_rpc/v1/spec_manager_test.py +++ b/dm_env_rpc/v1/spec_manager_test.py @@ -102,13 +102,9 @@ def test_pack_wrong_shape_raises_error(self): self._spec_manager.pack({'foo': [1, 2]}) def test_pack_wrong_dtype_raises_error(self): - with self.assertRaisesRegex(TypeError, 'int32'): + with self.assertRaises(ValueError): self._spec_manager.pack({'foo': 'hello'}) - def test_pack_cast_float_to_int_raises_error(self): - with self.assertRaisesRegex(TypeError, 'int32'): - self._spec_manager.pack({'foo': [0.5, 1.0, 1]}) - def test_pack_cast_int_to_float_is_ok(self): packed = self._spec_manager.pack({'fuzz': [1, 2]}) self.assertEqual([1.0, 2.0], packed[54].floats.array) diff --git a/dm_env_rpc/v1/tensor_utils.py b/dm_env_rpc/v1/tensor_utils.py index fc8bf39..02ee94e 100644 --- a/dm_env_rpc/v1/tensor_utils.py +++ b/dm_env_rpc/v1/tensor_utils.py @@ -322,7 +322,7 @@ def pack_tensor( value = value.astype( dtype=_DM_ENV_RPC_DTYPE_TO_NUMPY_DTYPE.get(dtype, dtype), copy=False, - casting='same_kind' if value.size else 'unsafe') + casting='unsafe') packed.shape[:] = value.shape packer = get_packer(value.dtype.type) diff --git a/dm_env_rpc/v1/tensor_utils_test.py b/dm_env_rpc/v1/tensor_utils_test.py index 63cf8bc..73c5a9f 100644 --- a/dm_env_rpc/v1/tensor_utils_test.py +++ b/dm_env_rpc/v1/tensor_utils_test.py @@ -168,7 +168,7 @@ def test_packed_rowmajor(self): np.testing.assert_array_equal([1, 2, 3, 4, 5, 6], tensor.int32s.array) def test_mixed_scalar_types_raises_exception(self): - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): tensor_utils.pack_tensor(['hello!', 75], dtype=np.float32) def test_jagged_arrays_throw_exceptions(self):