diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 493c7eb..6c73cc3 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -373,11 +373,11 @@ def test_batching__exactly_one_batch( batch_size=3, use_eager_fetch=use_eager_fetch, ) + assert exp_data_pipe.shape == (1, 3) batch_iter = iter(exp_data_pipe) - - batch = next(batch_iter) - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1]["label"].tolist() == ["0", "1", "2"] + X_batch, obs_batch = next(batch_iter) + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) with pytest.raises(StopIteration): next(batch_iter)