Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add barrier for cross device synchronization #830 #851

Merged
merged 12 commits into from
Jan 21, 2025
2 changes: 1 addition & 1 deletion requirements-iree-pinned.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
--find-links https://iree.dev/pip-release-links.html
iree-base-compiler==3.2.0rc20250120
iree-base-runtime==3.2.0rc20250120
iree-turbine==3.2.0rc20250119
iree-turbine==3.2.0rc20250121
8 changes: 8 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default)
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)


# Einsum
def mk_menk_men(inputs, weights):
# batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k
Expand Down Expand Up @@ -443,6 +444,13 @@ def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
)


@barrier_on_logical_device.override(Tensor)
def barrier_on_device_default(tensor: Tensor, ordinal: int):
return iree.turbine.ops.iree.barrier_on_logical_device(
f"{ordinal}", unbox_tensor(tensor)
)


@transpose.override(Tensor)
def transpose_default(
tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int
Expand Down
24 changes: 20 additions & 4 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def all_gather_split(
shards = [
cat(
[
shard if i == j else transfer_to_logical_device(shard, i)
(
barrier_on_logical_device(shard, i)
if i == j
else transfer_to_logical_device(shard, i)
)
for j, shard in enumerate(input.shards)
],
dim=dim,
Expand All @@ -63,7 +67,11 @@ def all_reduce_split_or_unreduced(
functools.reduce(
lambda x, y: elementwise(torch.add, x, y),
[
shard if i == j else transfer_to_logical_device(shard, i)
(
barrier_on_logical_device(shard, i)
if i == j
else transfer_to_logical_device(shard, i)
)
for j, shard in enumerate(input.shards)
],
)
Expand Down Expand Up @@ -1090,7 +1098,11 @@ def reshard_like_unreduced_to_replicated(
@sharded_cat.override(SplitPrimitiveTensor)
def sharded_cat_unsharded(tensor: SplitPrimitiveTensor):
shard_ts = [
transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else shard.as_torch()
(
transfer_to_logical_device(shard.as_torch(), 0)
if i != 0
else barrier_on_logical_device(shard.as_torch(), 0)
)
for i, shard in enumerate(tensor.shards)
]
return torch.cat(shard_ts, dim=tensor.shard_dim)
Expand Down Expand Up @@ -1182,7 +1194,11 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor:
def unshard_unreduced(input: UnreducedTensor) -> Tensor:
shards = input.shards
shards = [
shard if i == 0 else transfer_to_logical_device(shard, 0)
(
barrier_on_logical_device(shard, i)
if i == 0
else transfer_to_logical_device(shard, 0)
)
for i, shard in enumerate(shards)
]
return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards)
Expand Down
20 changes: 20 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
__all__ = [
"all_gather",
"all_reduce",
"barrier_on_logical_device",
"cat",
"conv2d",
"einsum_2args",
Expand Down Expand Up @@ -1025,6 +1026,25 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):
d.fail(dispatch_args)


@overridable
def barrier_on_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
...


@barrier_on_logical_device.trampoline
def _barrier_on_logical_device_trampoline(
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
):
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, ordinal)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
Expand Down
Loading