Skip to content

Commit

Permalink
Finish documentation of V2 API.
Browse files Browse the repository at this point in the history
Also fixes the decoding loop sampler to allow prompt padding on the right.

PiperOrigin-RevId: 647501304
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jun 28, 2024
1 parent b553f1f commit 9bdaa16
Show file tree
Hide file tree
Showing 14 changed files with 11,132 additions and 104 deletions.
88 changes: 55 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,70 @@ With Penzai, your neural networks could look like this:
Penzai is structured as a collection of modular tools, designed together but
each useable independently:

* `penzai.nn` (`pz.nn`): A declarative combinator-based neural network
library and an alternative to other neural network libraries like Flax, Haiku,
Keras, or Equinox, which exposes the full structure of your model's
forward pass in the model pytree. This means you can see everything your model
does by pretty printing it, and inject new runtime logic with `jax.tree_util`.
Like Equinox, there's no magic: models are just callable pytrees under the
hood.

* `penzai.treescope` (`pz.ts`): A superpowered interactive Python
pretty-printer, which works as a drop-in replacement for the ordinary

* A superpowered interactive Python pretty-printer:

* `penzai.treescope` (``pz.ts``): A drop-in replacement for the ordinary
IPython/Colab renderer. It's designed to help understand Penzai models and
other deeply-nested JAX pytrees, with built-in support for visualizing
arbitrary-dimensional NDArrays.

* `penzai.core.selectors` (`pz.select`): A pytree swiss-army-knife,
generalizing JAX's `.at[...].set(...)` syntax to arbitrary type-driven
pytree traversals, and making it easy to do complex rewrites or
on-the-fly patching of Penzai models and other data structures.
* A set of JAX tree and array manipulation utilities:

* `penzai.core.selectors` (``pz.select``): A pytree swiss-army-knife,
generalizing JAX's ``.at[...].set(...)`` syntax to arbitrary type-driven
pytree traversals, and making it easy to do complex rewrites or
on-the-fly patching of Penzai models and other data structures.

* `penzai.core.named_axes` (``pz.nx``): A lightweight named axis system which
lifts ordinary JAX functions to vectorize over named axes, and allows you to
seamlessly switch between named and positional programming styles without
having to learn a new array API.

* A declarative combinator-based neural network library, where models are
represented as easy-to-modify data structures:

* `penzai.core.named_axes` (`pz.nx`): A lightweight named axis system which
lifts ordinary JAX functions to vectorize over named axes, and allows you to
seamlessly switch between named and positional programming styles without
having to learn a new array API.
* `penzai.nn` (``pz.nn``): An alternative to other neural network libraries like
Flax, Haiku, Keras, or Equinox, which exposes the full structure of your model's
forward pass in the model pytree. This means you can see everything your model
does by pretty printing it, and inject new runtime logic with `jax.tree_util`.
Like Equinox, there's no magic: models are just callable pytrees under the
hood.

* `penzai.data_effects` (`pz.de`): An opt-in system for side arguments, random
numbers, and state variables that is built on pytree traversal and puts you
in control, without getting in the way of writing or using your model.
* `penzai.data_effects` (``pz.de``): An opt-in system for side arguments, random
numbers, and state variables that is built on pytree traversal and puts you
in control, without getting in the way of writing or using your model.

* **(NEW)** `penzai.experimental.v2`: An improved version of `penzai.nn` with
less boilerplate, including first-class support for mutable state and
parameter sharing.

* An implementation of the Gemma open-weights model using modular components and
named axes, built to enable interpretability and model surgery research.

* **(NEW)** The V2 version also supports Llama, Mistral, and GPT-NeoX / Pythia
models!

