Skip to content

Commit

Permalink
[nnx] add cache_args
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 14, 2025
1 parent 1961c12 commit 2ef1999
Show file tree
Hide file tree
Showing 58 changed files with 3,548 additions and 2,459 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/flax_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@v1
with:
python-version: '3.x'
- name: Install dependencies
Expand Down
42 changes: 21 additions & 21 deletions .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

name: Flax - Test

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

on:
push:
branches:
Expand All @@ -17,26 +13,30 @@ on:
- main

jobs:
cancel-previous:
name: Cancel Previous Runs
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.10.1
if: ${{github.ref != 'refs/head/main'}}
with:
access_token: ${{ github.token }}
pre-commit:
name: Test pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@v4
with:
python-version: '3.10'
- run: python -m pip install pre-commit
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'pyproject.toml') }}
- run: pre-commit run --show-diff-on-failure --color=always --all-files
- uses: pre-commit/action@v2.0.3
commit-count:
name: Check commit count
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v3
# We allow at most 5 commits in a branch to ensure our CI doesn't break.
- name: Check commit count in PR
if: always()
Expand Down Expand Up @@ -65,12 +65,12 @@ jobs:
matrix:
python-version: ['3.10', '3.11']
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
- uses: astral-sh/setup-uv@v2
with:
uv-version: "0.3.0"
- name: Install standalone dependencies only
Expand All @@ -81,7 +81,7 @@ jobs:
uv run python -c "import flax"
tests:
name: Run Tests
needs: [pre-commit, commit-count, test-import]
needs: [cancel-previous, pre-commit, commit-count, test-import]
runs-on: ubuntu-20.04-16core
strategy:
matrix:
Expand All @@ -98,14 +98,14 @@ jobs:
test-type: pytest
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Setup uv
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
uses: astral-sh/setup-uv@v2
with:
version: "0.3.0"

Expand Down Expand Up @@ -135,7 +135,7 @@ jobs:
fi
- name: Upload coverage to Codecov
if: matrix.test-type == 'pytest'
uses: codecov/codecov-action@1e68e06f1dbfde0e4cefc87efeba9e4643565303 # v5.1.2
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/flaxlib_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ jobs:
os: [ubuntu-latest, windows-latest, macos-13, macos-14]

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v4

- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- uses: actions/setup-python@v5

- name: Setup Rust
uses: actions-rust-lang/setup-rust-toolchain@11df97af8e8102fd60b60a77dfbf58d40cd843b8 # v1.10.1
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.21.0
Expand All @@ -41,7 +41,7 @@ jobs:
curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=minimal -y &&
rustup show
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./flaxlib/wheelhouse/*.whl
Expand All @@ -51,15 +51,15 @@ jobs:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v4

- name: Setup Rust
uses: actions-rust-lang/setup-rust-toolchain@11df97af8e8102fd60b60a77dfbf58d40cd843b8 # v1.10.1
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Build sdist
run: pipx run build --sdist flaxlib

- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
- uses: actions/upload-artifact@v4
with:
name: cibw-sdist
path: ./flaxlib/dist/*.tar.gz
Expand All @@ -72,14 +72,14 @@ jobs:
permissions:
id-token: write
steps:
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- uses: actions/setup-python@v1
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools build wheel twine
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
- uses: actions/download-artifact@v4
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
Expand Down
52 changes: 0 additions & 52 deletions .github/workflows/jax_nightly.yml

This file was deleted.

63 changes: 46 additions & 17 deletions benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,52 @@
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_enum(
'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')



class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.list = [
nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
nnx.Param(jnp.zeros((dout,))),
]
self.dict = {
'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
'b': nnx.Param(jnp.zeros((dout,))),
}
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))

def __call__(self, x):
return x @ self.w + self.b


class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))


class Count(nnx.Variable):
pass


class MLP(nnx.Module):
def __init__(self, depth, *, rngs: nnx.Rngs):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(10, 10, rngs=rngs) for _ in range(depth)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
x = nnx.relu(self.linear_in(x))
for layer in self.intermediates:
x = nnx.relu(layer(x))
x = self.linear_out(x)
return x


def main(argv):
Expand All @@ -63,21 +84,24 @@ def main(argv):
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)

#------------------------------------------------------------
# NNX
#------------------------------------------------------------
if mode in ['all', 'nnx']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
pass

cached_step_nnx = nnx.cache_args(step_nnx, model, optimizer)

t0 = time()
for _ in range(total_steps):
step_nnx(model, optimizer)
cached_step_nnx()

total_time = time() - t0
time_per_step = total_time / total_steps
Expand All @@ -93,6 +117,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
#------------------------------------------------------------

if mode in ['all', 'jax']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def step_jax(graphdef, state):
return graphdef, state
Expand Down
Loading

0 comments on commit 2ef1999

Please sign in to comment.