Skip to content

Commit

Permalink
Add support for jax.sharding.PartitionSpec.UNCONSTRAINED in logical…
Browse files Browse the repository at this point in the history
… specification

PiperOrigin-RevId: 630547382
  • Loading branch information
marksandler2 authored and Flax Authors committed May 7, 2024
1 parent 3d98eda commit f74680b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 38 deletions.
12 changes: 10 additions & 2 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import enum
import functools
import threading
import types
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import jax
Expand Down Expand Up @@ -127,7 +128,10 @@ def _logical_to_mesh_axes(
raise ValueError('Unknown axis rule specification type.')
# We assign mesh axes using a priority based ruleset over logical axis names.
result: List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]]
result = [_unassigned_axis] * len(array_dim_names)
result = [
(_unassigned_axis if isinstance(name, str) else name)
for name in array_dim_names
]
for rule_model_name, rule_mesh_names in rules:
if rule_model_name in array_dim_names:
pos = array_dim_names.index(rule_model_name)
Expand Down Expand Up @@ -263,10 +267,14 @@ def _with_sharding_constraint_one_fallback(
x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh
)

_AxisTypes = (
str, types.NoneType, type(jax.sharding.PartitionSpec.UNCONSTRAINED)
)


def _is_logical_spec(x):
return x is None or (
isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x)
isinstance(x, tuple) and all(isinstance(e, _AxisTypes) for e in x)
)


Expand Down
86 changes: 50 additions & 36 deletions tests/linen/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,42 +92,51 @@ def test_logical_to_mesh_axes_priorities(self):
)

@parameterized.parameters(
dict(
rules=(('a', ('model', 'data')), ('b', 'data')),
axes=('a', 'b'),
expected=(('model', 'data'), None),
),
dict(
rules=(('a', ('model', 'replica')), ('b', 'data')),
axes=('a', 'b'),
expected=(('model', 'replica'), 'data'),
),
dict(
rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))),
axes=('a', 'b'),
expected=(('model', 'replica'), None),
),
dict(
rules=(('a', ('model', 'replica')), ('b', 'model')),
axes=('a', 'b', 'c'),
expected=(('model', 'replica'), None, None),
),
dict(rules=(), axes=('a', 'b', 'c'), expected=(None, None, None)),
dict(
rules=(('a', None), ('a', 'model')),
axes=('a', 'b'),
expected=(None, None),
),
dict(
rules=(
('baz', 'data'),
('bar', None),
('foo', 'model'),
('foo', 'data'),
dict(
rules=(('a', ('model', 'data')), ('b', 'data')),
axes=('a', 'b'),
expected=(('model', 'data'), None),
),
dict(
rules=(('a', ('model', 'replica')), ('b', 'data')),
axes=('a', 'b'),
expected=(('model', 'replica'), 'data'),
),
dict(
rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))),
axes=('a', 'b'),
expected=(('model', 'replica'), None),
),
dict(
rules=(('a', ('model', 'replica')), ('b', 'model')),
axes=('a', 'b', 'c'),
expected=(('model', 'replica'), None, None),
),
dict(rules=(), axes=('a', 'b', 'c'), expected=(None, None, None)),
dict(
rules=(('a', None), ('a', 'model')),
axes=('a', 'b'),
expected=(None, None),
),
dict(
rules=(
('baz', 'data'),
('bar', None),
('foo', 'model'),
('foo', 'data'),
),
axes=('baz', 'bar', 'foo'),
expected=('data', None, 'model'),
),
dict(
rules=(('baz', 'data'), ('foo', ('model', 'emb'))),
axes=('baz', jax.sharding.PartitionSpec.UNCONSTRAINED, 'foo'),
expected=(
'data',
jax.sharding.PartitionSpec.UNCONSTRAINED,
('model', 'emb'),
),
),
axes=('baz', 'bar', 'foo'),
expected=('data', None, 'model'),
),
)
def test_logical_to_mesh_axes_cases(self, rules, axes, expected):
with partitioning.axis_rules(rules):
Expand All @@ -136,6 +145,7 @@ def test_logical_to_mesh_axes_cases(self, rules, axes, expected):

@mock.patch('flax.linen.spmd._with_sharding_constraint')
def test_with_sharding_constraint(self, wsc_fn):
unconstrained = jax.sharding.PartitionSpec.UNCONSTRAINED
arr = jnp.ones((2, 2))
axes = ('foo', 'bar')
partitioning.set_axis_rules(())
Expand All @@ -146,7 +156,11 @@ def test_with_sharding_constraint(self, wsc_fn):
wsc_fn.assert_not_called()
_ = partitioning.with_sharding_constraint(arr, axes)
wsc_fn.assert_called_with(
arr, jax.sharding.PartitionSpec('data', 'model'), mesh=None
arr, jax.sharding.PartitionSpec('data', 'model'), mesh=None
)
_ = partitioning.with_sharding_constraint(arr, ('foo', unconstrained))
wsc_fn.assert_called_with(
arr, jax.sharding.PartitionSpec('data', unconstrained), mesh=None
)

@mock.patch('flax.linen.spmd._with_sharding_constraint')
Expand Down

0 comments on commit f74680b

Please sign in to comment.