Skip to content

Commit

Permalink
updated haiku guide
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed May 5, 2024
1 parent d0e080d commit 331f138
Show file tree
Hide file tree
Showing 16 changed files with 246 additions and 222 deletions.
140 changes: 106 additions & 34 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
Use directive as follows:
.. codediff::
:title_left: <LEFT_CODE_BLOCK_TITLE>
:title_right: <RIGHT_CODE_BLOCK_TITLE>
:title: <LEFT_CODE_BLOCK_TITLE>, <RIGHT_CODE_BLOCK_TITLE>
<CODE_BLOCK_LEFT>
---
<CODE_BLOCK_RIGHT>
In order to highlight a line of code, append "#!" to it.
"""
from typing import List, Tuple

from typing import List, Optional, Tuple

import sphinx
from docutils import nodes
Expand All @@ -40,29 +40,99 @@
class CodeDiffParser:
def parse(
self,
lines,
title_left='Base',
title_right='Diff',
code_sep='---',
sync=MISSING,
lines: List[str],
title: str,
groups: Optional[List[str]] = None,
skip_test: Optional[str] = None,
code_sep: str = '---',
sync: object = MISSING,
):
sync = sync is not MISSING
"""Parse the code diff block and format it so that it
renders in different tabs and is tested by doctest.
if code_sep not in lines:
raise ValueError(
'Code separator not found! Code snippets should be '
f'separated by {code_sep}.'
)
idx = lines.index(code_sep)
code_left = self._code_block(lines[0:idx])
test_code = lines[idx + 1 :]
code_right = self._code_block(test_code)

output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)
For example:
.. testcode:: tab0, tab2, tab3
<CODE_BLOCK_A>
.. codediff::
:title: Tab 0, Tab 1, Tab 2, Tab 3
:groups: tab0, tab1, tab2, tab3
:skip_test: tab1, tab3
<CODE_BLOCK_B0>
---
return output, test_code
<CODE_BLOCK_B1>
---
<CODE_BLOCK_B2>
---
<CODE_BLOCK_B3>
For group tab0: <CODE_BLOCK_A> and <CODE_BLOCK_B0> are executed.
For group tab1: Nothing is executed.
For group tab2: <CODE_BLOCK_A> and <CODE_BLOCK_B2> are executed.
For group tab3: <CODE_BLOCK_A> is executed.
Arguments:
lines: a string list, where each element is a single string code line
title: a single string that contains the titles of each tab (they should
be separated by commas)
groups: a single string that contains the group of each tab (they should
be separated by commas). Code snippets that are part of the same group
will be executed together. If groups=None, then the group names will
default to '0', '1', '2', etc...
skip_test: a single string denoting which group(s) to skip testing (they
should be separated by commas). This is useful for legacy code snippets
that no longer run correctly anymore. If skip_test=None, then no tests
are skipped.
code_sep: the separator character(s) used to denote a separate code block
for a new tab. The default code separator is '---'.
sync: an option for Sphinx directives, that will sync all tabs together.
This means that if the user clicks to switch to another tab, all tabs
will switch to the new tab.
"""
title = [t.strip() for t in title.split(',')]
num_tabs = len(title)

sync = sync is not MISSING
# skip legacy code snippets in upgrade guides
if skip_test is not None:
skip_test = set([index.strip() for index in skip_test.split(',')])
else:
skip_test = set()

code_blocks = '\n'.join(lines)
if code_blocks.count(code_sep) != num_tabs-1:
raise ValueError(f'Expected {num_tabs-1} code separators for {num_tabs} tabs, but got {code_blocks.count(code_sep)} code separators instead.')
code_blocks = [code_block.split('\n') for code_block in code_blocks.split(code_sep+'\n')] # list[code_tab_list1[string_line1, ...], ...]

# TODO: test codediff_test.py is actually running on CI but forcing an error to be thrown
# TODO: test different groups and multiple tabs
# TODO: test visual render on RTD
# by default, put each code snippet in a different group denoted by an index number, to be executed separately
if groups is not None:
groups = [group_name.strip() for group_name in groups.split(',')]
else:
groups = [str(i) for i in range(num_tabs)]
if len(groups) != num_tabs:
raise ValueError(f'Expected {num_tabs} group assignments for {num_tabs} tabs, but got {len(groups)} group assignments instead.')

tabs = []
test_codes = []
for i, code_block in enumerate(code_blocks):
if groups[i] not in skip_test:
test_codes.append((code_block, groups[i]))
tabs.append((title[i], self._code_block(code_block)))
output = self._tabs(*tabs, sync=sync)

return output, test_codes

def _code_block(self, lines):
"""Creates a codeblock."""
Expand Down Expand Up @@ -99,36 +169,38 @@ def _tabs(self, *contents: Tuple[str, List[str]], sync):
class CodeDiffDirective(SphinxDirective):
has_content = True
option_spec = {
'title_left': directives.unchanged,
'title_right': directives.unchanged,
'title': directives.unchanged,
'groups': directives.unchanged,
'skip_test': directives.unchanged,
'code_sep': directives.unchanged,
'sync': directives.flag,
}

def run(self):
table_code, test_code = CodeDiffParser().parse(
table_code, test_codes = CodeDiffParser().parse(
list(self.content), **self.options
)

# Create a test node as a comment node so it won't show up in the docs.
# We add attribute "testnodetype" so it is be picked up by the doctest
# builder. This functionality is not officially documented but can be found
# in the source code:
# https://github.com/sphinx-doc/sphinx/blob/3.x/sphinx/ext/doctest.py
# https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py
# (search for 'testnodetype').
test_code = '\n'.join(test_code)
test_node = nodes.comment(test_code, test_code, testnodetype='testcode')
# Set the source info so the error message is correct when testing.
self.set_source_info(test_node)
test_node['options'] = {}
test_node['language'] = 'python3'
test_nodes = []
for test_code, group in test_codes:
test_node = nodes.comment('\n'.join(test_code), '\n'.join(test_code), testnodetype='testcode', groups=[group])
self.set_source_info(test_node)
test_node['options'] = {}
test_node['language'] = 'python3'
test_nodes.append(test_node)

# The table node is the side-by-side diff view that will be shown on RTD.
table_node = nodes.paragraph()
self.content = ViewList(table_code, self.content.parent)
self.state.nested_parse(self.content, self.content_offset, table_node)

return [table_node, test_node]
return [table_node] + test_nodes


def setup(app):
Expand Down
39 changes: 19 additions & 20 deletions docs/_ext/codediff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,36 +33,35 @@ def get_initial_params(key):
initial_params = CNN().init(key, init_val)['params']
return initial_params"""

