From 4b9000c9a123df388528af3fc3bd36b14dffeea0 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 4 Dec 2024 15:46:08 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886885 --- README.md | 4 ++-- docs/index.rst | 2 +- docs/notebooks/non_trainable.ipynb | 2 +- examples/haiku_lstms.ipynb | 2 +- examples/impala_lite.py | 2 +- haiku/_src/batch_norm_test.py | 2 +- haiku/_src/embed.py | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 07a10d55e..d1ff96f0c 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ Haiku is written in pure Python, but depends on C++ code via JAX. Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in `requirements.txt`. -First, follow [these instructions](https://github.com/google/jax#installation) +First, follow [these instructions](https://github.com/jax-ml/jax#installation) to install JAX with the relevant accelerator support. Then, install Haiku using pip: @@ -462,7 +462,7 @@ In this bibtex entry, the version number is intended to be from [`haiku/__init__.py`](https://github.com/deepmind/dm-haiku/blob/main/haiku/__init__.py), and the year corresponds to the project's open-source release. -[JAX]: https://github.com/google/jax +[JAX]: https://github.com/jax-ml/jax [Sonnet]: https://github.com/deepmind/sonnet [Tensorflow]: https://github.com/tensorflow/tensorflow [Flax]: https://github.com/google/flax diff --git a/docs/index.rst b/docs/index.rst index 4746e661d..8cf3e0edd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,7 +26,7 @@ abstractions for machine learning research. Installation ------------ -See https://github.com/google/jax#pip-installation for instructions on +See https://github.com/jax-ml/jax#pip-installation for instructions on installing JAX. We suggest installing the latest version of Haiku by running:: diff --git a/docs/notebooks/non_trainable.ipynb b/docs/notebooks/non_trainable.ipynb index 28f217f57..22a508665 100644 --- a/docs/notebooks/non_trainable.ipynb +++ b/docs/notebooks/non_trainable.ipynb @@ -63,7 +63,7 @@ "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", - "/tmp/haiku-docs-env/lib/python3.8/site-packages/jax/_src/lax/lax.py:6271: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n", + "/tmp/haiku-docs-env/lib/python3.8/site-packages/jax/_src/lax/lax.py:6271: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n", " warnings.warn(msg.format(dtype, fun_name , truncated_dtype))\n" ] }, diff --git a/examples/haiku_lstms.ipynb b/examples/haiku_lstms.ipynb index a856fe507..6a222db98 100644 --- a/examples/haiku_lstms.ipynb +++ b/examples/haiku_lstms.ipynb @@ -46,7 +46,7 @@ "source": [ "# LSTMs in Haiku\n", "\n", - "**[Haiku](https://github.com/deepmind/dm-haiku) is a simple neural network library for [JAX](https://github.com/google/jax).**\n", + "**[Haiku](https://github.com/deepmind/dm-haiku) is a simple neural network library for [JAX](https://github.com/jax-ml/jax).**\n", "\n", "This notebook walks through a simple LSTM in JAX with Haiku.\n", "\n", diff --git a/examples/impala_lite.py b/examples/impala_lite.py index 50a14ad22..ce3f114f0 100644 --- a/examples/impala_lite.py +++ b/examples/impala_lite.py @@ -91,7 +91,7 @@ def step( def loss(self, params: hk.Params, trajs: Transition) -> jax.Array: """Computes a loss of trajs wrt params.""" # Re-run the agent over the trajectories. - # Due to https://github.com/google/jax/issues/1459, we use hk.BatchApply + # Due to https://github.com/jax-ml/jax/issues/1459, we use hk.BatchApply # instead of vmap. # BatchApply turns the input tensors from [T, B, ...] into [T*B, ...]. # We `functools.partial` params in so it does not get transformed. diff --git a/haiku/_src/batch_norm_test.py b/haiku/_src/batch_norm_test.py index 6bea97960..4a3493b28 100644 --- a/haiku/_src/batch_norm_test.py +++ b/haiku/_src/batch_norm_test.py @@ -174,7 +174,7 @@ def test_no_offset_beta_init_provided(self): offset_init=jnp.zeros) def test_eps_cast_to_var_dtype(self): - # See https://github.com/google/jax/issues/4718 for more info. In the + # See https://github.com/jax-ml/jax/issues/4718 for more info. In the # context of this test we need to assert NumPy bf16 params/state and a # Python float for eps preserve bf16 output. diff --git a/haiku/_src/embed.py b/haiku/_src/embed.py index 52c87ca91..e8156ea1b 100644 --- a/haiku/_src/embed.py +++ b/haiku/_src/embed.py @@ -177,7 +177,7 @@ def __call__( # it along the row dimension and treat each row as a separate index into # one of the dimensions of the array. The error only surfaces when # indexing with DeviceArray, while indexing with numpy.ndarray works fine. - # See https://github.com/google/jax/issues/620 for more details. + # See https://github.com/jax-ml/jax/issues/620 for more details. # Cast to a jnp array in case `ids` is a tracer (eg un a dynamic_unroll). return jnp.asarray(self.embeddings)[(ids,)]