diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 6c73cc3..1d0ce35 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -431,9 +431,9 @@ def test_batching__partial_soma_batches_are_concatenated( use_eager_fetch=use_eager_fetch, ) - full_result = list(exp_data_pipe) + batches = list(exp_data_pipe) - assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] + assert [len(batch[0]) for batch in batches] == [3, 3, 3, 1] @pytest.mark.parametrize( @@ -471,9 +471,9 @@ def test_distributed__returns_data_partition_for_rank( obs_column_names=["soma_joinid"], io_batch_size=2, ) - full_result = list(iter(dp)) + batches = list(iter(dp)) soma_joinids = np.concatenate( - [t[1]["soma_joinid"].to_numpy() for t in full_result] + [batch[1]["soma_joinid"].to_numpy() for batch in batches] ) expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][