Skip to content

Commit

Permalink
update test_batching__exactly_one_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Oct 3, 2024
1 parent 982677b commit 446f438
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,11 @@ def test_batching__exactly_one_batch(
batch_size=3,
use_eager_fetch=use_eager_fetch,
)
assert exp_data_pipe.shape == (1, 3)
batch_iter = iter(exp_data_pipe)

batch = next(batch_iter)
assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
assert batch[1]["label"].tolist() == ["0", "1", "2"]
X_batch, obs_batch = next(batch_iter)
assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]}))

with pytest.raises(StopIteration):
next(batch_iter)
Expand Down

0 comments on commit 446f438

Please sign in to comment.