From 982677b49e57753e625c2ce0bc89d53618405639 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:54:22 -0400 Subject: [PATCH] update `test_batching__partial_final_batch_size` --- tests/test_pytorch.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index a8379dc..493c7eb 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -322,11 +322,13 @@ def test_soma_joinids( [(5, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__partial_final_batch_size( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, + return_sparse_X: bool, ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( @@ -335,12 +337,18 @@ def test_batching__partial_final_batch_size( obs_column_names=["label"], batch_size=3, use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, ) + assert exp_data_pipe.shape == (2, 3) batch_iter = iter(exp_data_pipe) next(batch_iter) - batch = next(batch_iter) - assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4"]})) with pytest.raises(StopIteration): next(batch_iter)