From f74680b3aef94f92735fa0a858788211dba8c614 Mon Sep 17 00:00:00 2001 From: Mark Sandler Date: Fri, 3 May 2024 18:12:32 -0700 Subject: [PATCH] Add support for `jax.sharding.PartitionSpec.UNCONSTRAINED` in logical specification PiperOrigin-RevId: 630547382 --- flax/linen/spmd.py | 12 ++++- tests/linen/partitioning_test.py | 86 +++++++++++++++++++------------- 2 files changed, 60 insertions(+), 38 deletions(-) diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 1610d7f87f..c2f5c16c59 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -30,6 +30,7 @@ import enum import functools import threading +import types from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import jax @@ -127,7 +128,10 @@ def _logical_to_mesh_axes( raise ValueError('Unknown axis rule specification type.') # We assign mesh axes using a priority based ruleset over logical axis names. result: List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]] - result = [_unassigned_axis] * len(array_dim_names) + result = [ + (_unassigned_axis if isinstance(name, str) else name) + for name in array_dim_names + ] for rule_model_name, rule_mesh_names in rules: if rule_model_name in array_dim_names: pos = array_dim_names.index(rule_model_name) @@ -263,10 +267,14 @@ def _with_sharding_constraint_one_fallback( x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh ) +_AxisTypes = ( + str, types.NoneType, type(jax.sharding.PartitionSpec.UNCONSTRAINED) +) + def _is_logical_spec(x): return x is None or ( - isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x) + isinstance(x, tuple) and all(isinstance(e, _AxisTypes) for e in x) ) diff --git a/tests/linen/partitioning_test.py b/tests/linen/partitioning_test.py index d3a1e93ab6..1aa9d38817 100644 --- a/tests/linen/partitioning_test.py +++ b/tests/linen/partitioning_test.py @@ -92,42 +92,51 @@ def test_logical_to_mesh_axes_priorities(self): ) @parameterized.parameters( - dict( - rules=(('a', ('model', 'data')), ('b', 'data')), - axes=('a', 'b'), - expected=(('model', 'data'), None), - ), - dict( - rules=(('a', ('model', 'replica')), ('b', 'data')), - axes=('a', 'b'), - expected=(('model', 'replica'), 'data'), - ), - dict( - rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))), - axes=('a', 'b'), - expected=(('model', 'replica'), None), - ), - dict( - rules=(('a', ('model', 'replica')), ('b', 'model')), - axes=('a', 'b', 'c'), - expected=(('model', 'replica'), None, None), - ), - dict(rules=(), axes=('a', 'b', 'c'), expected=(None, None, None)), - dict( - rules=(('a', None), ('a', 'model')), - axes=('a', 'b'), - expected=(None, None), - ), - dict( - rules=( - ('baz', 'data'), - ('bar', None), - ('foo', 'model'), - ('foo', 'data'), + dict( + rules=(('a', ('model', 'data')), ('b', 'data')), + axes=('a', 'b'), + expected=(('model', 'data'), None), + ), + dict( + rules=(('a', ('model', 'replica')), ('b', 'data')), + axes=('a', 'b'), + expected=(('model', 'replica'), 'data'), + ), + dict( + rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))), + axes=('a', 'b'), + expected=(('model', 'replica'), None), + ), + dict( + rules=(('a', ('model', 'replica')), ('b', 'model')), + axes=('a', 'b', 'c'), + expected=(('model', 'replica'), None, None), + ), + dict(rules=(), axes=('a', 'b', 'c'), expected=(None, None, None)), + dict( + rules=(('a', None), ('a', 'model')), + axes=('a', 'b'), + expected=(None, None), + ), + dict( + rules=( + ('baz', 'data'), + ('bar', None), + ('foo', 'model'), + ('foo', 'data'), + ), + axes=('baz', 'bar', 'foo'), + expected=('data', None, 'model'), + ), + dict( + rules=(('baz', 'data'), ('foo', ('model', 'emb'))), + axes=('baz', jax.sharding.PartitionSpec.UNCONSTRAINED, 'foo'), + expected=( + 'data', + jax.sharding.PartitionSpec.UNCONSTRAINED, + ('model', 'emb'), + ), ), - axes=('baz', 'bar', 'foo'), - expected=('data', None, 'model'), - ), ) def test_logical_to_mesh_axes_cases(self, rules, axes, expected): with partitioning.axis_rules(rules): @@ -136,6 +145,7 @@ def test_logical_to_mesh_axes_cases(self, rules, axes, expected): @mock.patch('flax.linen.spmd._with_sharding_constraint') def test_with_sharding_constraint(self, wsc_fn): + unconstrained = jax.sharding.PartitionSpec.UNCONSTRAINED arr = jnp.ones((2, 2)) axes = ('foo', 'bar') partitioning.set_axis_rules(()) @@ -146,7 +156,11 @@ def test_with_sharding_constraint(self, wsc_fn): wsc_fn.assert_not_called() _ = partitioning.with_sharding_constraint(arr, axes) wsc_fn.assert_called_with( - arr, jax.sharding.PartitionSpec('data', 'model'), mesh=None + arr, jax.sharding.PartitionSpec('data', 'model'), mesh=None + ) + _ = partitioning.with_sharding_constraint(arr, ('foo', unconstrained)) + wsc_fn.assert_called_with( + arr, jax.sharding.PartitionSpec('data', unconstrained), mesh=None ) @mock.patch('flax.linen.spmd._with_sharding_constraint')