Skip to content

Commit

Permalink
Add barrier for cross device synchronization
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Jan 15, 2025
1 parent e34ffec commit cd01543
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
7 changes: 7 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@ def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
)


@barrier_on_logical_device.override(Tensor)
def barrier_on_device_default(tensor: Tensor, ordinal: int):
return iree.turbine.ops.iree.barrier_on_logical_device(
f"{ordinal}", unbox_tensor(tensor)
)


@transpose.override(Tensor)
def transpose_default(
tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int
Expand Down
8 changes: 4 additions & 4 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def all_gather_split(
shards = [
cat(
[
shard if i == j else transfer_to_logical_device(shard, i)
barrier_on_logical_device(shard, i) if i == j else transfer_to_logical_device(shard, i)
for j, shard in enumerate(input.shards)
],
dim=dim,
Expand All @@ -63,7 +63,7 @@ def all_reduce_split_or_unreduced(
functools.reduce(
lambda x, y: elementwise(torch.add, x, y),
[
shard if i == j else transfer_to_logical_device(shard, i)
barrier_on_logical_device(shard, i) if i == j else transfer_to_logical_device(shard, i)
for j, shard in enumerate(input.shards)
],
)
Expand Down Expand Up @@ -1085,7 +1085,7 @@ def reshard_like_unreduced_to_replicated(
@sharded_cat.override(SplitPrimitiveTensor)
def sharded_cat_unsharded(tensor: SplitPrimitiveTensor):
shard_ts = [
transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else shard.as_torch()
transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else barrier_on_logical_device(shard.as_torch(), 0)
for i, shard in enumerate(tensor.shards)
]
return torch.cat(shard_ts, dim=tensor.shard_dim)
Expand Down Expand Up @@ -1177,7 +1177,7 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor:
def unshard_unreduced(input: UnreducedTensor) -> Tensor:
shards = input.shards
shards = [
shard if i == 0 else transfer_to_logical_device(shard, 0)
barrier_on_logical_device(shard, i) if i == 0 else transfer_to_logical_device(shard, 0)
for i, shard in enumerate(shards)
]
return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards)
Expand Down
19 changes: 19 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
__all__ = [
"all_gather",
"all_reduce",
"barrier_on_logical_device",
"cat",
"conv2d",
"einsum_2args",
Expand Down Expand Up @@ -1012,6 +1013,24 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):
d.fail(dispatch_args)


@overridable
def barrier_on_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
...


@barrier_on_logical_device.trampoline
def _barrier_on_logical_device_trampoline(
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
):
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, ordinal)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)

@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
Expand Down

0 comments on commit cd01543

Please sign in to comment.