From 446f438320092c69f469290b145376a4e2941402 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:54:56 -0400 Subject: [PATCH] update `test_batching__exactly_one_batch` --- tests/test_pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 493c7eb..6c73cc3 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -373,11 +373,11 @@ def test_batching__exactly_one_batch( batch_size=3, use_eager_fetch=use_eager_fetch, ) + assert exp_data_pipe.shape == (1, 3) batch_iter = iter(exp_data_pipe) - - batch = next(batch_iter) - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1]["label"].tolist() == ["0", "1", "2"] + X_batch, obs_batch = next(batch_iter) + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) with pytest.raises(StopIteration): next(batch_iter)