diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 2bf4b4c..8cbae43 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -567,6 +567,7 @@ def test_splits() -> None: from tiledbsoma_ml.pytorch import _splits assert _splits(10, 1).tolist() == [0, 10] + assert _splits(10, 2).tolist() == [0, 5, 10] assert _splits(10, 3).tolist() == [0, 4, 7, 10] assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]