Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming functions in tree_utils #1163

Open
Ata-Shaker opened this issue Dec 29, 2024 · 6 comments
Open

Renaming functions in tree_utils #1163

Ata-Shaker opened this issue Dec 29, 2024 · 6 comments
Assignees

Comments

@Ata-Shaker
Copy link

Hi, I was curious if there’s a specific reason why all the functions in tree_utils start with "tree". Are there any plans to simplify the naming, similar to how JAX has done?
Thanks,
Ata

@rdyro
Copy link
Collaborator

rdyro commented Jan 2, 2025

Adding the tree_ prefix makes the function names a little more explicit. JAX's jax.tree module is just an alias module.

We currently don't have plans to support this, but if you'd like, we'd love a PR! We would need a file similar to https://github.com/jax-ml/jax/blob/main/jax/tree.py

Would you be interested in contributing?

@rdyro rdyro self-assigned this Jan 2, 2025
@Ata-Shaker
Copy link
Author

Sure! But I am using a Mac. Would that be a problem?

@rdyro
Copy link
Collaborator

rdyro commented Jan 2, 2025

Not a problem at all

The general workflow of fork the repo, make changes in your own branch and then create pull request on github works!

Take a look here: https://github.com/google-deepmind/optax/blob/main/CONTRIBUTING.md for signing the Contributor License Agreement

Let me know if you have any questions!

@Ata-Shaker
Copy link
Author

Ata-Shaker commented Jan 2, 2025

I created the file "tree.py" just like in JAX and renamed the functions. However, I am not sure where to put the file. Is it ok just to leave as "./optax/tree.py"?
Also, should I be worried about the following message I received after running test.sh:
"************* Module optax.tree
optax/tree.py:3:0: C0301: Line too long (90/80) (line-too-long)
optax/tree.py:14:0: W0622: Redefining built-in 'set' (redefined-builtin)
optax/tree.py:21:0: W0622: Redefining built-in 'max' (redefined-builtin)
optax/tree.py:21:0: W0622: Redefining built-in 'sum' (redefined-builtin)
optax/tree.py:14:0: C0414: Import alias does not rename original package (useless-import-alias)
optax/tree.py:5:0: W0611: Unused tree_cast imported from optax.tree_utils._casting as cast (unused-import)
optax/tree.py:5:0: W0611: Unused tree_dtype imported from optax.tree_utils._casting as dtype (unused-import)
optax/tree.py:9:0: W0611: Unused tree_random_like imported from optax.tree_utils._random as random_like (unused-import)
optax/tree.py:9:0: W0611: Unused tree_split_key_like imported from optax.tree_utils._random as split_key_like (unused-import)
optax/tree.py:14:0: W0611: Unused NamedTupleKey imported from optax.tree_utils._state_utils as NamedTupleKey (unused-import)
optax/tree.py:14:0: W0611: Unused tree_get imported from optax.tree_utils._state_utils as get (unused-import)
optax/tree.py:14:0: W0611: Unused tree_get_all_with_path imported from optax.tree_utils._state_utils as get_all_with_path (unused-import)
optax/tree.py:14:0: W0611: Unused tree_map_params imported from optax.tree_utils._state_utils as map_params (unused-import)
optax/tree.py:14:0: W0611: Unused tree_set imported from optax.tree_utils._state_utils as set (unused-import)
optax/tree.py:21:0: W0611: Unused tree_add imported from optax.tree_utils._tree_math as add (unused-import)
optax/tree.py:21:0: W0611: Unused tree_add_scalar_mul imported from optax.tree_utils._tree_math as add_scalar_mul (unused-import)
optax/tree.py:21:0: W0611: Unused tree_bias_correction imported from optax.tree_utils._tree_math as bias_correction (unused-import)
optax/tree.py:21:0: W0611: Unused tree_clip imported from optax.tree_utils._tree_math as clip (unused-import)
optax/tree.py:21:0: W0611: Unused tree_conj imported from optax.tree_utils._tree_math as conj (unused-import)
optax/tree.py:21:0: W0611: Unused tree_div imported from optax.tree_utils._tree_math as div (unused-import)
optax/tree.py:21:0: W0611: Unused tree_full_like imported from optax.tree_utils._tree_math as full_like (unused-import)
optax/tree.py:21:0: W0611: Unused tree_l1_norm imported from optax.tree_utils._tree_math as l1_norm (unused-import)
optax/tree.py:21:0: W0611: Unused tree_l2_norm imported from optax.tree_utils._tree_math as l2_norm (unused-import)
optax/tree.py:21:0: W0611: Unused tree_linf_norm imported from optax.tree_utils._tree_math as linf_norm (unused-import)
optax/tree.py:21:0: W0611: Unused tree_max imported from optax.tree_utils._tree_math as max (unused-import)
optax/tree.py:21:0: W0611: Unused tree_mul imported from optax.tree_utils._tree_math as mul (unused-import)
optax/tree.py:21:0: W0611: Unused tree_ones_like imported from optax.tree_utils._tree_math as ones_like (unused-import)
optax/tree.py:21:0: W0611: Unused tree_real imported from optax.tree_utils._tree_math as real (unused-import)
optax/tree.py:21:0: W0611: Unused tree_scalar_mul imported from optax.tree_utils._tree_math as scalar_mul (unused-import)
optax/tree.py:21:0: W0611: Unused tree_sub imported from optax.tree_utils._tree_math as sub (unused-import)
optax/tree.py:21:0: W0611: Unused tree_sum imported from optax.tree_utils._tree_math as sum (unused-import)
optax/tree.py:21:0: W0611: Unused tree_update_infinity_moment imported from optax.tree_utils._tree_math as update_infinity_moment (unused-import)
optax/tree.py:21:0: W0611: Unused tree_update_moment imported from optax.tree_utils._tree_math as update_moment (unused-import)
optax/tree.py:21:0: W0611: Unused tree_update_moment_per_elem_norm imported from optax.tree_utils._tree_math as update_moment_per_elem_norm (unused-import)
optax/tree.py:21:0: W0611: Unused tree_vdot imported from optax.tree_utils._tree_math as vdot (unused-import)
optax/tree.py:21:0: W0611: Unused tree_where imported from optax.tree_utils._tree_math as where (unused-import)
optax/tree.py:21:0: W0611: Unused tree_zeros_like imported from optax.tree_utils._tree_math as zeros_like (unused-import)


Your code has been rated at 9.91/10

The following messages were raised:

  • warning message issued
  • convention message issued

Fatal messages detected. Failing..."?

@rdyro
Copy link
Collaborator

rdyro commented Jan 3, 2025

You'll need to disable all unused-import warnings for the whole file and specific warnings per-line.

Place the tree.py file in optax/_src and import it in optax/__init__.py like so: from optax._src import tree

A quick reference: https://stackoverflow.com/questions/28829236/is-it-possible-to-ignore-one-single-specific-line-with-pylint

@Ata-Shaker
Copy link
Author

Ata-Shaker commented Jan 3, 2025

Thanks for the guidance. I have created the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants