From b62b88ba540865823ab205896718bb1a9579b85d Mon Sep 17 00:00:00 2001 From: Ata Shaker Date: Thu, 2 Jan 2025 19:42:16 +0100 Subject: [PATCH 1/3] Created alias file for functions in optax.tree_utils and dropped "tree" from them --- optax/tree.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 optax/tree.py diff --git a/optax/tree.py b/optax/tree.py new file mode 100644 index 000000000..8efe08c04 --- /dev/null +++ b/optax/tree.py @@ -0,0 +1,46 @@ +"""Utilities for working with tree-like container data structures. + +The :mod:`optax.tree` namespace contains aliases of utilities from :mod:`optax.tree_util`. +""" +from optax.tree_utils._casting import ( +tree_cast as cast, +tree_dtype as dtype, +) +from optax.tree_utils._random import ( + tree_random_like as random_like, + tree_split_key_like as split_key_like, +) + +from optax.tree_utils._state_utils import ( + NamedTupleKey as NamedTupleKey, + tree_get as get, + tree_get_all_with_path as get_all_with_path, + tree_map_params as map_params, + tree_set as set, +) +from optax.tree_utils._tree_math import ( + tree_add as add, + tree_add_scalar_mul as add_scalar_mul, + tree_bias_correction as bias_correction, + tree_clip as clip, + tree_conj as conj, + tree_div as div, + tree_full_like as full_like, + tree_l1_norm as l1_norm, + tree_l2_norm as l2_norm, + tree_linf_norm as linf_norm, + tree_max as max, + tree_mul as mul, + tree_ones_like as ones_like, + tree_real as real, + tree_scalar_mul as scalar_mul, + tree_sub as sub, + tree_sum as sum, + tree_update_infinity_moment as update_infinity_moment, + tree_update_moment as update_moment, + tree_update_moment_per_elem_norm as update_moment_per_elem_norm, + tree_vdot as vdot, + tree_where as where, + tree_zeros_like as zeros_like, +) + From 1c4215f2d69bc8e80a98f49d83131e39aaa69e28 Mon Sep 17 00:00:00 2001 From: Ata Shaker Date: Fri, 3 Jan 2025 18:21:05 +0100 Subject: [PATCH 2/3] Move 'optax.tree' to '_src' --- optax/__init__.py | 1 + optax/{ => _src}/tree.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) rename optax/{ => _src}/tree.py (89%) diff --git a/optax/__init__.py b/optax/__init__.py index 39203cdc9..e0e4fb4db 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -142,6 +142,7 @@ from optax._src.utils import multi_normal from optax._src.utils import scale_gradient from optax._src.utils import value_and_grad_from_state +from optax._src import tree # TODO(mtthss): remove contrib aliases from flat namespace once users updated. # Deprecated modules diff --git a/optax/tree.py b/optax/_src/tree.py similarity index 89% rename from optax/tree.py rename to optax/_src/tree.py index 8efe08c04..6807822fe 100644 --- a/optax/tree.py +++ b/optax/_src/tree.py @@ -1,7 +1,10 @@ +# pylint: disable=line-too-long +# pylint: disable=redefined-builtin """Utilities for working with tree-like container data structures. The :mod:`optax.tree` namespace contains aliases of utilities from :mod:`optax.tree_util`. """ +# pylint: disable=unused-import from optax.tree_utils._casting import ( tree_cast as cast, tree_dtype as dtype, @@ -12,7 +15,7 @@ ) from optax.tree_utils._state_utils import ( - NamedTupleKey as NamedTupleKey, + NamedTupleKey, tree_get as get, tree_get_all_with_path as get_all_with_path, tree_map_params as map_params, @@ -41,6 +44,5 @@ tree_update_moment_per_elem_norm as update_moment_per_elem_norm, tree_vdot as vdot, tree_where as where, - tree_zeros_like as zeros_like, + tree_zeros_like as zeros_like ) - From 7887400c5b0f4422b475f2debb5cc9f86eccc797 Mon Sep 17 00:00:00 2001 From: Ata Shaker Date: Thu, 9 Jan 2025 01:32:48 +0100 Subject: [PATCH 3/3] Reformat the aliases as seperate imports --- optax/_src/tree.py | 83 +++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 45 deletions(-) diff --git a/optax/_src/tree.py b/optax/_src/tree.py index 6807822fe..d434bd1c2 100644 --- a/optax/_src/tree.py +++ b/optax/_src/tree.py @@ -1,48 +1,41 @@ -# pylint: disable=line-too-long -# pylint: disable=redefined-builtin -"""Utilities for working with tree-like container data structures. +"""Utilities for working with tree-like container data structures. The +:mod:`optax.tree` namespace contains aliases of utilities from +:mod:`optax.tree_util`.""" -The :mod:`optax.tree` namespace contains aliases of utilities from :mod:`optax.tree_util`. -""" # pylint: disable=unused-import -from optax.tree_utils._casting import ( -tree_cast as cast, -tree_dtype as dtype, -) -from optax.tree_utils._random import ( - tree_random_like as random_like, - tree_split_key_like as split_key_like, -) +from optax.tree_utils._casting import tree_cast as cast +from optax.tree_utils._casting import tree_dtype as dtype + +from optax.tree_utils._random import tree_random_like as random_like +from optax.tree_utils._random import tree_split_key_like as split_key_like + +from optax.tree_utils._state_utils import NamedTupleKey +from optax.tree_utils._state_utils import tree_get as get +from optax.tree_utils._state_utils import tree_get_all_with_path as get_all_with_path +from optax.tree_utils._state_utils import tree_map_params as map_params +from optax.tree_utils._state_utils import tree_set as set # pylint: disable=redefined-builtin + +from optax.tree_utils._tree_math import tree_add as add +from optax.tree_utils._tree_math import tree_add_scalar_mul as add_scalar_mul +from optax.tree_utils._tree_math import tree_bias_correction as bias_correction +from optax.tree_utils._tree_math import tree_clip as clip +from optax.tree_utils._tree_math import tree_conj as conj +from optax.tree_utils._tree_math import tree_div as div +from optax.tree_utils._tree_math import tree_full_like as full_like +from optax.tree_utils._tree_math import tree_l1_norm as l1_norm +from optax.tree_utils._tree_math import tree_l2_norm as l2_norm +from optax.tree_utils._tree_math import tree_linf_norm as linf_norm +from optax.tree_utils._tree_math import tree_max as max # pylint: disable=redefined-builtin +from optax.tree_utils._tree_math import tree_mul as mul +from optax.tree_utils._tree_math import tree_ones_like as ones_like +from optax.tree_utils._tree_math import tree_real as real +from optax.tree_utils._tree_math import tree_scalar_mul as scalar_mul +from optax.tree_utils._tree_math import tree_sub as sub +from optax.tree_utils._tree_math import tree_sum as sum # pylint: disable=redefined-builtin +from optax.tree_utils._tree_math import tree_update_infinity_moment as update_infinity_moment +from optax.tree_utils._tree_math import tree_update_moment as update_moment +from optax.tree_utils._tree_math import tree_update_moment_per_elem_norm as update_moment_per_elem_norm +from optax.tree_utils._tree_math import tree_vdot as vdot +from optax.tree_utils._tree_math import tree_where as where +from optax.tree_utils._tree_math import tree_zeros_like as zeros_like -from optax.tree_utils._state_utils import ( - NamedTupleKey, - tree_get as get, - tree_get_all_with_path as get_all_with_path, - tree_map_params as map_params, - tree_set as set, -) -from optax.tree_utils._tree_math import ( - tree_add as add, - tree_add_scalar_mul as add_scalar_mul, - tree_bias_correction as bias_correction, - tree_clip as clip, - tree_conj as conj, - tree_div as div, - tree_full_like as full_like, - tree_l1_norm as l1_norm, - tree_l2_norm as l2_norm, - tree_linf_norm as linf_norm, - tree_max as max, - tree_mul as mul, - tree_ones_like as ones_like, - tree_real as real, - tree_scalar_mul as scalar_mul, - tree_sub as sub, - tree_sum as sum, - tree_update_infinity_moment as update_infinity_moment, - tree_update_moment as update_moment, - tree_update_moment_per_elem_norm as update_moment_per_elem_norm, - tree_vdot as vdot, - tree_where as where, - tree_zeros_like as zeros_like -)