expected_table = r"""+----------------------------------------------------------+----------------------------------------------------------+
| Single device | Ensembling on multiple devices |
+----------------------------------------------------------+----------------------------------------------------------+
| .. code-block:: python | .. code-block:: python |
| :emphasize-lines: 1,2 | :emphasize-lines: 1 |
| | |
| @jax.jit | @jax.pmap |
| def get_initial_params(key): | def get_initial_params(key): |
| init_val = jnp.ones((1, 28, 28, 1), jnp.float32) | init_val = jnp.ones((1, 28, 28, 1), jnp.float32) |
| initial_params = CNN().init(key, init_val)['params'] | initial_params = CNN().init(key, init_val)['params'] |
| extra_line | return initial_params |
| return initial_params | |
+----------------------------------------------------------+----------------------------------------------------------+"""
expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params"""

expected_testcode = r"""@jax.pmap #!
expected_testcodes = [
r"""@jax.jit #!
def get_initial_params(key): #!
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
extra_line
return initial_params
""",
r"""@jax.pmap #!
def get_initial_params(key):
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
return initial_params"""
return initial_params""",
]

title_left = 'Single device'
title_right = 'Ensembling on multiple devices'

actual_table, actual_testcode = CodeDiffParser().parse(
actual_table, actual_testcodes = CodeDiffParser().parse(
lines=input_text.split('\n'),
title_left=title_left,
title_right=title_right,
title=f'{title_left}, {title_right}',
)

actual_table = '\n'.join(actual_table)
actual_testcode = '\n'.join(actual_testcode)
actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes]

assert False
self.assertEqual(expected_table, actual_table)
self.assertEqual(expected_testcode, actual_testcode)
self.assertEqual(expected_testcodes[0], actual_testcodes[0])
self.assertEqual(expected_testcodes[1], actual_testcodes[1])
8 changes: 3 additions & 5 deletions docs/experimental/nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ seamlessly switch between them or use them together. We will be focusing on the

First, let's set up imports and generate some dummy data:

.. testcode::
.. testcode:: 0, 1

from flax.experimental import nnx
import jax
Expand Down Expand Up @@ -38,8 +38,7 @@ whereas the function signature of JAX-transformed functions can only accept the
the transformed function.

.. codediff::
:title_left: NNX transforms
:title_right: JAX transforms
:title: NNX transforms, JAX transforms
:sync:

@nnx.jit
Expand Down Expand Up @@ -83,8 +82,7 @@ NNX and JAX transformations can be mixed together, so long as the JAX-transforme
pure and has valid argument types that are recognized by JAX.

.. codediff::
:title_left: Using ``nnx.jit`` with ``jax.grad``
:title_right: Using ``jax.jit`` with ``nnx.grad``
:title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad``
:sync:

@nnx.jit
Expand Down
Loading

0 comments on commit 331f138

Please sign in to comment.