Skip to content

Commit

Permalink
update `test_distributed_and_multiprocessing__returns_data_partition_…
Browse files Browse the repository at this point in the history
…for_rank`
  • Loading branch information
ryan-williams committed Oct 3, 2024
1 parent ce0f835 commit cdf78af
Showing 1 changed file with 44 additions and 40 deletions.
84 changes: 44 additions & 40 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,63 +482,67 @@ def test_distributed__returns_data_partition_for_rank(
assert sorted(soma_joinids) == expected_joinids


# fmt: off
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen",
[(12, 3, pytorch_x_value_gen), (13, 3, pytorch_x_value_gen)],
)
@pytest.mark.parametrize(
"world_size,rank,num_workers,worker_id",
"obs_range,var_range,X_value_gen,world_size,num_workers,splits",
[
(3, 1, 2, 0),
(3, 1, 2, 1),
(12, 3, pytorch_x_value_gen, 3, 2, [[0, 2, 4], [4, 6, 8], [ 8, 10, 12]]),
(13, 3, pytorch_x_value_gen, 3, 2, [[0, 2, 4], [5, 7, 9], [ 9, 11, 13]]),
(15, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 5], [5, 9, 10], [10, 14, 15]]),
(16, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 5], [6, 10, 11], [11, 15, 16]]),
(18, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [6, 10, 12], [12, 16, 18]]),
(19, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [7, 11, 13], [13, 17, 19]]),
(20, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [7, 11, 13], [14, 18, 20]]),
(21, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 7], [7, 11, 14], [14, 18, 21]]),
(25, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 8], [9, 13, 17], [17, 21, 25]]),
(27, 3, pytorch_x_value_gen, 3, 2, [[0, 6, 9], [9, 15, 18], [18, 24, 27]]),
],
)
@pytest.mark.parametrize("PipeClass", PipeClasses)
# fmt: on
def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
PipeClass: PipeClassType,
soma_experiment: Experiment,
obs_range: int,
world_size: int,
rank: int,
num_workers: int,
worker_id: int,
splits: list[list[int]],
) -> None:
"""Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and
DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch
setup or real DataLoader multiprocessing."""

with (
patch("torch.utils.data.get_worker_info") as mock_get_worker_info,
patch("torch.distributed.is_initialized") as mock_dist_is_initialized,
patch("torch.distributed.get_rank") as mock_dist_get_rank,
patch("torch.distributed.get_world_size") as mock_dist_get_world_size,
):
mock_get_worker_info.return_value = WorkerInfo(
id=worker_id, num_workers=num_workers, seed=1234
)
mock_dist_is_initialized.return_value = True
mock_dist_get_rank.return_value = rank
mock_dist_get_world_size.return_value = world_size

with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid"],
io_batch_size=2,
for rank in range(world_size):
proc_splits = splits[rank]
for worker_id in range(num_workers):
expected_joinids = list(
range(proc_splits[worker_id], proc_splits[worker_id + 1])
)
with (
patch("torch.utils.data.get_worker_info") as mock_get_worker_info,
patch("torch.distributed.is_initialized") as mock_dist_is_initialized,
patch("torch.distributed.get_rank") as mock_dist_get_rank,
patch("torch.distributed.get_world_size") as mock_dist_get_world_size,
):
mock_get_worker_info.return_value = WorkerInfo(
id=worker_id, num_workers=num_workers, seed=1234
)
mock_dist_is_initialized.return_value = True
mock_dist_get_rank.return_value = rank
mock_dist_get_world_size.return_value = world_size

full_result = list(iter(dp))
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = ExperimentAxisQueryIterable(
query,
X_name="raw",
obs_column_names=["soma_joinid"],
io_batch_size=2,
)

soma_joinids = np.concatenate(
[t[1]["soma_joinid"].to_numpy() for t in full_result]
)
batches = list(iter(dp))

expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][
0 : obs_range // world_size
]
expected_joinids = np.array_split(expected_joinids, num_workers)[worker_id]
assert sorted(soma_joinids) == expected_joinids.tolist()
soma_joinids = np.concatenate(
[batch[1]["soma_joinid"].to_numpy() for batch in batches]
).tolist()

assert soma_joinids == expected_joinids


def test_batched() -> None:
Expand Down

0 comments on commit cdf78af

Please sign in to comment.