Skip to content

Commit

Permalink
update test_batching__partial_final_batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Oct 3, 2024
1 parent 585e5ff commit 982677b
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,13 @@ def test_soma_joinids(
[(5, 3, pytorch_x_value_gen)],
)
@pytest.mark.parametrize("use_eager_fetch", [True, False])
@pytest.mark.parametrize("return_sparse_X", [True, False])
@pytest.mark.parametrize("PipeClass", PipeClasses)
def test_batching__partial_final_batch_size(
PipeClass: PipeClassType,
soma_experiment: Experiment,
use_eager_fetch: bool,
return_sparse_X: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
exp_data_pipe = PipeClass(
Expand All @@ -335,12 +337,18 @@ def test_batching__partial_final_batch_size(
obs_column_names=["label"],
batch_size=3,
use_eager_fetch=use_eager_fetch,
return_sparse_X=return_sparse_X,
)
assert exp_data_pipe.shape == (2, 3)
batch_iter = iter(exp_data_pipe)

next(batch_iter)
batch = next(batch_iter)
assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]]
X_batch, obs_batch = next(batch_iter)
if return_sparse_X:
assert isinstance(X_batch, sparse.csr_matrix)
X_batch = X_batch.todense()
assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0]]
assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4"]}))

with pytest.raises(StopIteration):
next(batch_iter)
Expand Down

0 comments on commit 982677b

Please sign in to comment.