Documentation on Penzai can be found at
[https://penzai.readthedocs.io](https://penzai.readthedocs.io).

> [!WARNING]
> Penzai's API is currently unstable and may change in future releases.
> [!IMPORTANT]
> Penzai currently has two versions of its neural network API: the original
> "V1" API, and a new "V2" API located in `penzai.experimental.v2`.
>
> In particular, the way Penzai handles parameter initialization, parameter
> sharing, and local mutable state in `penzai.nn` and
> `penzai.data_effects` is likely to be simplified in the future.
> Some internal details of the `treescope` pretty-printer intermediate
> representation may also change to make it easier to extend and configure.
> The V2 API aims to be simpler and more flexible, by introducing first-class
> support for mutable state and parameter sharing, and removing unnecessary
> boilerplate. It also includes a more flexible transformer implementation with
> support for more pretrained model variants. You can read about the
> differences between the two APIs in the
> ["Changes in the V2 API"](v2_differences) overview.
>
> Projects that use Penzai's neural network components or model implementations,
> or that define their own handlers for `treescope`, are encouraged to pin the
> `0.1.x` release series (e.g. `penzai>=0.1,<0.2`) to avoid breaking changes.
> We plan to stabilize the V2 API and move it out of experimental in release
> ``0.2.0``, replacing the V1 API. If you wish to keep the V1 behavior, we
> recommend pinning the ``0.1.x`` release series (e.g. ``penzai>=0.1,<0.2``)
> to avoid breaking changes.
[v2_differences]: https://penzai.readthedocs.io/en/stable/guides/v2_differences.html


## Getting Started
Expand Down Expand Up @@ -128,10 +149,11 @@ output, intermediates = mlp_with_captured_activations(
```

To learn more about how to build and manipulate neural networks with Penzai,
we recommend starting with the ["How to Think in Penzai" tutorial][], or one
we recommend starting with the "How to Think in Penzai" tutorial ([V1 API version][how_to_think_1], [V2 API version][how_to_think_2]), or one
of the other tutorials in the [Penzai documentation][].

["How to Think in Penzai" tutorial]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[how_to_think_1]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[how_to_think_2]: https://penzai.readthedocs.io/en/stable/notebooks/v2_how_to_think_in_penzai.html
[Penzai documentation]: https://penzai.readthedocs.io


Expand Down
1 change: 1 addition & 0 deletions docs/_autogen_root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@
:hidden:

notebooks/induction_heads_2B
notebooks/v2_induction_heads_2B
_include/_glue_figures
44 changes: 25 additions & 19 deletions docs/_include/_glue_figures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import myst_nb\n",
"\n",
"import penzai\n",
"from penzai import pz\n",
"from penzai.experimental.v2 import pz\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -53,7 +53,8 @@
},
"outputs": [],
"source": [
"from penzai.example_models import gemma"
"from penzai.experimental.v2.models.transformer.variants import gemma\n",
"from penzai.experimental.v2.models.transformer.variants import llamalike_common"
]
},
{
Expand Down Expand Up @@ -103,26 +104,28 @@
" )\n",
" flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)\n",
"\n",
" model = gemma.model_core.GemmaTransformer.from_pretrained(\n",
" model = gemma.gemma_from_pretrained_checkpoint(\n",
" flat_params, upcast_activations_to_float32=False\n",
" )\n",
"\n",
"else:\n",
" model = pz.nn.initialize_parameters(\n",
" gemma.model_core.GemmaTransformer.from_config(\n",
" gemma.model_core.GemmaTransformerConfig(\n",
" num_heads=8,\n",
" embedding_dim=256,\n",
" projection_dim=32,\n",
" single_kv_head=False,\n",
" mlp_hidden_dim=512,\n",
" num_decoder_blocks=10,\n",
" vocab_size=1000,\n",
" parameter_dtype=jnp.float32,\n",
" activation_dtype=jnp.float32,\n",
" )\n",
" model = llamalike_common.build_llamalike_transformer(\n",
" llamalike_common.LlamalikeTransformerConfig(\n",
" num_kv_heads=8,\n",
" query_head_multiplier=1,\n",
" embedding_dim=256,\n",
" projection_dim=32,\n",
" mlp_hidden_dim=512,\n",
" num_decoder_blocks=10,\n",
" vocab_size=1000,\n",
" mlp_variant=\"geglu_approx\",\n",
" rope_wavelength=10_000,\n",
" tie_embedder_and_logits=True,\n",
" use_layer_stack=False,\n",
" parameter_dtype=jnp.float32,\n",
" activation_dtype=jnp.float32,\n",
" ),\n",
" jax.random.key(1),\n",
" init_base_rng=jax.random.key(42),\n",
" )"
]
},
Expand All @@ -138,8 +141,8 @@
"\n",
"with IPython.utils.capture.capture_output() as capturer:\n",
" pz.select(model).at(lambda root: (\n",
" root.body.body.body.sublayers[2].sublayers[0].delta.sublayers[1].input_to_query,\n",
" root.body.body.body.sublayers[2].sublayers[1].delta.sublayers[1],\n",
" root.body.sublayers[2].sublayers[0].delta.sublayers[1].input_to_query,\n",
" root.body.sublayers[2].sublayers[1].delta.sublayers[1],\n",
" )).show_value()"
]
},
Expand All @@ -161,6 +164,9 @@
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
Expand Down
4 changes: 2 additions & 2 deletions docs/api/penzai.experimental.v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ Language Modeling
pz.nn.Attention
pz.nn.KVCachingAttention
pz.nn.ApplyExplicitAttentionMask
pz.nnApplyCausalAttentionMask
pz.nnApplyCausalSlidingWindowAttentionMask
pz.nn.ApplyCausalAttentionMask
pz.nn.ApplyCausalSlidingWindowAttentionMask
pz.nn.EmbeddingTable
pz.nn.EmbeddingLookup
pz.nn.EmbeddingDecode
Expand Down
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ def filter(self, record):
'py',
]
hoverxref_role_types = {
'obj': 'tooltip',
'class': 'tooltip',
'exc': 'tooltip',
'func': 'tooltip',
'mod': 'tooltip',
'obj': 'tooltip',
}

# -- Source code links -------------------------------------------------------
Expand Down
29 changes: 25 additions & 4 deletions docs/guides/howto_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from penzai.experimental.v2 import pz

## Visualization

This is a short overview of some of Penzai's visualization tools. For more, see the tutorials on [pretty printing](../notebooks/treescope_prettyprinting.ipynb) and [array visualization]](../notebooks/treescope_arrayviz.ipynb).
This is a short overview of some of Penzai's visualization tools. For more, see the tutorials on [pretty printing](../notebooks/treescope_prettyprinting.ipynb) and [array visualization](../notebooks/treescope_arrayviz.ipynb).

### Setting up pretty-printing
When using Penzai in IPython notebooks, it's recommended to set up Penzai's pretty-printer as the default pretty-printer and turn on array autovisualization. You can do this by running
Expand Down Expand Up @@ -220,17 +220,38 @@ You can read more about Penzai's conventions for layers in "How to Think in Penz

Penzai's Gemma implementation includes a conversion utility that converts the ["Flax" model weights from Kaggle](https://www.kaggle.com/models/google/gemma) into the correct form. You can load it using:

```
```python
import kagglehub
import orbax.checkpoint
from penzai.experimental.v2.models.transformer.variants import gemma
from penzai.experimental.v2.models.transformer import variants

weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')

checkpointer = orbax.checkpoint.PyTreeCheckpointer()
flax_params_dict = checkpointer.restore(ckpt_path)
model = gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
```

### Loading Llama, Mistral, or GPT-NeoX / Pythia

Penzai also includes re-implementations of the architectures used by [Llama](https://llama.meta.com/), [Mistral](https://mistral.ai/), and the [GPT-NeoX](https://www.eleuther.ai/artifacts/gpt-neox-20b) family of models, including the [Pythia](https://github.com/EleutherAI/pythia) model scaling suite. To load these models into Penzai, you can first load the weights using the HuggingFace `transformers` library, then convert them to Penzai:

```python
import transformers
from penzai.experimental.v2.models.transformer import variants

# To load a Llama model:
hf_model = transformers.LlamaForCausalLM.from_pretrained(...)
pz_model = variants.llama.llama_from_huggingface_model(hf_model)

# To load a Mistral model:
hf_model = transformers.MistralForCausalLM.from_pretrained(...)
pz_model = variants.mistral.mistral_from_huggingface_model(hf_model)

# To load a GPT-NeoX / Pythia model:
hf_model = transformers.GPTNeoXForCausalLM.from_pretrained(...)
pz_model = variants.gpt_neox.gpt_neox_from_huggingface_model(hf_model)
```

### Freezing pretrained model weights
Expand Down
23 changes: 22 additions & 1 deletion docs/guides/v2_differences.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Differences between the v1 and v2 neural network APIs
# Changes in the V2 API

Penzai includes two neural network APIs:

Expand All @@ -21,6 +21,8 @@ This document explains the major changes in the V2 API, relative to the V1 API.
- The data-effect system is no longer used.
- Parameter sharing, state, and side outputs will instead use `Parameter` and `StateVariable`.
- Side inputs should be passed as keyword arguments.
- The built-in Transformer implementation also supports loading Llama, Mistral, and GPT-NeoX / Pythia models.
- This implementation is in `penzai.experimental.v2.models.transformer`, and shares the same high-level interface across all transformer variants.

For Penzai release v0.2.0, we are planning to move the V2 API to penzai.nn, and deprecate the original system.
(This will be a **breaking change** to `penzai.nn` and Penzai's existing model implementations.)
Expand Down Expand Up @@ -366,3 +368,22 @@ output, new_vars = model_without_vars.call_with_local_vars(
for k, var in vars.items:
var.value = new_vars[k].value
```


### Loading pretrained transformers

The V2 API includes a new transformer implementation with support for additional transformer variants. If you are using the current Gemma model, you will need to change how you load it:

```python
# Old
from penzai.example_models import gemma
model = gemma.model_core.GemmaTransformer.from_pretrained(flax_params_dict)
# (model is an instance of GemmaTransformer)

# New
from penzai.experimental.v2.models.transformer import variants
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
# (model is an instance of TransformerLM)
```

Additionally, the types of various model components have changed to become more generic (e.g. `TransformerFeedForward` instead of `GemmaFeedForward`).
Loading

0 comments on commit 9bdaa16

Please sign in to comment.