Skip to content

Commit

Permalink
Adjust common block colors, fix README.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660128213
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Aug 6, 2024
1 parent a4fbb59 commit 8a745ae
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,10 @@ intermediates = var.value
```

To learn more about how to build and manipulate neural networks with Penzai,
we recommend starting with the "How to Think in Penzai" tutorial ([V1 API version][how_to_think_1], [V2 API version][how_to_think_2]), or one
of the other tutorials in the [Penzai documentation][].
we recommend starting with the ["How to Think in Penzai" tutorial][how_to_think]
or one of the other tutorials in the [Penzai documentation][].

[how_to_think_1]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[how_to_think_2]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[how_to_think]: https://penzai.readthedocs.io/en/stable/notebooks/how_to_think_in_penzai.html
[Penzai documentation]: https://penzai.readthedocs.io


Expand Down
4 changes: 4 additions & 0 deletions penzai/models/transformer/model_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class TransformerFeedForward(pz.nn.Sequential):
class TransformerBlock(pz.nn.Sequential):
"""Informatively-named Sequential subclass for the main transformer blocks."""

def treescope_color(self):
color = "oklch(0.785 0.103 186.9 / 1.0)"
return color, f"color-mix(in oklab, {color} 25%, white)"


@pz.pytree_dataclass
class TransformerLM(pz.nn.Layer):
Expand Down
3 changes: 3 additions & 0 deletions penzai/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def __call__(
output = self.attn_value_to_output((attn, value), **side_inputs)
return output

def treescope_color(self):
return "oklch(0.785 0.103 38.5 / 1.0)"


@struct.pytree_dataclass
class KVCachingAttention(layer_base.Layer):
Expand Down

0 comments on commit 8a745ae

Please sign in to comment.