diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 2f7faf0..0000000 --- a/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -**/.ipynb_checkpoints -examples/data/ diff --git a/examples/implementing_pc_from_scratch/index.html b/examples/implementing_pc_from_scratch/index.html new file mode 100644 index 0000000..0fada51 --- /dev/null +++ b/examples/implementing_pc_from_scratch/index.html @@ -0,0 +1,1497 @@ + + + + + + + + + + + + + + + + + + + + + Implementing pc from scratch - jpc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + + + +
+
+
+ + + +
+
+ + + + + + +
+
+
+

⚙️ Implementing Predictive Coding from Scratch (in JAX)¤

+

In this notebook, we walk through how to implement Predictive Coding (PC) from scratch using JAX. In particular, we are going to train a simple feedforward network to classify MNIST digits with PC. It might be a good idea to revisit the introductory lecture on PC from this week.

+

If you're not familiar with JAX, have a look at their docs, but we will explain all the necessary concepts below. JAX is basically numpy for GPUs and other hardware accelerators. We will also use: Equinox, which allows you to define neural nets with PyTorch-like syntax; and Optax, which provides a range of common machine learning optimisers such as gradient descent and Adam.

+
+
+
+
+
+
+

Installations & imports¤

+
+
+
+
+
+ +
#@title installations
+
+
+%%capture
+!pip install torch==2.3.1
+!pip install torchvision==0.18.1
+
+ +
+
+
+
+ +
#@title imports
+
+
+import jax.random as jr
+import jax.numpy as jnp
+from jax import vmap, grad
+from jax.tree_util import tree_map
+
+import equinox as eqx
+import equinox.nn as nn
+from equinox import filter_grad
+import optax
+
+import torch
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+
+import warnings
+warnings.simplefilter("ignore")
+
+ +
+
+
+
+
+

Hyperparameters¤

+

We define some global parameters related to the data, network, optimisers, etc.

+
+
+
+
+
+ +
SEED = 827
+
+INPUT_DIM, OUTPUT_DIM = 28*28, 10
+NETWORK_WIDTH = 300
+
+ACTIVITY_LR = 1e-1
+INFERENCE_STEPS = 20
+
+PARAM_LR = 1e-3
+BATCH_SIZE = 64
+
+TEST_EVERY = 50
+N_TRAIN_ITERS = 500
+
+ +
+
+
+
+
+

Dataset¤

+

Some utils to fetch MNIST.

+
+
+
+
+
+ +
#@title data utils
+
+
+def get_mnist_loaders(batch_size):
+    train_data = MNIST(train=True, normalise=True)
+    test_data = MNIST(train=False, normalise=True)
+    train_loader = DataLoader(
+        dataset=train_data,
+        batch_size=batch_size,
+        shuffle=True,
+        drop_last=True
+    )
+    test_loader = DataLoader(
+        dataset=test_data,
+        batch_size=batch_size,
+        shuffle=True,
+        drop_last=True
+    )
+    return train_loader, test_loader
+
+
+class MNIST(datasets.MNIST):
+    def __init__(self, train, normalise=True, save_dir="data"):
+        if normalise:
+            transform = transforms.Compose(
+                [
+                    transforms.ToTensor(),
+                    transforms.Normalize(
+                        mean=(0.1307), std=(0.3081)
+                    )
+                ]
+            )
+        else:
+            transform = transforms.Compose([transforms.ToTensor()])
+        super().__init__(save_dir, download=True, train=train, transform=transform)
+
+    def __getitem__(self, index):
+        img, label = super().__getitem__(index)
+        img = torch.flatten(img)
+        label = one_hot(label)
+        return img, label
+
+
+def one_hot(labels, n_classes=10):
+    arr = torch.eye(n_classes)
+    return arr[labels]
+
+ +
+
+
+
+
+

PC energy¤

+
+
+
+
+
+
+

First, recall that PC can be derived as a variational inference algorithm under certain assumptions. In particular, if we assume +* a dirac delta (point mass) posterior and +* a hierarchical Gaussian generative model,

+

we get the standard PC energy

+

\begin{equation} + \mathcal{F} = \frac{1}{2N}\sum_{i=1}^{N} \sum_{\ell=1}^L ||\mathbf{z}{\ell, i} - f\ell(W_\ell \mathbf{z}{\ell-1, i} + \mathbf{b}\ell)||^2_2 +\end{equation} +which is just a sum of squared prediction errors at each network layer. Here we are being a little bit more precise than in the lecture, including multiple (\(N\)) data points and biases \(\mathbf{b}_\ell\).

+

🤔 Food for thought: Think about how the form of this energy could change depending other assumptions we make about the generative model. See, for example, Learning on Arbitrary Graph Topologies via Predictive Coding + by Salvatori et al. (2022).

+

Let's start by implementing this energy below. The function simply takes the model (with all the parameters), some initialised activities, and some input and output. Given these, it simply sums the prediction error at each layer.

+

NOTE: below we use vmap, one of the core JAX transforms that allows you to vectorise operations, in this case for multiple data points or over a batch. See https://jax.readthedocs.io/en/latest/automatic-vectorization.html for more details.

+
+
+
+
+
+ +
def pc_energy_fn(model, activities, input, output):
+    batch_size = output.shape[0]
+    n_activity_layers = len(activities) - 1
+    n_layers = len(model) - 1
+
+    eL = output - vmap(model[-1])(activities[-2])
+    energies = [jnp.sum(eL ** 2)]
+    for act_l, net_l in zip(
+            range(1, n_activity_layers),
+            range(1, n_layers)
+    ):
+        err = activities[act_l] - vmap(model[net_l])(activities[act_l - 1])
+        energies.append(jnp.sum(err ** 2))
+
+    e1 = activities[0] - vmap(model[0])(input)
+    energies.append(jnp.sum(e1 ** 2))
+
+    return jnp.sum(jnp.array(energies)) / batch_size
+
+ +
+
+
+
+
+

Now let's test it. To do so, we first need a model. Below we use Equinox to create a simple feedforward network with 2 hidden layers and tanh activations. Note that we split the model into different parts with nn.Sequential to define the activities which PC will optimise over (during inference, more on this below).

+

Question: Think about other ways in which we could split the layers, for example by separating the non-linearities. Can you think of potential issues with this?

+
+
+
+
+
+ +
# jax uses explicit random number generators (see https://jax.readthedocs.io/en/latest/random-numbers.html)
+key = jr.PRNGKey(SEED)
+subkeys = jr.split(key, 3)
+
+model = [
+    nn.Sequential(
+        [
+            nn.Linear(INPUT_DIM, NETWORK_WIDTH, key=subkeys[0]),
+            nn.Lambda(jnp.tanh)
+        ],
+    ),
+    nn.Sequential(
+        [
+            nn.Linear(NETWORK_WIDTH, NETWORK_WIDTH, key=subkeys[1]),
+            nn.Lambda(jnp.tanh)
+        ],
+    ),
+    nn.Linear(NETWORK_WIDTH, OUTPUT_DIM, key=subkeys[2]),
+]
+model
+
+ +
+
+
+
+
+
+[Sequential(
+   layers=(
+     Linear(
+       weight=f32[300,784],
+       bias=f32[300],
+       in_features=784,
+       out_features=300,
+       use_bias=True
+     ),
+     Lambda(fn=<wrapped function tanh>)
+   )
+ ),
+ Sequential(
+   layers=(
+     Linear(
+       weight=f32[300,300],
+       bias=f32[300],
+       in_features=300,
+       out_features=300,
+       use_bias=True
+     ),
+     Lambda(fn=<wrapped function tanh>)
+   )
+ ),
+ Linear(
+   weight=f32[10,300],
+   bias=f32[10],
+   in_features=300,
+   out_features=10,
+   use_bias=True
+ )]
+
+
+
+
+
+
+
+
+
+

The last thing with need is to initialise the activities. For this, we will use a feedforward pass as often done in practice.

+

Question: Can you think of other ways of initialising the activities?

+
+
+
+
+
+ +
def init_activities_with_ffwd(model, input):
+    activities = [vmap(model[0])(input)]
+    for l in range(1, len(model)):
+        layer_output = vmap(model[l])(activities[l - 1])
+        activities.append(layer_output)
+
+    return activities
+
+ +
+
+
+
+
+

Let's test it on an MNIST sample.

+
+
+
+
+
+ +
# get a data sample
+train_loader, test_loader = get_mnist_loaders(BATCH_SIZE)
+img_batch, label_batch = next(iter(train_loader))
+
+# we need to turn the torch.Tensor data into numpy arrays for jax
+img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
+
+# let's check our initialised activities
+activities = init_activities_with_ffwd(model, img_batch)
+for i, a in enumerate(activities):
+    print(f"activity z at layer {i+1}: {a.shape}")
+
+ +
+
+
+
+
+
+activity z at layer 1: (64, 300)
+activity z at layer 2: (64, 300)
+activity z at layer 3: (64, 10)
+
+
+
+
+
+
+
+
+
+
+

Ok so now we have everything to test our PC energy function: model, activities, and some data.

+
+
+
+
+
+ +
pc_energy_fn(
+    model=model,
+    activities=activities,
+    input=img_batch,
+    output=label_batch
+)
+
+ +
+
+
+
+
+
+Array(1.2335204, dtype=float32)
+
+
+
+
+
+
+
+
+
+

And it works!

+
+
+
+
+
+
+

Energy gradients¤

+
+
+
+
+
+
+

How do we minimise the PC energy we defined above (Eq. 1)? Recall from the lecture that we do this in two phases: first with respect to the activities (inference) and then with respect to the weights (learning).

+
\[\begin{equation} + \textit{Inference:} - \frac{\partial \mathcal{F}}{\partial \mathbf{z}_\ell} +\end{equation}\]
+
\[\begin{equation} + \textit{Learning:} - \frac{\partial \mathcal{F}}{\partial W_\ell} +\end{equation}\]
+

So we just need to take these gradients of the energy. We are going to use autodiff, which JAX embeds by design (https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). If you're familiar with PyTorch, you are probably used to loss.backward() for this, which might feel obstruse at times. JAX, on the other hand, is a fully functional (as opposed to object-oriented) language whose syntax is very close to the maths as you can see below.

+
+
+
+
+
+ +
# note how close this code is to the maths
+# this can be read as "take the gradient of the energy...
+# ...with the respect to the 2nd argument (the activities)
+
+def compute_activity_grad(model, activities, input, output):
+    return grad(pc_energy_fn, argnums=1)(
+        model,
+        activities,
+        input,
+        output
+    )
+
+ +
+
+
+
+
+

Let's test this out.

+
+
+
+
+
+ +
dFdzs = compute_activity_grad(
+    model=model, 
+    activities=activities, 
+    input=img_batch, 
+    output=label_batch
+)
+for i, dFdz in enumerate(dFdzs):
+    print(f"activity gradient dFdz shape at layer {i+1}: {dFdz.shape}")
+
+ +
+
+
+
+
+
+activity gradient dFdz shape at layer 1: (64, 300)
+activity gradient dFdz shape at layer 2: (64, 300)
+activity gradient dFdz shape at layer 3: (64, 10)
+
+
+
+
+
+
+
+
+
+
+

Now we do the same and take the gradient of the energy with respect to the parameters.

+

Technical note: below we use Equinox's convenience function filter_grad rather than JAX's native grad. This is because things like activation functions do not have parameters and so we do not want to differentiate them. filter_grad automatically filters these non-differentiable objects for us, while grad alone would throw an error.

+
+
+
+
+
+ +
# note that, compared to the previous function,...
+# ...we just change the argument with respect to which...
+# ...we are differentiating (the first, or in this case the model)
+
+def compute_param_grad(model, activities, input, output):
+    return filter_grad(pc_energy_fn)(
+        model,
+        activities,
+        input,
+        output
+    )
+
+ +
+
+
+
+
+

And let's test it.

+
+
+
+
+
+ +
param_grads = compute_param_grad(
+    model=model, 
+    activities=activities, 
+    input=img_batch, 
+    output=label_batch
+)
+
+ +
+
+
+
+
+

Updates¤

+
+
+
+
+
+
+

Before putting everything together, let's wrap our gradients into update functions. This will also allow us to use JAX's jit primitive, which essentially compiles your code the first time it's executed so that it can be run more efficiently the next time (see https://jax.readthedocs.io/en/latest/jit-compilation.html for more details).

+

These functions take an (Optax) optimiser such as gradient descent in addition to the previous arguments (model, activities and data).

+
+
+
+
+
+ +
@eqx.filter_jit
+def update_activities(model, activities, optim, opt_state, input, output):
+    activity_grads = compute_activity_grad(
+        model=model,
+        activities=activities,
+        input=input,
+        output=output
+    )
+    activity_updates, activity_opt_state = optim.update(
+        updates=activity_grads,
+        state=opt_state,
+        params=activities
+    )
+    activities = eqx.apply_updates(
+        model=activities,
+        updates=activity_updates
+    )
+    return activities, optim, opt_state
+
+
+# note that the only difference with the above function is...
+# ...the variable we are updating (parameters vs activities)
+@eqx.filter_jit
+def update_params(model, activities, optim, opt_state, input, output):
+    param_grads = compute_param_grad(
+        model=model,
+        activities=activities,
+        input=input,
+        output=output
+    )
+    param_updates, param_opt_state = optim.update(
+        updates=param_grads,
+        state=opt_state,
+        params=model
+    )
+    model = eqx.apply_updates(
+        model=model,
+        updates=param_updates
+    )
+    return model, optim, opt_state
+
+ +
+
+
+
+
+

Putting everything together: Training and testing¤

+

Now that we have our activity and parameter updates, we just need to wrap them in a training and test loop.

+
+
+
+
+
+ +
# note: the accuracy computation below could be sped up...
+# ...with jit in a separate function
+
+def evaluate(model, test_loader):
+    avg_test_acc = 0
+    for test_iter, (img_batch, label_batch) in enumerate(test_loader):
+        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
+
+        preds = init_activities_with_ffwd(model, img_batch)[-1]
+        test_acc = jnp.mean(
+            jnp.argmax(label_batch, axis=1) == jnp.argmax(preds, axis=1)
+        ) * 100
+        avg_test_acc += test_acc
+
+    return avg_test_acc / len(test_loader)
+
+
+def train(
+      model,
+      activity_lr,
+      inference_steps,
+      param_lr,
+      batch_size,
+      test_every,
+      n_train_iters
+):
+    # define optimisers for activities and parameters
+    activity_optim = optax.sgd(activity_lr)
+    param_optim = optax.adam(param_lr)
+    param_opt_state = param_optim.init(eqx.filter(model, eqx.is_array))
+
+    train_loader, test_loader = get_mnist_loaders(batch_size)
+    for train_iter, (img_batch, label_batch) in enumerate(train_loader):
+        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
+
+        # initialise activities
+        activities = init_activities_with_ffwd(model, img_batch)
+        activity_opt_state = activity_optim.init(activities)
+
+        train_loss = jnp.mean((label_batch - activities[-1])**2)
+
+        # inference
+        for t in range(inference_steps):
+            activities, activity_optim, activity_opt_state = update_activities(
+                model=model, 
+                activities=activities, 
+                optim=activity_optim, 
+                opt_state=activity_opt_state, 
+                input=img_batch, 
+                output=label_batch
+            )
+
+        # learning
+        model, param_optim, param_opt_state = update_params(
+            model=model,
+            activities=activities,  # note how we use the optimised activities
+            optim=param_optim,
+            opt_state=param_opt_state,
+            input=img_batch,
+            output=label_batch
+        )
+        if ((train_iter+1) % test_every) == 0:
+            avg_test_acc = evaluate(model, test_loader)
+            print(
+                f"Train iter {train_iter+1}, train loss={train_loss:4f}, "
+                f"avg test accuracy={avg_test_acc:4f}"
+            )
+            if (train_iter+1) >= n_train_iters:
+                break
+
+ +
+
+
+
+
+

Run¤

+
+
+
+
+
+
+

Let's test our implementation.

+
+
+
+
+
+ +
train(
+    model=model,
+    activity_lr=ACTIVITY_LR,
+    inference_steps=INFERENCE_STEPS,
+    param_lr=PARAM_LR,
+    batch_size=BATCH_SIZE,
+    test_every=TEST_EVERY,
+    n_train_iters=N_TRAIN_ITERS
+)
+
+ +
+
+
+
+
+
+Train iter 50, train loss=0.065566, avg test accuracy=72.726364
+Train iter 100, train loss=0.046521, avg test accuracy=76.292068
+Train iter 150, train loss=0.042710, avg test accuracy=86.568512
+Train iter 200, train loss=0.029598, avg test accuracy=89.082535
+Train iter 250, train loss=0.031486, avg test accuracy=89.222755
+Train iter 300, train loss=0.016624, avg test accuracy=91.296074
+Train iter 350, train loss=0.025201, avg test accuracy=92.648239
+Train iter 400, train loss=0.018597, avg test accuracy=92.968750
+Train iter 450, train loss=0.019027, avg test accuracy=94.130608
+Train iter 500, train loss=0.014850, avg test accuracy=93.760017
+
+
+
+
+
+
+
+
+
+
+

🥳 Great, we see that our model is training! You can probably improve the performance by tweaking some of the hyperparameters (e.g. try a higher number of inference steps).

+

Even if you didn't follow all the implementation details, you should now have at least an idea of how PC works in practice. Indeed, this is basically the core code behind a new PC library our lab will soon release: JPC. Play around with the notebook examples there where you can learn how to train a variety of PC networks.

+
+
+
+ + + + + + + +
+
+
+ +
+ + + + +
+
+
+
+ + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/linear_net_theoretical_activities/index.html b/examples/linear_net_theoretical_activities/index.html new file mode 100644 index 0000000..cec095a --- /dev/null +++ b/examples/linear_net_theoretical_activities/index.html @@ -0,0 +1,1320 @@ + + + + + + + + + + + + + + + + + + + + + Linear net theoretical activities - jpc + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + + + +
+
+
+ + + +
+
+ + + + + + +
+
+
+

Theoretical activities of deep linear networks¤

+

Open in Colab

+
+
+
+
+
+ +
%%capture
+!pip install torch==2.3.1
+!pip install torchvision==0.18.1
+!pip install plotly==5.11.0
+!pip install -U kaleido
+
+ +
+
+
+
+ +
import jpc
+
+import jax
+from jax import vmap
+import jax.numpy as jnp
+import equinox as eqx
+import equinox.nn as nn
+import optax
+
+import torch
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+
+import plotly.graph_objs as go
+import plotly.io as pio
+
+pio.renderers.default = 'iframe'
+
+ +
+
+
+
+
+

Hyperparameters¤

+

We define some global parameters, including network architecture, learning rate, batch size etc.

+
+
+
+
+
+ +
SEED = 0
+LEARNING_RATE = 1e-3
+BATCH_SIZE = 64
+TEST_EVERY = 10
+N_TRAIN_ITERS = 20
+
+ +
+
+
+
+
+

Dataset¤

+

Some utils to fetch MNIST.

+
+
+
+
+
+ +
#@title data utils
+
+
+def get_mnist_loaders(batch_size):
+    train_data = MNIST(train=True, normalise=True)
+    test_data = MNIST(train=False, normalise=True)
+    train_loader = DataLoader(
+        dataset=train_data,
+        batch_size=batch_size,
+        shuffle=True,
+        drop_last=True
+    )
+    test_loader = DataLoader(
+        dataset=test_data,
+        batch_size=batch_size,
+        shuffle=True,
+        drop_last=True
+    )
+    return train_loader, test_loader
+
+
+class MNIST(datasets.MNIST):
+    def __init__(self, train, normalise=True, save_dir="data"):
+        if normalise:
+            transform = transforms.Compose(
+                [
+                    transforms.ToTensor(),
+                    transforms.Normalize(
+                        mean=(0.1307), std=(0.3081)
+                    )
+                ]
+            )
+        else:
+            transform = transforms.Compose([transforms.ToTensor()])
+        super().__init__(save_dir, download=True, train=train, transform=transform)
+
+    def __getitem__(self, index):
+        img, label = super().__getitem__(index)
+        img = torch.flatten(img)
+        label = one_hot(label)
+        return img, label
+
+
+def one_hot(labels, n_classes=10):
+    arr = torch.eye(n_classes)
+    return arr[labels]
+
+ +
+
+
+
+
+

Plotting¤

+
+
+
+
+
+ +
def plot_layer_energies(energies):
+    n_train_iters = energies["theory"].shape[0]
+    n_energies = energies["theory"].shape[1]
+    train_iters = [b+1 for b in range(n_train_iters)]
+
+    colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52', '#8C564B']
+
+    fig = go.Figure()
+    for n in range(n_energies):
+        fig.add_traces(
+            go.Scatter(
+                x=train_iters,
+                y=energies["theory"][:, n],
+                mode="lines",
+                line=dict(
+                    width=2, 
+                    dash="dash",
+                    color=colors[n]
+                ),
+                showlegend=False
+            )
+        )
+        fig.add_traces(
+            go.Scatter(
+                x=train_iters,
+                y=energies["experiment"][:, n],
+                name=f"$\Large{{\ell_{n+1}}}$",
+                mode="lines",
+                line=dict(
+                    width=3, 
+                    dash="solid",
+                    color=colors[n]
+                ),
+            )
+        )
+
+    fig.update_layout(
+        height=350,
+        width=475,
+        xaxis=dict(
+            title="Training iteration",
+            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
+            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
+        ),
+        yaxis=dict(
+            title="Energy",
+            nticks=3
+        ),
+        font=dict(size=16),
+    )
+    fig.write_image("dln_layer_energies_example.pdf")
+    return fig
+
+ +
+
+
+
+
+

Linear network¤

+
+
+
+
+
+ +
key = jax.random.PRNGKey(0)
+subkeys = jax.random.split(key, 21)
+
+network = [
+    eqx.nn.Linear(784, 300, key=subkeys[0], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[1], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[2], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[3], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[4], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[5], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[6], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[7], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[8], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[9], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[10], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[11], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[12], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[13], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[14], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[15], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[16], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[17], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[18], use_bias=False),
+    eqx.nn.Linear(300, 300, key=subkeys[19], use_bias=False),
+    eqx.nn.Linear(300, 10, key=subkeys[20], use_bias=False),
+]
+
+ +
+
+
+
+
+

Train and test¤

+

A PC network can be trained in a single line of code with jpc.make_pc_step(). See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already "jitted" for performance.

+

Below we simply wrap each of these functions in our training and test loops, respectively.

+
+
+
+
+
+ +
def evaluate(model, test_loader):
+    test_acc = 0
+    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
+        img_batch = img_batch.numpy()
+        label_batch = label_batch.numpy()
+
+        test_acc += jpc.test_discriminative_pc(
+            model=model,
+            y=label_batch,
+            x=img_batch
+        )
+
+    return test_acc / len(test_loader)
+
+
+def train(
+      model,  
+      lr,
+      batch_size,
+      t1,
+      test_every,
+      n_train_iters
+):
+    optim = optax.adam(lr)
+    opt_state = optim.init(eqx.filter(model, eqx.is_array))
+    train_loader, test_loader = get_mnist_loaders(batch_size)
+
+    As, cond_numbers = [], []
+    num_energies, theory_energies = [], []
+    num_total_energies, theory_total_energies = [], []
+    for iter, (img_batch, label_batch) in enumerate(train_loader):
+        img_batch = img_batch.numpy()
+        label_batch = label_batch.numpy()
+
+        theory_total_energies.append(
+            jpc.linear_equilib_energy(
+                network=model, 
+                x=img_batch, 
+                y=label_batch
+            )
+        )
+        A = jpc.linear_activities_coeff_matrix([l.weight for l in model])
+        As.append(A)
+        cond_numbers.append(jnp.linalg.cond(A))
+        theory_activities = jpc.linear_activities_solution(
+            network=model, 
+            x=img_batch, 
+            y=label_batch
+        )
+        theory_energies.append(
+            jnp.flip(jpc.pc_energy_fn(
+                model,
+                theory_activities,
+                x=img_batch,
+                y=label_batch,
+                record_layers=True
+            ))
+        )
+        result = jpc.make_pc_step(
+            model,
+            optim,
+            opt_state,
+            y=label_batch,
+            x=img_batch,
+            t1=t1,
+            record_energies=True
+        )
+        model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
+        train_loss, t_max = result["loss"], result["t_max"]
+        num_total_energies.append(result["energies"][:, t_max-1].sum())
+        num_energies.append(result["energies"][:, t_max-1])
+
+        if ((iter+1) % test_every) == 0:
+            avg_test_acc = evaluate(model, test_loader)
+            print(
+                f"Train iter {iter+1}, train loss={train_loss:4f}, "
+                f"avg test accuracy={avg_test_acc:4f}"
+            )
+            if (iter+1) >= n_train_iters:
+                break
+
+    return {
+        "experiment": jnp.array(num_energies),
+        "theory": jnp.array(theory_energies)
+    }, As, cond_numbers
+
+ +
+
+
+
+
+

Run¤

+
+
+
+
+
+ +
energies, As, cond_numbers = train(
+    model=network,
+    lr=LEARNING_RATE,
+    batch_size=BATCH_SIZE,
+    t1=300,
+    test_every=TEST_EVERY,
+    n_train_iters=N_TRAIN_ITERS
+)
+
+ +
+
+
+
+
+
+/Users/fi69/PycharmProjects/jpc/venv/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning:
+
+unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
+
+
+
+
+
+
+
+
+Train iter 10, train loss=0.091816, avg test accuracy=0.270333
+Train iter 20, train loss=0.089849, avg test accuracy=0.335337
+
+
+
+
+
+
+
+
+
+ +
import matplotlib.pyplot as plt
+
+ +
+
+
+
+ +
plt.plot(cond_numbers)
+
+ +
+
+
+
+
+
+[<matplotlib.lines.Line2D at 0x31bee37c0>]
+
+
+
+
+
+ +
+
+
+
+
+
+
+ +
n_train_iters = energies["theory"].shape[0]
+n_energies = energies["theory"].shape[1]
+train_iters = [b+1 for b in range(n_train_iters)]
+
+colors = [
+    '#636EFA', 
+    '#EF553B',
+    '#00CC96', 
+    '#AB63FA', 
+    '#FFA15A', 
+    '#19D3F3', 
+    '#FF6692', 
+    '#B6E880', 
+    '#FF97FF', 
+    '#FECB52', 
+    '#8C564B',
+    '#636EFA', 
+    '#EF553B',
+    '#00CC96', 
+    '#AB63FA', 
+    '#FFA15A', 
+    '#19D3F3', 
+    '#FF6692', 
+    '#B6E880', 
+    '#FF97FF', 
+    '#FECB52', 
+    '#8C564B',
+    '#636EFA', 
+    '#EF553B',
+    '#00CC96', 
+    '#AB63FA', 
+    '#FFA15A', 
+    '#19D3F3', 
+    '#FF6692', 
+    '#B6E880', 
+    '#FF97FF', 
+    '#FECB52', 
+    '#8C564B'
+]
+
+fig = go.Figure()
+for n in range(n_energies):
+    fig.add_traces(
+        go.Scatter(
+            x=train_iters,
+            y=energies["theory"][:, n],
+            mode="lines",
+            line=dict(
+                width=3, 
+                dash="dash",
+                color=colors[n]
+            ),
+            showlegend=False
+        )
+    )
+    fig.add_traces(
+        go.Scatter(
+            x=train_iters,
+            y=energies["experiment"][:, n],
+            name=f"$\Large{{\ell_{{{n+1}}}}}$",
+            mode="lines",
+            line=dict(
+                width=2, 
+                dash="solid",
+                color=colors[n]
+            ),
+        )
+    )
+
+fig.update_layout(
+    height=400,
+    width=600,
+    xaxis=dict(
+        title="Training iteration",
+        tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
+        ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
+    ),
+    yaxis=dict(
+        title="Energy",
+        nticks=3
+    ),
+    font=dict(size=16)
+)
+
+ +
+
+
+
+
+ +
+
+
+
+
+
+
+ +

+
+ +
+
+ + + + + + + +
+
+
+ +
+ + + + +
+
+
+
+ + + + + + + + + + + + + + \ No newline at end of file diff --git a/search/search_index.json b/search/search_index.json index 232eb76..2b841d5 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"Getting started \u00a4 JPC is a J AX library for training neural networks with P redictive C oding (PC). It is built on top of three main libraries: Equinox , to define neural networks with PyTorch-like syntax, Diffrax , to solve the PC inference (activity) dynamics, and Optax , for parameter optimisation. JPC provides a simple , relatively fast and flexible API for training of a variety of PCNs including discriminative, generative and hybrid models. Like JAX, JPC is completely functional, and the core library is <1000 lines of code. Unlike existing implementations, JPC leverages ordinary differential equation (ODE) solvers to integrate the inference dynamics of PC networks (PCNs), which we find can provide significant speed-ups compared to standard optimisers, especially for deeper models. JPC also provides some analytical tools that can be used to study and diagnose issues with PCNs. \ud83d\udcbb Installation \u00a4 pip install jpc Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11.2+, Diffrax 0.6.0+, and Optax 0.2.4+. For GPU usage, upgrade jax to the appropriate cuda version (12 as an example here). pip install --upgrade \"jax[cuda12]\" \u26a1\ufe0f Quick example \u00a4 Use jpc.make_pc_step to update the parameters of any neural network compatible with PC updates (see examples) import jax.random as jr import jax.numpy as jnp import equinox as eqx import optax import jpc # toy data x = jnp . array ([ 1. , 1. , 1. ]) y = - x # define model and optimiser key = jr . PRNGKey ( 0 ) model = jpc . make_mlp ( key , layer_sizes = [ 3 , 5 , 5 , 3 ], act_fn = \"relu\" ) optim = optax . adam ( 1e-3 ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) # perform one training step with PC result = jpc . make_pc_step ( model = model , optim = optim , opt_state = opt_state , output = y , input = x ) # updated model and optimiser model = result [ \"model\" ] optim , opt_state = result [ \"optim\" ], result [ \"opt_state\" ] Under the hood, jpc.make_pc_step integrates the inference (activity) dynamics using a Diffrax ODE solver, and updates model parameters at the numerical solution of the activities with a given Optax optimiser. NOTE : All convenience training and test functions including make_pc_step are already \"jitted\" (for increased performance) for the user's convenience. \ud83e\udde0\ufe0f Predictive coding primer \u00a4 ... \ud83d\ude80 Advanced usage \u00a4 Advanced users can access all the underlying functions of jpc.make_pc_step as well as additional features. A custom PC training step looks like the following: import jpc # 1. initialise activities with a feedforward pass activities = jpc . init_activities_with_ffwd ( model = model , input = x ) # 2. run inference to equilibrium equilibrated_activities = jpc . solve_inference ( params = ( model , None ), activities = activities , output = y , input = x ) # 3. update parameters at the activities' solution with PC result = jpc . update_params ( params = ( model , None ), activities = equilibrated_activities , optim = optim , opt_state = opt_state , output = y , input = x ) which can be embedded in a jitted function with any other additional computations. \ud83d\udcc4 Citation \u00a4 If you found this library useful in your work, please cite (arXiv link): @article { innocenti2024jpc , title = {JPC: Flexible Inference for Predictive Coding Networks in JAX} , author = {Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Singh, Ryan and De Llanza Varona, Miguel and Buckley, Christopher} , journal = {arXiv preprint} , year = {2024} } Also consider starring the project on GitHub ! \u2b50\ufe0f \u23ed\ufe0f Next steps \u00a4","title":"Getting started"},{"location":"#getting-started","text":"JPC is a J AX library for training neural networks with P redictive C oding (PC). It is built on top of three main libraries: Equinox , to define neural networks with PyTorch-like syntax, Diffrax , to solve the PC inference (activity) dynamics, and Optax , for parameter optimisation. JPC provides a simple , relatively fast and flexible API for training of a variety of PCNs including discriminative, generative and hybrid models. Like JAX, JPC is completely functional, and the core library is <1000 lines of code. Unlike existing implementations, JPC leverages ordinary differential equation (ODE) solvers to integrate the inference dynamics of PC networks (PCNs), which we find can provide significant speed-ups compared to standard optimisers, especially for deeper models. JPC also provides some analytical tools that can be used to study and diagnose issues with PCNs.","title":"Getting started"},{"location":"#installation","text":"pip install jpc Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11.2+, Diffrax 0.6.0+, and Optax 0.2.4+. For GPU usage, upgrade jax to the appropriate cuda version (12 as an example here). pip install --upgrade \"jax[cuda12]\"","title":"\ud83d\udcbb Installation"},{"location":"#quick-example","text":"Use jpc.make_pc_step to update the parameters of any neural network compatible with PC updates (see examples) import jax.random as jr import jax.numpy as jnp import equinox as eqx import optax import jpc # toy data x = jnp . array ([ 1. , 1. , 1. ]) y = - x # define model and optimiser key = jr . PRNGKey ( 0 ) model = jpc . make_mlp ( key , layer_sizes = [ 3 , 5 , 5 , 3 ], act_fn = \"relu\" ) optim = optax . adam ( 1e-3 ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) # perform one training step with PC result = jpc . make_pc_step ( model = model , optim = optim , opt_state = opt_state , output = y , input = x ) # updated model and optimiser model = result [ \"model\" ] optim , opt_state = result [ \"optim\" ], result [ \"opt_state\" ] Under the hood, jpc.make_pc_step integrates the inference (activity) dynamics using a Diffrax ODE solver, and updates model parameters at the numerical solution of the activities with a given Optax optimiser. NOTE : All convenience training and test functions including make_pc_step are already \"jitted\" (for increased performance) for the user's convenience.","title":"\u26a1\ufe0f Quick example"},{"location":"#predictive-coding-primer","text":"...","title":"\ud83e\udde0\ufe0f Predictive coding primer"},{"location":"#advanced-usage","text":"Advanced users can access all the underlying functions of jpc.make_pc_step as well as additional features. A custom PC training step looks like the following: import jpc # 1. initialise activities with a feedforward pass activities = jpc . init_activities_with_ffwd ( model = model , input = x ) # 2. run inference to equilibrium equilibrated_activities = jpc . solve_inference ( params = ( model , None ), activities = activities , output = y , input = x ) # 3. update parameters at the activities' solution with PC result = jpc . update_params ( params = ( model , None ), activities = equilibrated_activities , optim = optim , opt_state = opt_state , output = y , input = x ) which can be embedded in a jitted function with any other additional computations.","title":"\ud83d\ude80 Advanced usage"},{"location":"#citation","text":"If you found this library useful in your work, please cite (arXiv link): @article { innocenti2024jpc , title = {JPC: Flexible Inference for Predictive Coding Networks in JAX} , author = {Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Singh, Ryan and De Llanza Varona, Miguel and Buckley, Christopher} , journal = {arXiv preprint} , year = {2024} } Also consider starring the project on GitHub ! \u2b50\ufe0f","title":"\ud83d\udcc4 Citation"},{"location":"#next-steps","text":"","title":"\u23ed\ufe0f Next steps"},{"location":"FAQs/","text":"","title":"FAQs"},{"location":"advanced_usage/","text":"","title":"Advanced usage"},{"location":"basic_usage/","text":"Info JPC provides two types of API depending on the use case: a simple, basic API that allows to train and test models with predictive coding with a few lines of code a more advanced and flexible API allowing for Describe purposes/use cases of both basic and advanced. Basic usage \u00a4 JPC provides a single convenience function jpc.make_pc_step() to train predictive coding networks (PCNs) on classification and generation tasks, in a supervised as well as unsupervised manner. import jpc relu_net = jpc . get_fc_network ( key , [ 10 , 100 , 100 , 10 ], \"relu\" ) result = jpc . make_pc_step ( model = relu_net , optim = optim , opt_state = opt_state , y = y , x = x ) At a minimum, jpc.make_pc_step() takes a model, an optax optimiser and its state, and an output target. Under the hood, jpc.make_pc_step() uses diffrax to solve the activity (inference) dynamics of PC. The arguments can be changed import jpc result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , y = y , x = x , solver = other_solver , dt = 1e-1 , ) Moreover, JPC provides a similar function for training a hybrid PCN import jax import jax.numpy as jnp from equinox import nn as nn # some datasets x = jnp . array ([ 1. , 1. , 1. ]) y = - x # network key = jax . random . key ( 0 ) _ , * subkeys = jax . random . split ( key ) network = [ nn . Sequential ( [ nn . Linear ( 3 , 100 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu )], ), nn . Linear ( 100 , 3 , key = subkeys [ 1 ]), ]","title":"Basic usage"},{"location":"basic_usage/#basic-usage","text":"JPC provides a single convenience function jpc.make_pc_step() to train predictive coding networks (PCNs) on classification and generation tasks, in a supervised as well as unsupervised manner. import jpc relu_net = jpc . get_fc_network ( key , [ 10 , 100 , 100 , 10 ], \"relu\" ) result = jpc . make_pc_step ( model = relu_net , optim = optim , opt_state = opt_state , y = y , x = x ) At a minimum, jpc.make_pc_step() takes a model, an optax optimiser and its state, and an output target. Under the hood, jpc.make_pc_step() uses diffrax to solve the activity (inference) dynamics of PC. The arguments can be changed import jpc result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , y = y , x = x , solver = other_solver , dt = 1e-1 , ) Moreover, JPC provides a similar function for training a hybrid PCN import jax import jax.numpy as jnp from equinox import nn as nn # some datasets x = jnp . array ([ 1. , 1. , 1. ]) y = - x # network key = jax . random . key ( 0 ) _ , * subkeys = jax . random . split ( key ) network = [ nn . Sequential ( [ nn . Linear ( 3 , 100 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu )], ), nn . Linear ( 100 , 3 , key = subkeys [ 1 ]), ]","title":"Basic usage"},{"location":"extending_jpc/","text":"","title":"Extending JPC"},{"location":"api/Analytical%20tools/","text":"Analytical tools \u00a4 jpc . linear_equilib_energy ( network : PyTree [ equinox . nn . _linear . Linear ], x : ArrayLike , y : ArrayLike ) -> Array \u00a4 Computes the theoretical equilibrated PC energy for a deep linear network (DLN). \\[ \\mathcal{F}^* = 1/N \\sum_i^N (\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i)^T S^{-1}(\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i) \\] where the rescaling is \\(S = I_{d_y} + \\sum_{\\ell=2}^L (W_{L:\\ell})(W_{L:\\ell})^T\\) , and we use the shorthand \\(W_{L:\\ell} = W_L W_{L-1} \\dots W_\\ell\\) . See reference below. Note This expression assumes no biases. Reference @article { innocenti2024only , title = {Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?} , author = {Innocenti, Francesco and Achour, El Mehdi and Singh, Ryan and Buckley, Christopher L} , journal = {arXiv preprint arXiv:2408.11979} , year = {2024} } Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: Mean total analytical energy over a batch or dataset. jpc . linear_activities_solution ( network : PyTree [ equinox . nn . _linear . Linear ], x : ArrayLike , y : ArrayLike ) -> PyTree [ Array ] \u00a4 Computes the theoretical solution for the PC activities of a deep linear network (DLN). \\[ \\mathbf{z}^* = A^{-1} \\mathbf{b} \\] where \\(A\\) is a sparse block diagonal matrix depending only on the weights, and \\(\\mathbf{b} = [W_1 \\mathbf{x}, \\mathbf{0}, \\dots, W_L^T \\mathbf{y}]^T\\) . In particular, \\(A_{\\ell,k} = I + W_\\ell^T W_\\ell\\) if \\(\\ell = k\\) , \\(A_{\\ell,k} = -W_\\ell\\) if \\(\\ell = k+1\\) , \\(A_{\\ell,k} = -W_\\ell^T\\) if \\(\\ell = k-1\\) , and \\(\\mathbf{0}\\) otherwise, for \\(\\ell, k \\in [2, \\dots, L]\\) . Note This expression assumes no biases. Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: List of theoretical activities for each layer.","title":"Analytical tools"},{"location":"api/Analytical%20tools/#analytical-tools","text":"","title":"Analytical tools"},{"location":"api/Analytical%20tools/#jpc.linear_equilib_energy","text":"Computes the theoretical equilibrated PC energy for a deep linear network (DLN). \\[ \\mathcal{F}^* = 1/N \\sum_i^N (\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i)^T S^{-1}(\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i) \\] where the rescaling is \\(S = I_{d_y} + \\sum_{\\ell=2}^L (W_{L:\\ell})(W_{L:\\ell})^T\\) , and we use the shorthand \\(W_{L:\\ell} = W_L W_{L-1} \\dots W_\\ell\\) . See reference below. Note This expression assumes no biases. Reference @article { innocenti2024only , title = {Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?} , author = {Innocenti, Francesco and Achour, El Mehdi and Singh, Ryan and Buckley, Christopher L} , journal = {arXiv preprint arXiv:2408.11979} , year = {2024} } Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: Mean total analytical energy over a batch or dataset.","title":"linear_equilib_energy()"},{"location":"api/Analytical%20tools/#jpc.linear_activities_solution","text":"Computes the theoretical solution for the PC activities of a deep linear network (DLN). \\[ \\mathbf{z}^* = A^{-1} \\mathbf{b} \\] where \\(A\\) is a sparse block diagonal matrix depending only on the weights, and \\(\\mathbf{b} = [W_1 \\mathbf{x}, \\mathbf{0}, \\dots, W_L^T \\mathbf{y}]^T\\) . In particular, \\(A_{\\ell,k} = I + W_\\ell^T W_\\ell\\) if \\(\\ell = k\\) , \\(A_{\\ell,k} = -W_\\ell\\) if \\(\\ell = k+1\\) , \\(A_{\\ell,k} = -W_\\ell^T\\) if \\(\\ell = k-1\\) , and \\(\\mathbf{0}\\) otherwise, for \\(\\ell, k \\in [2, \\dots, L]\\) . Note This expression assumes no biases. Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: List of theoretical activities for each layer.","title":"linear_activities_solution()"},{"location":"api/Energy%20functions/","text":"Energy functions \u00a4 jpc . pc_energy_fn ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], y : ArrayLike , x : Optional [ ArrayLike ] = None , loss : str = 'MSE' , record_layers : bool = False ) -> Array | Array \u00a4 Computes the free energy for a feedforward neural network of the form \\[ \\mathcal{F}(\\mathbf{z}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}_{i, \\ell} - f_\\ell(\\mathbf{z}_{i, \\ell-1}; \u03b8) ||^2 \\] given parameters \\(\u03b8\\) , free activities \\(\\mathbf{z}\\) , output \\(\\mathbf{z}_L = \\mathbf{y}\\) and optional input \\(\\mathbf{z}_0 = \\mathbf{x}\\) for supervised training. The activity of each layer \\(\\mathbf{z}_\\ell\\) is some function of the previous layer, e.g. ReLU \\((W_\\ell \\mathbf{z}_{\\ell-1} + \\mathbf{b}_\\ell)\\) for a fully connected layer with biases and ReLU as activation. Note The input \\(x\\) and output \\(y\\) correspond to the prior and observation of the generative model, respectively. Main arguments: params : Tuple with callable model (e.g. neural network) layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model (for supervised training). Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ??? cite \"Reference\" @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } - record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by the batch size. jpc . hpc_energy_fn ( model : PyTree [ typing . Callable ], equilib_activities : PyTree [ ArrayLike ], amort_activities : PyTree [ ArrayLike ], x : ArrayLike , y : Optional [ ArrayLike ] = None , record_layers : bool = False ) -> Array | Array \u00a4 Computes the free energy of an amortised PC network \\[ \\mathcal{F}(\\mathbf{z}^*, \\hat{\\mathbf{z}}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}^*_{i, \\ell} - f_\\ell(\\hat{\\mathbf{z}}_{i, \\ell-1}; \u03b8) ||^2 \\] given the equilibrated activities of the generator \\(\\mathbf{z}^*\\) (target for the amortiser), the feedforward guesses of the amortiser \\(\\hat{\\mathbf{z}}\\) , the amortiser's parameters \\(\u03b8\\) , input \\(\\mathbf{z}_0 = \\mathbf{x}\\) , and optional output \\(\\mathbf{z}_L = \\mathbf{y}\\) for supervised training. Note The input \\(x\\) and output \\(y\\) are reversed compared to pc_energy_fn ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Other arguments: record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by batch size.","title":"Energy functions"},{"location":"api/Energy%20functions/#energy-functions","text":"","title":"Energy functions"},{"location":"api/Energy%20functions/#jpc.pc_energy_fn","text":"Computes the free energy for a feedforward neural network of the form \\[ \\mathcal{F}(\\mathbf{z}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}_{i, \\ell} - f_\\ell(\\mathbf{z}_{i, \\ell-1}; \u03b8) ||^2 \\] given parameters \\(\u03b8\\) , free activities \\(\\mathbf{z}\\) , output \\(\\mathbf{z}_L = \\mathbf{y}\\) and optional input \\(\\mathbf{z}_0 = \\mathbf{x}\\) for supervised training. The activity of each layer \\(\\mathbf{z}_\\ell\\) is some function of the previous layer, e.g. ReLU \\((W_\\ell \\mathbf{z}_{\\ell-1} + \\mathbf{b}_\\ell)\\) for a fully connected layer with biases and ReLU as activation. Note The input \\(x\\) and output \\(y\\) correspond to the prior and observation of the generative model, respectively. Main arguments: params : Tuple with callable model (e.g. neural network) layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model (for supervised training). Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ??? cite \"Reference\" @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } - record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by the batch size.","title":"pc_energy_fn()"},{"location":"api/Energy%20functions/#jpc.hpc_energy_fn","text":"Computes the free energy of an amortised PC network \\[ \\mathcal{F}(\\mathbf{z}^*, \\hat{\\mathbf{z}}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}^*_{i, \\ell} - f_\\ell(\\hat{\\mathbf{z}}_{i, \\ell-1}; \u03b8) ||^2 \\] given the equilibrated activities of the generator \\(\\mathbf{z}^*\\) (target for the amortiser), the feedforward guesses of the amortiser \\(\\hat{\\mathbf{z}}\\) , the amortiser's parameters \\(\u03b8\\) , input \\(\\mathbf{z}_0 = \\mathbf{x}\\) , and optional output \\(\\mathbf{z}_L = \\mathbf{y}\\) for supervised training. Note The input \\(x\\) and output \\(y\\) are reversed compared to pc_energy_fn ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Other arguments: record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by batch size.","title":"hpc_energy_fn()"},{"location":"api/Gradients/","text":"Gradients \u00a4 jpc . neg_activity_grad ( t : float | int , activities : PyTree [ ArrayLike ], args : Tuple [ Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], ArrayLike , Optional [ ArrayLike ], str , diffrax . _step_size_controller . base . AbstractStepSizeController ]) -> PyTree [ Array ] \u00a4 Computes the negative gradient of the energy with respect to the activities \\(- \\partial \\mathcal{F} / \\partial \\mathbf{z}\\) . This defines an ODE system to be integrated by solve_pc_inference . Main arguments: t : Time step of the ODE system, used for downstream integration by diffrax.diffeqsolve . activities : List of activities for each layer free to vary. args : 5-Tuple with (i) Tuple with callable model layers and optional skip connections, (ii) network output (observation), (iii) network input (prior), (iv) Loss specified at the output layer (MSE vs cross-entropy), and (v) diffrax controller for step size integration. Returns: List of negative gradients of the energy w.r.t. the activities. jpc . compute_pc_param_grads ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], y : ArrayLike , x : Optional [ ArrayLike ] = None , loss_id : str = 'MSE' ) -> Tuple [ PyTree [ Array ], PyTree [ Array ]] \u00a4 Computes the gradient of the PC energy with respect to model parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). Returns: List of parameter gradients for each network layer. jpc . compute_hpc_param_grads ( model : PyTree [ typing . Callable ], equilib_activities : PyTree [ ArrayLike ], amort_activities : PyTree [ ArrayLike ], x : ArrayLike , y : Optional [ ArrayLike ] = None ) -> PyTree [ Array ] \u00a4 Computes the gradient of the hybrid energy with respect to an amortiser's parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Note The input \\(x\\) and output \\(y\\) are reversed compared to compute_pc_param_grads ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Returns: List of parameter gradients for each network layer.","title":"Gradients"},{"location":"api/Gradients/#gradients","text":"","title":"Gradients"},{"location":"api/Gradients/#jpc.neg_activity_grad","text":"Computes the negative gradient of the energy with respect to the activities \\(- \\partial \\mathcal{F} / \\partial \\mathbf{z}\\) . This defines an ODE system to be integrated by solve_pc_inference . Main arguments: t : Time step of the ODE system, used for downstream integration by diffrax.diffeqsolve . activities : List of activities for each layer free to vary. args : 5-Tuple with (i) Tuple with callable model layers and optional skip connections, (ii) network output (observation), (iii) network input (prior), (iv) Loss specified at the output layer (MSE vs cross-entropy), and (v) diffrax controller for step size integration. Returns: List of negative gradients of the energy w.r.t. the activities.","title":"neg_activity_grad()"},{"location":"api/Gradients/#jpc.compute_pc_param_grads","text":"Computes the gradient of the PC energy with respect to model parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). Returns: List of parameter gradients for each network layer.","title":"compute_pc_param_grads()"},{"location":"api/Gradients/#jpc.compute_hpc_param_grads","text":"Computes the gradient of the hybrid energy with respect to an amortiser's parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Note The input \\(x\\) and output \\(y\\) are reversed compared to compute_pc_param_grads ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Returns: List of parameter gradients for each network layer.","title":"compute_hpc_param_grads()"},{"location":"api/Inference/","text":"Inference \u00a4 jpc . solve_inference ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], output : ArrayLike , input : Optional [ ArrayLike ] = None , loss_id : str = 'MSE' , solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 20 , dt : float | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), record_iters : bool = False , record_every : int = None ) -> PyTree [ Array ] \u00a4 Solves the inference (activity) dynamics of a predictive coding network. This is a wrapper around diffrax.diffeqsolve to integrate the gradient ODE system _neg_activity_grad defining the PC inference dynamics \\[ \\partial \\mathbf{z} / \\partial t = - \\partial \\mathcal{F} / \\partial \\mathbf{z} \\] where \\(\\mathcal{F}\\) is the free energy, \\(\\mathbf{z}\\) are the activities, with \\(\\mathbf{z}_L\\) clamped to some target and \\(\\mathbf{z}_0\\) optionally equal to some prior. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). solver : Diffrax (ODE) solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_iters : If True , returns all integration steps. record_every : int determining the sampling frequency the integration steps. Returns: List with solution of the activity dynamics for each layer.","title":"Inference"},{"location":"api/Inference/#inference","text":"","title":"Inference"},{"location":"api/Inference/#jpc.solve_inference","text":"Solves the inference (activity) dynamics of a predictive coding network. This is a wrapper around diffrax.diffeqsolve to integrate the gradient ODE system _neg_activity_grad defining the PC inference dynamics \\[ \\partial \\mathbf{z} / \\partial t = - \\partial \\mathcal{F} / \\partial \\mathbf{z} \\] where \\(\\mathcal{F}\\) is the free energy, \\(\\mathbf{z}\\) are the activities, with \\(\\mathbf{z}_L\\) clamped to some target and \\(\\mathbf{z}_0\\) optionally equal to some prior. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). solver : Diffrax (ODE) solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_iters : If True , returns all integration steps. record_every : int determining the sampling frequency the integration steps. Returns: List with solution of the activity dynamics for each layer.","title":"solve_inference()"},{"location":"api/Initialisation/","text":"Initialisation \u00a4 jpc . init_activities_with_ffwd ( model : PyTree [ typing . Callable ], input : ArrayLike , skip_model : Optional [ PyTree [ Callable ]] = None ) -> PyTree [ Array ] \u00a4 Initialises layers' activity with a feedforward pass \\(\\{ f_\\ell(\\mathbf{z}_{\\ell-1}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{x}\\) is the input. Main arguments: model : List of callable model (e.g. neural network) layers. input : input to the model. Other arguments: skip_model : Optional skip connection model. Returns: List with activity values of each layer. jpc . init_activities_from_normal ( key : PRNGKeyArray , layer_sizes : PyTree [ int ], mode : str , batch_size : int , sigma : Array = 0.05 ) -> PyTree [ Array ] \u00a4 Initialises network activities from a zero-mean Gaussian \\(\\sim \\mathcal{N}(0, \\sigma^2)\\) . Main arguments: key : jax.random.PRNGKey for sampling. layer_sizes : List with dimension of all layers (input, hidden and output). mode : If supervised , all hidden layers are initialised. If unsupervised the input layer \\(\\mathbf{z}_0\\) is also initialised. batch_size : Dimension of data batch. sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. Returns: List of randomly initialised activities for each layer. jpc . init_activities_with_amort ( amortiser : PyTree [ typing . Callable ], generator : PyTree [ typing . Callable ], input : ArrayLike ) -> PyTree [ Array ] \u00a4 Initialises layers' activity with an amortised network \\(\\{ f_{L-\\ell+1}(\\mathbf{z}_{L-\\ell}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{y}\\) is the input or generator's target. Note The output order is reversed for downstream use by the generator. Main arguments: amortiser : List of callable layers for model amortising the inference of the generator . generator : List of callable layers for the generative model. input : Input to the amortiser. Returns: List with amortised initialisation of each layer.","title":"Initialisation"},{"location":"api/Initialisation/#initialisation","text":"","title":"Initialisation"},{"location":"api/Initialisation/#jpc.init_activities_with_ffwd","text":"Initialises layers' activity with a feedforward pass \\(\\{ f_\\ell(\\mathbf{z}_{\\ell-1}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{x}\\) is the input. Main arguments: model : List of callable model (e.g. neural network) layers. input : input to the model. Other arguments: skip_model : Optional skip connection model. Returns: List with activity values of each layer.","title":"init_activities_with_ffwd()"},{"location":"api/Initialisation/#jpc.init_activities_from_normal","text":"Initialises network activities from a zero-mean Gaussian \\(\\sim \\mathcal{N}(0, \\sigma^2)\\) . Main arguments: key : jax.random.PRNGKey for sampling. layer_sizes : List with dimension of all layers (input, hidden and output). mode : If supervised , all hidden layers are initialised. If unsupervised the input layer \\(\\mathbf{z}_0\\) is also initialised. batch_size : Dimension of data batch. sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. Returns: List of randomly initialised activities for each layer.","title":"init_activities_from_normal()"},{"location":"api/Initialisation/#jpc.init_activities_with_amort","text":"Initialises layers' activity with an amortised network \\(\\{ f_{L-\\ell+1}(\\mathbf{z}_{L-\\ell}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{y}\\) is the input or generator's target. Note The output order is reversed for downstream use by the generator. Main arguments: amortiser : List of callable layers for model amortising the inference of the generator . generator : List of callable layers for the generative model. input : Input to the amortiser. Returns: List with amortised initialisation of each layer.","title":"init_activities_with_amort()"},{"location":"api/Testing/","text":"Testing \u00a4 jpc . test_discriminative_pc ( model : PyTree [ typing . Callable ], output : ArrayLike , input : ArrayLike , loss : str = 'MSE' , skip_model : Optional [ PyTree [ Callable ]] = None ) -> Tuple [ Array , Array ] \u00a4 Computes test metrics for a discriminative predictive coding network. Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : - loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). skip_model : Optional list of callable skip connection functions. Returns: Test loss and accuracy of output predictions. jpc . test_generative_pc ( model : PyTree [ typing . Callable ], output : ArrayLike , input : ArrayLike , key : PRNGKeyArray , layer_sizes : PyTree [ int ], batch_size : int , sigma : Array = 0.05 , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 500 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), skip_model : Optional [ PyTree [ Callable ]] = None ) -> Tuple [ Array , Array ] \u00a4 Computes test metrics for a generative predictive coding network. Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image). Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracy and output predictions. jpc . test_hpc ( generator : PyTree [ typing . Callable ], amortiser : PyTree [ typing . Callable ], output : ArrayLike , input : ArrayLike , key : PRNGKeyArray , layer_sizes : PyTree [ int ], batch_size : int , sigma : Array = 0.05 , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 500 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None )) -> Tuple [ Array , Array , Array , Array ] \u00a4 Computes test metrics for hybrid predictive coding trained in a supervised manner. Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator. Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . output : Observation or target of the generative model. input : Optional prior of the generator, target for the amortiser. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for initialisation of activities. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracies of all models and output predictions.","title":"Testing"},{"location":"api/Testing/#testing","text":"","title":"Testing"},{"location":"api/Testing/#jpc.test_discriminative_pc","text":"Computes test metrics for a discriminative predictive coding network. Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : - loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). skip_model : Optional list of callable skip connection functions. Returns: Test loss and accuracy of output predictions.","title":"test_discriminative_pc()"},{"location":"api/Testing/#jpc.test_generative_pc","text":"Computes test metrics for a generative predictive coding network. Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image). Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracy and output predictions.","title":"test_generative_pc()"},{"location":"api/Testing/#jpc.test_hpc","text":"Computes test metrics for hybrid predictive coding trained in a supervised manner. Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator. Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . output : Observation or target of the generative model. input : Optional prior of the generator, target for the amortiser. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for initialisation of activities. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracies of all models and output predictions.","title":"test_hpc()"},{"location":"api/Training/","text":"Training \u00a4 jpc . make_pc_step ( model : PyTree [ typing . Callable ], optim : optax . _src . base . GradientTransformation | optax . _src . base . GradientTransformationExtraArgs , opt_state : Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]], output : ArrayLike , input : Optional [ ArrayLike ] = None , loss_id : str = 'MSE' , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 20 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), skip_model : Optional [ PyTree [ Callable ]] = None , key : Optional [ PRNGKeyArray ] = None , layer_sizes : Optional [ PyTree [ int ]] = None , batch_size : Optional [ int ] = None , sigma : Array = 0.05 , record_activities : bool = False , record_energies : bool = False , record_every : int = None , activity_norms : bool = False , param_norms : bool = False , grad_norms : bool = False , calculate_accuracy : bool = False ) -> Dict \u00a4 Updates network parameters with predictive coding. Main arguments: model : List of callable model (e.g. neural network) layers. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Note key , layer_sizes and batch_size must be passed if input is None , since unsupervised training will be assumed and activities need to be initialised randomly. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (20 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. skip_model : Optional list of callable skip connection functions. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. sigma : Standard deviation for Gaussian to sample activities from for random initialisation. Defaults to 5e-2. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. record_every : int determining the sampling frequency the integration steps. activity_norms : If True , computes l2 norm of the activities. param_norms : If True , computes l2 norm of the parameters. grad_norms : If True , computes l2 norm of parameter gradients. calculate_accuracy : If True , computes the training accuracy. Returns: Dict including model (and optional skip model) with updated parameters, optimiser, updated optimiser state, loss, energies, activities, and optionally other metrics (see other args above). Raises: ValueError for inconsistent inputs and invalid losses. jpc . make_hpc_step ( generator : PyTree [ typing . Callable ], amortiser : PyTree [ typing . Callable ], optims : Tuple [ optax . _src . base . GradientTransformationExtraArgs ], opt_states : Tuple [ Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]]], output : ArrayLike , input : Optional [ ArrayLike ] = None , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 300 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), record_activities : bool = False , record_energies : bool = False ) -> Dict \u00a4 Updates parameters of a hybrid predictive coding network. Reference @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . optims : Optax optimisers (e.g. optax.sgd() ), one for each model. opt_states : State of Optax optimisers, one for each model. output : Observation of the generator, input to the amortiser. input : Optional prior of the generator, target for the amortiser. Other arguments: ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.. max_t1 : Maximum end of integration region (300 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. Returns: Dict including models with updated parameters, optimiser and state for each model, model activities, last inference step for the generator, MSE losses, and energies.","title":"Training"},{"location":"api/Training/#training","text":"","title":"Training"},{"location":"api/Training/#jpc.make_pc_step","text":"Updates network parameters with predictive coding. Main arguments: model : List of callable model (e.g. neural network) layers. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Note key , layer_sizes and batch_size must be passed if input is None , since unsupervised training will be assumed and activities need to be initialised randomly. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (20 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. skip_model : Optional list of callable skip connection functions. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. sigma : Standard deviation for Gaussian to sample activities from for random initialisation. Defaults to 5e-2. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. record_every : int determining the sampling frequency the integration steps. activity_norms : If True , computes l2 norm of the activities. param_norms : If True , computes l2 norm of the parameters. grad_norms : If True , computes l2 norm of parameter gradients. calculate_accuracy : If True , computes the training accuracy. Returns: Dict including model (and optional skip model) with updated parameters, optimiser, updated optimiser state, loss, energies, activities, and optionally other metrics (see other args above). Raises: ValueError for inconsistent inputs and invalid losses.","title":"make_pc_step()"},{"location":"api/Training/#jpc.make_hpc_step","text":"Updates parameters of a hybrid predictive coding network. Reference @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . optims : Optax optimisers (e.g. optax.sgd() ), one for each model. opt_states : State of Optax optimisers, one for each model. output : Observation of the generator, input to the amortiser. input : Optional prior of the generator, target for the amortiser. Other arguments: ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.. max_t1 : Maximum end of integration region (300 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. Returns: Dict including models with updated parameters, optimiser and state for each model, model activities, last inference step for the generator, MSE losses, and energies.","title":"make_hpc_step()"},{"location":"api/Updates/","text":"Updates \u00a4 jpc . update_activities ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], optim : optax . _src . base . GradientTransformation | optax . _src . base . GradientTransformationExtraArgs , opt_state : Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]], output : ArrayLike , input : Optional [ ArrayLike ] = None ) -> Dict \u00a4 Updates activities of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with energy, updated activities, activity gradients, optimiser, and updated optimiser state. jpc . update_params ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], optim : optax . _src . base . GradientTransformation | optax . _src . base . GradientTransformationExtraArgs , opt_state : Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]], output : ArrayLike , input : Optional [ ArrayLike ] = None ) -> Dict \u00a4 Updates parameters of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with model (and optional skip model) with updated parameters, parameter gradients, optimiser, and updated optimiser state.","title":"Updates"},{"location":"api/Updates/#updates","text":"","title":"Updates"},{"location":"api/Updates/#jpc.update_activities","text":"Updates activities of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with energy, updated activities, activity gradients, optimiser, and updated optimiser state.","title":"update_activities()"},{"location":"api/Updates/#jpc.update_params","text":"Updates parameters of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with model (and optional skip model) with updated parameters, parameter gradients, optimiser, and updated optimiser state.","title":"update_params()"},{"location":"api/make_mlp/","text":"make_mlp \u00a4 jpc . make_mlp ( key : PRNGKeyArray , layer_sizes : PyTree [ int ], act_fn : str , use_bias : bool = True ) -> PyTree [ typing . Callable ] \u00a4 Creates a multi-layer perceptron compatible with predictive coding updates. Main arguments: key : jax.random.PRNGKey for parameter initialisation. layer_sizes : Dimension of all layers (input, hidden and output). Options are linear , tanh and relu . act_fn : Activation function for all layers except the output. use_bias : True by default. Returns: List of callable fully connected layers.","title":"make_mlp"},{"location":"api/make_mlp/#make_mlp","text":"","title":"make_mlp"},{"location":"api/make_mlp/#jpc.make_mlp","text":"Creates a multi-layer perceptron compatible with predictive coding updates. Main arguments: key : jax.random.PRNGKey for parameter initialisation. layer_sizes : Dimension of all layers (input, hidden and output). Options are linear , tanh and relu . act_fn : Activation function for all layers except the output. use_bias : True by default. Returns: List of callable fully connected layers.","title":"make_mlp()"},{"location":"examples/discriminative_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Discriminative PC on MNIST \u00a4 This notebook demonstrates how to train a neural network with predictive coding (PC) to discriminate or classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 import jpc import jax import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 784 , 300 , 300 , 10 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300 Dataset \u00a4 Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True ), Lambda(fn=Identity()) ) )] Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , input = img_batch , output = label_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break Run \u00a4 import warnings with warnings . catch_warnings (): warnings . simplefilter ( 'ignore' ) train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 100, train loss=0.018149, avg test accuracy=93.790062 Train iter 200, train loss=0.012088, avg test accuracy=95.142227 Train iter 300, train loss=0.016424, avg test accuracy=95.723160","title":"Discriminative PC"},{"location":"examples/discriminative_pc/#discriminative-pc-on-mnist","text":"This notebook demonstrates how to train a neural network with predictive coding (PC) to discriminate or classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 import jpc import jax import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Discriminative PC on MNIST"},{"location":"examples/discriminative_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 784 , 300 , 300 , 10 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300","title":"Hyperparameters"},{"location":"examples/discriminative_pc/#dataset","text":"Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ]","title":"Dataset"},{"location":"examples/discriminative_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/discriminative_pc/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , input = img_batch , output = label_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break","title":"Train and test"},{"location":"examples/discriminative_pc/#run","text":"import warnings with warnings . catch_warnings (): warnings . simplefilter ( 'ignore' ) train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 100, train loss=0.018149, avg test accuracy=93.790062 Train iter 200, train loss=0.012088, avg test accuracy=95.142227 Train iter 300, train loss=0.016424, avg test accuracy=95.723160","title":"Run"},{"location":"examples/hybrid_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Hybrid PC on MNIST \u00a4 This notebook demonstrates how to train a hybrid predictive coding network that can both generate and classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 50 TEST_EVERY = 100 N_TRAIN_ITERS = 300 Dataset \u00a4 Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_imgs ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ]) plt . show () Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )] Train and test \u00a4 A hybrid PC network can be trained in a single line of code with jpc.make_hpc_step() . See the documentation for more. Similarly, we can use jpc.test_hpc() to compute different test metrics. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ): amort_accs , hpc_accs , gen_accs = 0 , 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () amort_acc , hpc_acc , gen_acc , img_preds = jpc . test_hpc ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , generator = generator , amortiser = amortiser , input = label_batch , output = img_batch ) amort_accs += amort_acc hpc_accs += hpc_acc gen_accs += gen_acc return ( amort_accs / len ( test_loader ), hpc_accs / len ( test_loader ), gen_accs / len ( test_loader ), label_batch , img_preds ) def train ( seed , layer_sizes , act_fn , batch_size , lr , max_t1 , test_every , n_train_iters ): key = jax . random . PRNGKey ( seed ) key , * subkey = jax . random . split ( key , 3 ) generator = jpc . make_mlp ( subkey [ 0 ], layer_sizes , act_fn ) amortiser = jpc . make_mlp ( subkey [ 1 ], layer_sizes [:: - 1 ], act_fn ) gen_optim = optax . adam ( lr ) amort_optim = optax . adam ( lr ) optims = [ gen_optim , amort_optim ] gen_opt_state = gen_optim . init ( ( eqx . filter ( generator , eqx . is_array ), None ) ) amort_opt_state = amort_optim . init ( eqx . filter ( amortiser , eqx . is_array )) opt_states = [ gen_opt_state , amort_opt_state ] train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_hpc_step ( generator = generator , amortiser = amortiser , optims = optims , opt_states = opt_states , input = label_batch , output = img_batch , max_t1 = max_t1 ) generator , amortiser = result [ \"generator\" ], result [ \"amortiser\" ] optims , opt_states = result [ \"optims\" ], result [ \"opt_states\" ] gen_loss , amort_loss = result [ \"losses\" ] if (( iter + 1 ) % test_every ) == 0 : amort_acc , hpc_acc , gen_acc , label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ) print ( f \"Iter { iter + 1 } , gen loss= { gen_loss : 4f } , \" f \"amort loss= { amort_loss : 4f } , \" f \"avg amort test accuracy= { amort_acc : 4f } , \" f \"avg hpc test accuracy= { hpc_acc : 4f } , \" f \"avg gen test accuracy= { gen_acc : 4f } , \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_imgs ( img_preds , label_batch ) return amortiser , generator Run \u00a4 network = train ( seed = SEED , layer_sizes = LAYER_SIZES , act_fn = ACT_FN , batch_size = BATCH_SIZE , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Iter 100, gen loss=0.558071, amort loss=0.056306, avg amort test accuracy=76.782852, avg hpc test accuracy=79.727562, avg gen test accuracy=78.675880, Iter 200, gen loss=0.622492, amort loss=0.039034, avg amort test accuracy=83.503609, avg hpc test accuracy=81.740784, avg gen test accuracy=81.109779, Iter 300, gen loss=0.548741, amort loss=0.039427, avg amort test accuracy=85.987579, avg hpc test accuracy=82.311699, avg gen test accuracy=81.209938,","title":"Hybrid PC"},{"location":"examples/hybrid_pc/#hybrid-pc-on-mnist","text":"This notebook demonstrates how to train a hybrid predictive coding network that can both generate and classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Hybrid PC on MNIST"},{"location":"examples/hybrid_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 50 TEST_EVERY = 100 N_TRAIN_ITERS = 300","title":"Hyperparameters"},{"location":"examples/hybrid_pc/#dataset","text":"Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_imgs ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ]) plt . show ()","title":"Dataset"},{"location":"examples/hybrid_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/hybrid_pc/#train-and-test","text":"A hybrid PC network can be trained in a single line of code with jpc.make_hpc_step() . See the documentation for more. Similarly, we can use jpc.test_hpc() to compute different test metrics. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ): amort_accs , hpc_accs , gen_accs = 0 , 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () amort_acc , hpc_acc , gen_acc , img_preds = jpc . test_hpc ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , generator = generator , amortiser = amortiser , input = label_batch , output = img_batch ) amort_accs += amort_acc hpc_accs += hpc_acc gen_accs += gen_acc return ( amort_accs / len ( test_loader ), hpc_accs / len ( test_loader ), gen_accs / len ( test_loader ), label_batch , img_preds ) def train ( seed , layer_sizes , act_fn , batch_size , lr , max_t1 , test_every , n_train_iters ): key = jax . random . PRNGKey ( seed ) key , * subkey = jax . random . split ( key , 3 ) generator = jpc . make_mlp ( subkey [ 0 ], layer_sizes , act_fn ) amortiser = jpc . make_mlp ( subkey [ 1 ], layer_sizes [:: - 1 ], act_fn ) gen_optim = optax . adam ( lr ) amort_optim = optax . adam ( lr ) optims = [ gen_optim , amort_optim ] gen_opt_state = gen_optim . init ( ( eqx . filter ( generator , eqx . is_array ), None ) ) amort_opt_state = amort_optim . init ( eqx . filter ( amortiser , eqx . is_array )) opt_states = [ gen_opt_state , amort_opt_state ] train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_hpc_step ( generator = generator , amortiser = amortiser , optims = optims , opt_states = opt_states , input = label_batch , output = img_batch , max_t1 = max_t1 ) generator , amortiser = result [ \"generator\" ], result [ \"amortiser\" ] optims , opt_states = result [ \"optims\" ], result [ \"opt_states\" ] gen_loss , amort_loss = result [ \"losses\" ] if (( iter + 1 ) % test_every ) == 0 : amort_acc , hpc_acc , gen_acc , label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ) print ( f \"Iter { iter + 1 } , gen loss= { gen_loss : 4f } , \" f \"amort loss= { amort_loss : 4f } , \" f \"avg amort test accuracy= { amort_acc : 4f } , \" f \"avg hpc test accuracy= { hpc_acc : 4f } , \" f \"avg gen test accuracy= { gen_acc : 4f } , \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_imgs ( img_preds , label_batch ) return amortiser , generator","title":"Train and test"},{"location":"examples/hybrid_pc/#run","text":"network = train ( seed = SEED , layer_sizes = LAYER_SIZES , act_fn = ACT_FN , batch_size = BATCH_SIZE , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Iter 100, gen loss=0.558071, amort loss=0.056306, avg amort test accuracy=76.782852, avg hpc test accuracy=79.727562, avg gen test accuracy=78.675880, Iter 200, gen loss=0.622492, amort loss=0.039034, avg amort test accuracy=83.503609, avg hpc test accuracy=81.740784, avg gen test accuracy=81.109779, Iter 300, gen loss=0.548741, amort loss=0.039427, avg amort test accuracy=85.987579, avg hpc test accuracy=82.311699, avg gen test accuracy=81.209938,","title":"Run"},{"location":"examples/linear_net_theoretical_energy/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Theoretical energy of deep linear networks \u00a4 %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install plotly == 5.11.0 ! pip install - U kaleido import jpc import jax from jax import vmap import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import plotly.graph_objs as go import plotly.io as pio pio . renderers . default = 'iframe' import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 300 TEST_EVERY = 10 N_TRAIN_ITERS = 100 Dataset \u00a4 Some utils to fetch MNIST. def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] Plotting \u00a4 def plot_total_energies ( energies ): n_train_iters = len ( energies [ \"theory\" ]) train_iters = [ b + 1 for b in range ( n_train_iters )] fig = go . Figure () for energy_type , energy in energies . items (): is_theory = energy_type == \"theory\" fig . add_traces ( go . Scatter ( x = train_iters , y = energy , name = energy_type , mode = \"lines\" , line = dict ( width = 3 , dash = \"dash\" if is_theory else \"solid\" , color = \"rgb(27, 158, 119)\" if is_theory else \"#00CC96\" ), legendrank = 1 if is_theory else 2 ) ) fig . update_layout ( height = 300 , width = 450 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ), ) fig . write_image ( \"dln_total_energy.pdf\" ) return fig Linear network \u00a4 key = jax . random . PRNGKey ( 0 ) width , n_hidden = 300 , 10 network = jpc . make_mlp ( key , [ 784 ] + [ width ] * n_hidden + [ 10 ], act_fn = \"linear\" , use_bias = False ) Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , output = label_batch , input = img_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) num_total_energies , theory_total_energies = [], [] for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () theory_total_energies . append ( jpc . linear_equilib_energy ( network = model , x = img_batch , y = label_batch ) ) result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch , max_t1 = max_t1 , record_energies = True ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss , t_max = result [ \"loss\" ], result [ \"t_max\" ] num_total_energies . append ( result [ \"energies\" ][:, t_max - 1 ] . sum ()) if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break return { \"experiment\" : jnp . array ( num_total_energies ), \"theory\" : jnp . array ( theory_total_energies ) } Run \u00a4 energies = train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_total_energies ( energies ) Train iter 10, train loss=0.067622, avg test accuracy=59.535255 Train iter 20, train loss=0.054068, avg test accuracy=75.230370 Train iter 30, train loss=0.063356, avg test accuracy=77.453926 Train iter 40, train loss=0.051848, avg test accuracy=80.048080 Train iter 50, train loss=0.061488, avg test accuracy=82.061295 Train iter 60, train loss=0.044830, avg test accuracy=80.789261 Train iter 70, train loss=0.045716, avg test accuracy=84.174683 Train iter 80, train loss=0.053921, avg test accuracy=82.041267 Train iter 90, train loss=0.040125, avg test accuracy=83.072914 Train iter 100, train loss=0.050980, avg test accuracy=83.974358","title":"Linear theoretical energy"},{"location":"examples/linear_net_theoretical_energy/#theoretical-energy-of-deep-linear-networks","text":"%% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install plotly == 5.11.0 ! pip install - U kaleido import jpc import jax from jax import vmap import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import plotly.graph_objs as go import plotly.io as pio pio . renderers . default = 'iframe' import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Theoretical energy of deep linear networks"},{"location":"examples/linear_net_theoretical_energy/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 300 TEST_EVERY = 10 N_TRAIN_ITERS = 100","title":"Hyperparameters"},{"location":"examples/linear_net_theoretical_energy/#dataset","text":"Some utils to fetch MNIST. def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ]","title":"Dataset"},{"location":"examples/linear_net_theoretical_energy/#plotting","text":"def plot_total_energies ( energies ): n_train_iters = len ( energies [ \"theory\" ]) train_iters = [ b + 1 for b in range ( n_train_iters )] fig = go . Figure () for energy_type , energy in energies . items (): is_theory = energy_type == \"theory\" fig . add_traces ( go . Scatter ( x = train_iters , y = energy , name = energy_type , mode = \"lines\" , line = dict ( width = 3 , dash = \"dash\" if is_theory else \"solid\" , color = \"rgb(27, 158, 119)\" if is_theory else \"#00CC96\" ), legendrank = 1 if is_theory else 2 ) ) fig . update_layout ( height = 300 , width = 450 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ), ) fig . write_image ( \"dln_total_energy.pdf\" ) return fig","title":"Plotting"},{"location":"examples/linear_net_theoretical_energy/#linear-network","text":"key = jax . random . PRNGKey ( 0 ) width , n_hidden = 300 , 10 network = jpc . make_mlp ( key , [ 784 ] + [ width ] * n_hidden + [ 10 ], act_fn = \"linear\" , use_bias = False )","title":"Linear network"},{"location":"examples/linear_net_theoretical_energy/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , output = label_batch , input = img_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) num_total_energies , theory_total_energies = [], [] for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () theory_total_energies . append ( jpc . linear_equilib_energy ( network = model , x = img_batch , y = label_batch ) ) result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch , max_t1 = max_t1 , record_energies = True ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss , t_max = result [ \"loss\" ], result [ \"t_max\" ] num_total_energies . append ( result [ \"energies\" ][:, t_max - 1 ] . sum ()) if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break return { \"experiment\" : jnp . array ( num_total_energies ), \"theory\" : jnp . array ( theory_total_energies ) }","title":"Train and test"},{"location":"examples/linear_net_theoretical_energy/#run","text":"energies = train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_total_energies ( energies ) Train iter 10, train loss=0.067622, avg test accuracy=59.535255 Train iter 20, train loss=0.054068, avg test accuracy=75.230370 Train iter 30, train loss=0.063356, avg test accuracy=77.453926 Train iter 40, train loss=0.051848, avg test accuracy=80.048080 Train iter 50, train loss=0.061488, avg test accuracy=82.061295 Train iter 60, train loss=0.044830, avg test accuracy=80.789261 Train iter 70, train loss=0.045716, avg test accuracy=84.174683 Train iter 80, train loss=0.053921, avg test accuracy=82.041267 Train iter 90, train loss=0.040125, avg test accuracy=83.072914 Train iter 100, train loss=0.050980, avg test accuracy=83.974358","title":"Run"},{"location":"examples/supervised_generative_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Generative PC on MNIST \u00a4 This notebook demonstrates how to train a neural network with predictive coding to generate MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn from diffrax import Heun import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 TEST_EVERY = 50 N_TRAIN_ITERS = 200 Dataset \u00a4 Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_img_preds ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ], fontsize = 16 ) plt . show () Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )] Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_generative_pc() to get some test metrics including accuracy of inferred labels and image predictions. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. Note that to train in an unsupervised way, you can simply remove the input from jpc.make_pc_step() and the evaluate() script. def evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 ): test_acc = 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () acc , img_preds = jpc . test_generative_pc ( model = network , input = label_batch , output = img_batch , key = key , layer_sizes = layer_sizes , batch_size = batch_size , max_t1 = max_t1 ) test_acc += acc avg_test_acc = test_acc / len ( test_loader ) return avg_test_acc , label_batch , img_preds def train ( key , layer_sizes , batch_size , network , lr , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , input = label_batch , output = img_batch , max_t1 = max_t1 ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_acc , test_label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 = max_t1 ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_img_preds ( img_preds , test_label_batch ) return network Run \u00a4 network = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 50, train loss=0.631369, avg test accuracy=74.959938 Train iter 100, train loss=0.607500, avg test accuracy=79.206734 Train iter 150, train loss=0.577637, avg test accuracy=80.418671 Train iter 200, train loss=0.555235, avg test accuracy=79.236778","title":"Supervised generative PC"},{"location":"examples/supervised_generative_pc/#generative-pc-on-mnist","text":"This notebook demonstrates how to train a neural network with predictive coding to generate MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn from diffrax import Heun import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Generative PC on MNIST"},{"location":"examples/supervised_generative_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 TEST_EVERY = 50 N_TRAIN_ITERS = 200","title":"Hyperparameters"},{"location":"examples/supervised_generative_pc/#dataset","text":"Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_img_preds ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ], fontsize = 16 ) plt . show ()","title":"Dataset"},{"location":"examples/supervised_generative_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/supervised_generative_pc/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_generative_pc() to get some test metrics including accuracy of inferred labels and image predictions. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. Note that to train in an unsupervised way, you can simply remove the input from jpc.make_pc_step() and the evaluate() script. def evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 ): test_acc = 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () acc , img_preds = jpc . test_generative_pc ( model = network , input = label_batch , output = img_batch , key = key , layer_sizes = layer_sizes , batch_size = batch_size , max_t1 = max_t1 ) test_acc += acc avg_test_acc = test_acc / len ( test_loader ) return avg_test_acc , label_batch , img_preds def train ( key , layer_sizes , batch_size , network , lr , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , input = label_batch , output = img_batch , max_t1 = max_t1 ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_acc , test_label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 = max_t1 ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_img_preds ( img_preds , test_label_batch ) return network","title":"Train and test"},{"location":"examples/supervised_generative_pc/#run","text":"network = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 50, train loss=0.631369, avg test accuracy=74.959938 Train iter 100, train loss=0.607500, avg test accuracy=79.206734 Train iter 150, train loss=0.577637, avg test accuracy=80.418671 Train iter 200, train loss=0.555235, avg test accuracy=79.236778","title":"Run"},{"location":"examples/unsupervised_generative_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Unsupervised generative PC on MNIST \u00a4 This notebook demonstrates how to train a neural network with predictive coding to encode MNIST digits in an unsupervised manner. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import matplotlib.colors as mcolors import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 50 , 100 , 100 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 N_TRAIN_ITERS = 300 Dataset \u00a4 Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , _ = super () . __getitem__ ( index ) img = torch . flatten ( img ) return img Plotting \u00a4 def plot_train_energies ( train_energies , ts ): t_max = int ( ts [ 0 ]) norm = mcolors . Normalize ( vmin = 0 , vmax = len ( energies ) - 1 ) fig , ax = plt . subplots ( figsize = ( 8 , 4 )) cmap_blues = plt . get_cmap ( \"Blues\" ) cmap_reds = plt . get_cmap ( \"Reds\" ) cmap_greens = plt . get_cmap ( \"Greens\" ) legend_handles = [] legend_labels = [] for t , energies_iter in enumerate ( energies ): line1 , = ax . plot ( energies_iter [ 0 , : t_max ], color = cmap_blues ( norm ( t ))) line2 , = ax . plot ( energies_iter [ 1 , : t_max ], color = cmap_reds ( norm ( t ))) line3 , = ax . plot ( energies_iter [ 2 , : t_max ], color = cmap_greens ( norm ( t ))) if t == 70 : legend_handles . append ( line1 ) legend_labels . append ( \"$\\ell_1$\" ) legend_handles . append ( line2 ) legend_labels . append ( \"$\\ell_2$\" ) legend_handles . append ( line3 ) legend_labels . append ( \"$\\ell_3$\" ) ax . legend ( legend_handles , legend_labels , loc = \"best\" , fontsize = 16 ) sm = plt . cm . ScalarMappable ( cmap = plt . get_cmap ( \"Greys\" ), norm = norm ) sm . _A = [] cbar = fig . colorbar ( sm , ax = ax ) cbar . set_label ( \"Training iteration\" , fontsize = 16 , labelpad = 14 ) cbar . ax . tick_params ( labelsize = 14 ) plt . gca () . tick_params ( axis = \"both\" , which = \"major\" , labelsize = 16 ) ax . set_xlabel ( \"Inference iterations\" , fontsize = 18 , labelpad = 14 ) ax . set_ylabel ( \"Energy\" , fontsize = 18 , labelpad = 14 ) ax . set_yscale ( \"log\" ) plt . show () Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[100,50], bias=f32[100], in_features=50, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[100,100], bias=f32[100], in_features=100, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,100], bias=f32[784], in_features=100, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )] Train \u00a4 def train ( key , layer_sizes , batch_size , network , lr , max_t1 , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) train_energies , ts = [], [] for iter , img_batch in enumerate ( train_loader ): img_batch = img_batch . numpy () result = jpc . make_pc_step ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , model = network , optim = optim , opt_state = opt_state , output = img_batch , max_t1 = max_t1 , record_activities = True , record_energies = True ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_energies . append ( result [ \"energies\" ]) ts . append ( result [ \"t_max\" ]) if ( iter + 1 ) >= n_train_iters : break return result [ \"model\" ], train_energies , ts Run \u00a4 network , energies , ts = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_train_energies ( energies , ts )","title":"Unsupervised generative PC"},{"location":"examples/unsupervised_generative_pc/#unsupervised-generative-pc-on-mnist","text":"This notebook demonstrates how to train a neural network with predictive coding to encode MNIST digits in an unsupervised manner. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import matplotlib.colors as mcolors import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Unsupervised generative PC on MNIST"},{"location":"examples/unsupervised_generative_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 50 , 100 , 100 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 N_TRAIN_ITERS = 300","title":"Hyperparameters"},{"location":"examples/unsupervised_generative_pc/#dataset","text":"Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , _ = super () . __getitem__ ( index ) img = torch . flatten ( img ) return img","title":"Dataset"},{"location":"examples/unsupervised_generative_pc/#plotting","text":"def plot_train_energies ( train_energies , ts ): t_max = int ( ts [ 0 ]) norm = mcolors . Normalize ( vmin = 0 , vmax = len ( energies ) - 1 ) fig , ax = plt . subplots ( figsize = ( 8 , 4 )) cmap_blues = plt . get_cmap ( \"Blues\" ) cmap_reds = plt . get_cmap ( \"Reds\" ) cmap_greens = plt . get_cmap ( \"Greens\" ) legend_handles = [] legend_labels = [] for t , energies_iter in enumerate ( energies ): line1 , = ax . plot ( energies_iter [ 0 , : t_max ], color = cmap_blues ( norm ( t ))) line2 , = ax . plot ( energies_iter [ 1 , : t_max ], color = cmap_reds ( norm ( t ))) line3 , = ax . plot ( energies_iter [ 2 , : t_max ], color = cmap_greens ( norm ( t ))) if t == 70 : legend_handles . append ( line1 ) legend_labels . append ( \"$\\ell_1$\" ) legend_handles . append ( line2 ) legend_labels . append ( \"$\\ell_2$\" ) legend_handles . append ( line3 ) legend_labels . append ( \"$\\ell_3$\" ) ax . legend ( legend_handles , legend_labels , loc = \"best\" , fontsize = 16 ) sm = plt . cm . ScalarMappable ( cmap = plt . get_cmap ( \"Greys\" ), norm = norm ) sm . _A = [] cbar = fig . colorbar ( sm , ax = ax ) cbar . set_label ( \"Training iteration\" , fontsize = 16 , labelpad = 14 ) cbar . ax . tick_params ( labelsize = 14 ) plt . gca () . tick_params ( axis = \"both\" , which = \"major\" , labelsize = 16 ) ax . set_xlabel ( \"Inference iterations\" , fontsize = 18 , labelpad = 14 ) ax . set_ylabel ( \"Energy\" , fontsize = 18 , labelpad = 14 ) ax . set_yscale ( \"log\" ) plt . show ()","title":"Plotting"},{"location":"examples/unsupervised_generative_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[100,50], bias=f32[100], in_features=50, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[100,100], bias=f32[100], in_features=100, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,100], bias=f32[784], in_features=100, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/unsupervised_generative_pc/#train","text":"def train ( key , layer_sizes , batch_size , network , lr , max_t1 , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) train_energies , ts = [], [] for iter , img_batch in enumerate ( train_loader ): img_batch = img_batch . numpy () result = jpc . make_pc_step ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , model = network , optim = optim , opt_state = opt_state , output = img_batch , max_t1 = max_t1 , record_activities = True , record_energies = True ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_energies . append ( result [ \"energies\" ]) ts . append ( result [ \"t_max\" ]) if ( iter + 1 ) >= n_train_iters : break return result [ \"model\" ], train_energies , ts","title":"Train"},{"location":"examples/unsupervised_generative_pc/#run","text":"network , energies , ts = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_train_energies ( energies , ts )","title":"Run"}]} \ No newline at end of file +{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"Getting started \u00a4 JPC is a J AX library for training neural networks with P redictive C oding (PC). It is built on top of three main libraries: Equinox , to define neural networks with PyTorch-like syntax, Diffrax , to solve the PC inference (activity) dynamics, and Optax , for parameter optimisation. JPC provides a simple , relatively fast and flexible API for training of a variety of PCNs including discriminative, generative and hybrid models. Like JAX, JPC is completely functional, and the core library is <1000 lines of code. Unlike existing implementations, JPC leverages ordinary differential equation (ODE) solvers to integrate the inference dynamics of PC networks (PCNs), which we find can provide significant speed-ups compared to standard optimisers, especially for deeper models. JPC also provides some analytical tools that can be used to study and diagnose issues with PCNs. \ud83d\udcbb Installation \u00a4 pip install jpc Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11.2+, Diffrax 0.6.0+, and Optax 0.2.4+. For GPU usage, upgrade jax to the appropriate cuda version (12 as an example here). pip install --upgrade \"jax[cuda12]\" \u26a1\ufe0f Quick example \u00a4 Use jpc.make_pc_step to update the parameters of any neural network compatible with PC updates (see examples) import jax.random as jr import jax.numpy as jnp import equinox as eqx import optax import jpc # toy data x = jnp . array ([ 1. , 1. , 1. ]) y = - x # define model and optimiser key = jr . PRNGKey ( 0 ) model = jpc . make_mlp ( key , layer_sizes = [ 3 , 5 , 5 , 3 ], act_fn = \"relu\" ) optim = optax . adam ( 1e-3 ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) # perform one training step with PC result = jpc . make_pc_step ( model = model , optim = optim , opt_state = opt_state , output = y , input = x ) # updated model and optimiser model = result [ \"model\" ] optim , opt_state = result [ \"optim\" ], result [ \"opt_state\" ] Under the hood, jpc.make_pc_step integrates the inference (activity) dynamics using a Diffrax ODE solver, and updates model parameters at the numerical solution of the activities with a given Optax optimiser. NOTE : All convenience training and test functions including make_pc_step are already \"jitted\" (for increased performance) for the user's convenience. \ud83e\udde0\ufe0f Predictive coding primer \u00a4 ... \ud83d\ude80 Advanced usage \u00a4 Advanced users can access all the underlying functions of jpc.make_pc_step as well as additional features. A custom PC training step looks like the following: import jpc # 1. initialise activities with a feedforward pass activities = jpc . init_activities_with_ffwd ( model = model , input = x ) # 2. run inference to equilibrium equilibrated_activities = jpc . solve_inference ( params = ( model , None ), activities = activities , output = y , input = x ) # 3. update parameters at the activities' solution with PC result = jpc . update_params ( params = ( model , None ), activities = equilibrated_activities , optim = optim , opt_state = opt_state , output = y , input = x ) which can be embedded in a jitted function with any other additional computations. \ud83d\udcc4 Citation \u00a4 If you found this library useful in your work, please cite (arXiv link): @article { innocenti2024jpc , title = {JPC: Flexible Inference for Predictive Coding Networks in JAX} , author = {Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Singh, Ryan and De Llanza Varona, Miguel and Buckley, Christopher} , journal = {arXiv preprint} , year = {2024} } Also consider starring the project on GitHub ! \u2b50\ufe0f \u23ed\ufe0f Next steps \u00a4","title":"Getting started"},{"location":"#getting-started","text":"JPC is a J AX library for training neural networks with P redictive C oding (PC). It is built on top of three main libraries: Equinox , to define neural networks with PyTorch-like syntax, Diffrax , to solve the PC inference (activity) dynamics, and Optax , for parameter optimisation. JPC provides a simple , relatively fast and flexible API for training of a variety of PCNs including discriminative, generative and hybrid models. Like JAX, JPC is completely functional, and the core library is <1000 lines of code. Unlike existing implementations, JPC leverages ordinary differential equation (ODE) solvers to integrate the inference dynamics of PC networks (PCNs), which we find can provide significant speed-ups compared to standard optimisers, especially for deeper models. JPC also provides some analytical tools that can be used to study and diagnose issues with PCNs.","title":"Getting started"},{"location":"#installation","text":"pip install jpc Requires Python 3.9+, JAX 0.4.23+, Equinox 0.11.2+, Diffrax 0.6.0+, and Optax 0.2.4+. For GPU usage, upgrade jax to the appropriate cuda version (12 as an example here). pip install --upgrade \"jax[cuda12]\"","title":"\ud83d\udcbb Installation"},{"location":"#quick-example","text":"Use jpc.make_pc_step to update the parameters of any neural network compatible with PC updates (see examples) import jax.random as jr import jax.numpy as jnp import equinox as eqx import optax import jpc # toy data x = jnp . array ([ 1. , 1. , 1. ]) y = - x # define model and optimiser key = jr . PRNGKey ( 0 ) model = jpc . make_mlp ( key , layer_sizes = [ 3 , 5 , 5 , 3 ], act_fn = \"relu\" ) optim = optax . adam ( 1e-3 ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) # perform one training step with PC result = jpc . make_pc_step ( model = model , optim = optim , opt_state = opt_state , output = y , input = x ) # updated model and optimiser model = result [ \"model\" ] optim , opt_state = result [ \"optim\" ], result [ \"opt_state\" ] Under the hood, jpc.make_pc_step integrates the inference (activity) dynamics using a Diffrax ODE solver, and updates model parameters at the numerical solution of the activities with a given Optax optimiser. NOTE : All convenience training and test functions including make_pc_step are already \"jitted\" (for increased performance) for the user's convenience.","title":"\u26a1\ufe0f Quick example"},{"location":"#predictive-coding-primer","text":"...","title":"\ud83e\udde0\ufe0f Predictive coding primer"},{"location":"#advanced-usage","text":"Advanced users can access all the underlying functions of jpc.make_pc_step as well as additional features. A custom PC training step looks like the following: import jpc # 1. initialise activities with a feedforward pass activities = jpc . init_activities_with_ffwd ( model = model , input = x ) # 2. run inference to equilibrium equilibrated_activities = jpc . solve_inference ( params = ( model , None ), activities = activities , output = y , input = x ) # 3. update parameters at the activities' solution with PC result = jpc . update_params ( params = ( model , None ), activities = equilibrated_activities , optim = optim , opt_state = opt_state , output = y , input = x ) which can be embedded in a jitted function with any other additional computations.","title":"\ud83d\ude80 Advanced usage"},{"location":"#citation","text":"If you found this library useful in your work, please cite (arXiv link): @article { innocenti2024jpc , title = {JPC: Flexible Inference for Predictive Coding Networks in JAX} , author = {Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Singh, Ryan and De Llanza Varona, Miguel and Buckley, Christopher} , journal = {arXiv preprint} , year = {2024} } Also consider starring the project on GitHub ! \u2b50\ufe0f","title":"\ud83d\udcc4 Citation"},{"location":"#next-steps","text":"","title":"\u23ed\ufe0f Next steps"},{"location":"FAQs/","text":"","title":"FAQs"},{"location":"advanced_usage/","text":"","title":"Advanced usage"},{"location":"basic_usage/","text":"Info JPC provides two types of API depending on the use case: a simple, basic API that allows to train and test models with predictive coding with a few lines of code a more advanced and flexible API allowing for Describe purposes/use cases of both basic and advanced. Basic usage \u00a4 JPC provides a single convenience function jpc.make_pc_step() to train predictive coding networks (PCNs) on classification and generation tasks, in a supervised as well as unsupervised manner. import jpc relu_net = jpc . get_fc_network ( key , [ 10 , 100 , 100 , 10 ], \"relu\" ) result = jpc . make_pc_step ( model = relu_net , optim = optim , opt_state = opt_state , y = y , x = x ) At a minimum, jpc.make_pc_step() takes a model, an optax optimiser and its state, and an output target. Under the hood, jpc.make_pc_step() uses diffrax to solve the activity (inference) dynamics of PC. The arguments can be changed import jpc result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , y = y , x = x , solver = other_solver , dt = 1e-1 , ) Moreover, JPC provides a similar function for training a hybrid PCN import jax import jax.numpy as jnp from equinox import nn as nn # some datasets x = jnp . array ([ 1. , 1. , 1. ]) y = - x # network key = jax . random . key ( 0 ) _ , * subkeys = jax . random . split ( key ) network = [ nn . Sequential ( [ nn . Linear ( 3 , 100 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu )], ), nn . Linear ( 100 , 3 , key = subkeys [ 1 ]), ]","title":"Basic usage"},{"location":"basic_usage/#basic-usage","text":"JPC provides a single convenience function jpc.make_pc_step() to train predictive coding networks (PCNs) on classification and generation tasks, in a supervised as well as unsupervised manner. import jpc relu_net = jpc . get_fc_network ( key , [ 10 , 100 , 100 , 10 ], \"relu\" ) result = jpc . make_pc_step ( model = relu_net , optim = optim , opt_state = opt_state , y = y , x = x ) At a minimum, jpc.make_pc_step() takes a model, an optax optimiser and its state, and an output target. Under the hood, jpc.make_pc_step() uses diffrax to solve the activity (inference) dynamics of PC. The arguments can be changed import jpc result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , y = y , x = x , solver = other_solver , dt = 1e-1 , ) Moreover, JPC provides a similar function for training a hybrid PCN import jax import jax.numpy as jnp from equinox import nn as nn # some datasets x = jnp . array ([ 1. , 1. , 1. ]) y = - x # network key = jax . random . key ( 0 ) _ , * subkeys = jax . random . split ( key ) network = [ nn . Sequential ( [ nn . Linear ( 3 , 100 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu )], ), nn . Linear ( 100 , 3 , key = subkeys [ 1 ]), ]","title":"Basic usage"},{"location":"extending_jpc/","text":"","title":"Extending JPC"},{"location":"api/Analytical%20tools/","text":"Analytical tools \u00a4 jpc . linear_equilib_energy ( network : PyTree [ equinox . nn . _linear . Linear ], x : ArrayLike , y : ArrayLike ) -> Array \u00a4 Computes the theoretical equilibrated PC energy for a deep linear network (DLN). \\[ \\mathcal{F}^* = 1/N \\sum_i^N (\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i)^T S^{-1}(\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i) \\] where the rescaling is \\(S = I_{d_y} + \\sum_{\\ell=2}^L (W_{L:\\ell})(W_{L:\\ell})^T\\) , and we use the shorthand \\(W_{L:\\ell} = W_L W_{L-1} \\dots W_\\ell\\) . See reference below. Note This expression assumes no biases. Reference @article { innocenti2024only , title = {Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?} , author = {Innocenti, Francesco and Achour, El Mehdi and Singh, Ryan and Buckley, Christopher L} , journal = {arXiv preprint arXiv:2408.11979} , year = {2024} } Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: Mean total analytical energy over a batch or dataset. jpc . linear_activities_solution ( network : PyTree [ equinox . nn . _linear . Linear ], x : ArrayLike , y : ArrayLike ) -> PyTree [ Array ] \u00a4 Computes the theoretical solution for the PC activities of a deep linear network (DLN). \\[ \\mathbf{z}^* = A^{-1} \\mathbf{b} \\] where \\(A\\) is a sparse block diagonal matrix depending only on the weights, and \\(\\mathbf{b} = [W_1 \\mathbf{x}, \\mathbf{0}, \\dots, W_L^T \\mathbf{y}]^T\\) . In particular, \\(A_{\\ell,k} = I + W_\\ell^T W_\\ell\\) if \\(\\ell = k\\) , \\(A_{\\ell,k} = -W_\\ell\\) if \\(\\ell = k+1\\) , \\(A_{\\ell,k} = -W_\\ell^T\\) if \\(\\ell = k-1\\) , and \\(\\mathbf{0}\\) otherwise, for \\(\\ell, k \\in [2, \\dots, L]\\) . Note This expression assumes no biases. Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: List of theoretical activities for each layer.","title":"Analytical tools"},{"location":"api/Analytical%20tools/#analytical-tools","text":"","title":"Analytical tools"},{"location":"api/Analytical%20tools/#jpc.linear_equilib_energy","text":"Computes the theoretical equilibrated PC energy for a deep linear network (DLN). \\[ \\mathcal{F}^* = 1/N \\sum_i^N (\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i)^T S^{-1}(\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i) \\] where the rescaling is \\(S = I_{d_y} + \\sum_{\\ell=2}^L (W_{L:\\ell})(W_{L:\\ell})^T\\) , and we use the shorthand \\(W_{L:\\ell} = W_L W_{L-1} \\dots W_\\ell\\) . See reference below. Note This expression assumes no biases. Reference @article { innocenti2024only , title = {Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?} , author = {Innocenti, Francesco and Achour, El Mehdi and Singh, Ryan and Buckley, Christopher L} , journal = {arXiv preprint arXiv:2408.11979} , year = {2024} } Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: Mean total analytical energy over a batch or dataset.","title":"linear_equilib_energy()"},{"location":"api/Analytical%20tools/#jpc.linear_activities_solution","text":"Computes the theoretical solution for the PC activities of a deep linear network (DLN). \\[ \\mathbf{z}^* = A^{-1} \\mathbf{b} \\] where \\(A\\) is a sparse block diagonal matrix depending only on the weights, and \\(\\mathbf{b} = [W_1 \\mathbf{x}, \\mathbf{0}, \\dots, W_L^T \\mathbf{y}]^T\\) . In particular, \\(A_{\\ell,k} = I + W_\\ell^T W_\\ell\\) if \\(\\ell = k\\) , \\(A_{\\ell,k} = -W_\\ell\\) if \\(\\ell = k+1\\) , \\(A_{\\ell,k} = -W_\\ell^T\\) if \\(\\ell = k-1\\) , and \\(\\mathbf{0}\\) otherwise, for \\(\\ell, k \\in [2, \\dots, L]\\) . Note This expression assumes no biases. Main arguments: network : Linear network defined as a list of Equinox Linear layers. x : Network input. y : Network output. Returns: List of theoretical activities for each layer.","title":"linear_activities_solution()"},{"location":"api/Energy%20functions/","text":"Energy functions \u00a4 jpc . pc_energy_fn ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], y : ArrayLike , x : Optional [ ArrayLike ] = None , loss : str = 'MSE' , record_layers : bool = False ) -> Array | Array \u00a4 Computes the free energy for a feedforward neural network of the form \\[ \\mathcal{F}(\\mathbf{z}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}_{i, \\ell} - f_\\ell(\\mathbf{z}_{i, \\ell-1}; \u03b8) ||^2 \\] given parameters \\(\u03b8\\) , free activities \\(\\mathbf{z}\\) , output \\(\\mathbf{z}_L = \\mathbf{y}\\) and optional input \\(\\mathbf{z}_0 = \\mathbf{x}\\) for supervised training. The activity of each layer \\(\\mathbf{z}_\\ell\\) is some function of the previous layer, e.g. ReLU \\((W_\\ell \\mathbf{z}_{\\ell-1} + \\mathbf{b}_\\ell)\\) for a fully connected layer with biases and ReLU as activation. Note The input \\(x\\) and output \\(y\\) correspond to the prior and observation of the generative model, respectively. Main arguments: params : Tuple with callable model (e.g. neural network) layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model (for supervised training). Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ??? cite \"Reference\" @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } - record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by the batch size. jpc . hpc_energy_fn ( model : PyTree [ typing . Callable ], equilib_activities : PyTree [ ArrayLike ], amort_activities : PyTree [ ArrayLike ], x : ArrayLike , y : Optional [ ArrayLike ] = None , record_layers : bool = False ) -> Array | Array \u00a4 Computes the free energy of an amortised PC network \\[ \\mathcal{F}(\\mathbf{z}^*, \\hat{\\mathbf{z}}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}^*_{i, \\ell} - f_\\ell(\\hat{\\mathbf{z}}_{i, \\ell-1}; \u03b8) ||^2 \\] given the equilibrated activities of the generator \\(\\mathbf{z}^*\\) (target for the amortiser), the feedforward guesses of the amortiser \\(\\hat{\\mathbf{z}}\\) , the amortiser's parameters \\(\u03b8\\) , input \\(\\mathbf{z}_0 = \\mathbf{x}\\) , and optional output \\(\\mathbf{z}_L = \\mathbf{y}\\) for supervised training. Note The input \\(x\\) and output \\(y\\) are reversed compared to pc_energy_fn ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Other arguments: record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by batch size.","title":"Energy functions"},{"location":"api/Energy%20functions/#energy-functions","text":"","title":"Energy functions"},{"location":"api/Energy%20functions/#jpc.pc_energy_fn","text":"Computes the free energy for a feedforward neural network of the form \\[ \\mathcal{F}(\\mathbf{z}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}_{i, \\ell} - f_\\ell(\\mathbf{z}_{i, \\ell-1}; \u03b8) ||^2 \\] given parameters \\(\u03b8\\) , free activities \\(\\mathbf{z}\\) , output \\(\\mathbf{z}_L = \\mathbf{y}\\) and optional input \\(\\mathbf{z}_0 = \\mathbf{x}\\) for supervised training. The activity of each layer \\(\\mathbf{z}_\\ell\\) is some function of the previous layer, e.g. ReLU \\((W_\\ell \\mathbf{z}_{\\ell-1} + \\mathbf{b}_\\ell)\\) for a fully connected layer with biases and ReLU as activation. Note The input \\(x\\) and output \\(y\\) correspond to the prior and observation of the generative model, respectively. Main arguments: params : Tuple with callable model (e.g. neural network) layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model (for supervised training). Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ??? cite \"Reference\" @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } - record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by the batch size.","title":"pc_energy_fn()"},{"location":"api/Energy%20functions/#jpc.hpc_energy_fn","text":"Computes the free energy of an amortised PC network \\[ \\mathcal{F}(\\mathbf{z}^*, \\hat{\\mathbf{z}}; \u03b8) = 1/N \\sum_i^N \\sum_{\\ell=1}^L || \\mathbf{z}^*_{i, \\ell} - f_\\ell(\\hat{\\mathbf{z}}_{i, \\ell-1}; \u03b8) ||^2 \\] given the equilibrated activities of the generator \\(\\mathbf{z}^*\\) (target for the amortiser), the feedforward guesses of the amortiser \\(\\hat{\\mathbf{z}}\\) , the amortiser's parameters \\(\u03b8\\) , input \\(\\mathbf{z}_0 = \\mathbf{x}\\) , and optional output \\(\\mathbf{z}_L = \\mathbf{y}\\) for supervised training. Note The input \\(x\\) and output \\(y\\) are reversed compared to pc_energy_fn ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Other arguments: record_layers : If True , returns energies for each layer. Returns: The total or layer-wise energy normalised by batch size.","title":"hpc_energy_fn()"},{"location":"api/Gradients/","text":"Gradients \u00a4 jpc . neg_activity_grad ( t : float | int , activities : PyTree [ ArrayLike ], args : Tuple [ Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], ArrayLike , Optional [ ArrayLike ], str , diffrax . _step_size_controller . base . AbstractStepSizeController ]) -> PyTree [ Array ] \u00a4 Computes the negative gradient of the energy with respect to the activities \\(- \\partial \\mathcal{F} / \\partial \\mathbf{z}\\) . This defines an ODE system to be integrated by solve_pc_inference . Main arguments: t : Time step of the ODE system, used for downstream integration by diffrax.diffeqsolve . activities : List of activities for each layer free to vary. args : 5-Tuple with (i) Tuple with callable model layers and optional skip connections, (ii) network output (observation), (iii) network input (prior), (iv) Loss specified at the output layer (MSE vs cross-entropy), and (v) diffrax controller for step size integration. Returns: List of negative gradients of the energy w.r.t. the activities. jpc . compute_pc_param_grads ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], y : ArrayLike , x : Optional [ ArrayLike ] = None , loss_id : str = 'MSE' ) -> Tuple [ PyTree [ Array ], PyTree [ Array ]] \u00a4 Computes the gradient of the PC energy with respect to model parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). Returns: List of parameter gradients for each network layer. jpc . compute_hpc_param_grads ( model : PyTree [ typing . Callable ], equilib_activities : PyTree [ ArrayLike ], amort_activities : PyTree [ ArrayLike ], x : ArrayLike , y : Optional [ ArrayLike ] = None ) -> PyTree [ Array ] \u00a4 Computes the gradient of the hybrid energy with respect to an amortiser's parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Note The input \\(x\\) and output \\(y\\) are reversed compared to compute_pc_param_grads ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Returns: List of parameter gradients for each network layer.","title":"Gradients"},{"location":"api/Gradients/#gradients","text":"","title":"Gradients"},{"location":"api/Gradients/#jpc.neg_activity_grad","text":"Computes the negative gradient of the energy with respect to the activities \\(- \\partial \\mathcal{F} / \\partial \\mathbf{z}\\) . This defines an ODE system to be integrated by solve_pc_inference . Main arguments: t : Time step of the ODE system, used for downstream integration by diffrax.diffeqsolve . activities : List of activities for each layer free to vary. args : 5-Tuple with (i) Tuple with callable model layers and optional skip connections, (ii) network output (observation), (iii) network input (prior), (iv) Loss specified at the output layer (MSE vs cross-entropy), and (v) diffrax controller for step size integration. Returns: List of negative gradients of the energy w.r.t. the activities.","title":"neg_activity_grad()"},{"location":"api/Gradients/#jpc.compute_pc_param_grads","text":"Computes the gradient of the PC energy with respect to model parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. y : Observation or target of the generative model. x : Optional prior of the generative model. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). Returns: List of parameter gradients for each network layer.","title":"compute_pc_param_grads()"},{"location":"api/Gradients/#jpc.compute_hpc_param_grads","text":"Computes the gradient of the hybrid energy with respect to an amortiser's parameters \\(\\partial \\mathcal{F} / \\partial \u03b8\\) . Main arguments: model : List of callable model (e.g. neural network) layers. equilib_activities : List of equilibrated activities reached by the generator and target for the amortiser. amort_activities : List of amortiser's feedforward guesses (initialisation) for the network activities. x : Input to the amortiser. y : Optional target of the amortiser (for supervised training). Note The input \\(x\\) and output \\(y\\) are reversed compared to compute_pc_param_grads ( \\(x\\) is the generator's target and \\(y\\) is its optional input or prior). Just think of \\(x\\) and \\(y\\) as the actual input and output of the amortiser, respectively. Returns: List of parameter gradients for each network layer.","title":"compute_hpc_param_grads()"},{"location":"api/Inference/","text":"Inference \u00a4 jpc . solve_inference ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], output : ArrayLike , input : Optional [ ArrayLike ] = None , loss_id : str = 'MSE' , solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 20 , dt : float | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), record_iters : bool = False , record_every : int = None ) -> PyTree [ Array ] \u00a4 Solves the inference (activity) dynamics of a predictive coding network. This is a wrapper around diffrax.diffeqsolve to integrate the gradient ODE system _neg_activity_grad defining the PC inference dynamics \\[ \\partial \\mathbf{z} / \\partial t = - \\partial \\mathcal{F} / \\partial \\mathbf{z} \\] where \\(\\mathcal{F}\\) is the free energy, \\(\\mathbf{z}\\) are the activities, with \\(\\mathbf{z}_L\\) clamped to some target and \\(\\mathbf{z}_0\\) optionally equal to some prior. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). solver : Diffrax (ODE) solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_iters : If True , returns all integration steps. record_every : int determining the sampling frequency the integration steps. Returns: List with solution of the activity dynamics for each layer.","title":"Inference"},{"location":"api/Inference/#inference","text":"","title":"Inference"},{"location":"api/Inference/#jpc.solve_inference","text":"Solves the inference (activity) dynamics of a predictive coding network. This is a wrapper around diffrax.diffeqsolve to integrate the gradient ODE system _neg_activity_grad defining the PC inference dynamics \\[ \\partial \\mathbf{z} / \\partial t = - \\partial \\mathcal{F} / \\partial \\mathbf{z} \\] where \\(\\mathcal{F}\\) is the free energy, \\(\\mathbf{z}\\) are the activities, with \\(\\mathbf{z}_L\\) clamped to some target and \\(\\mathbf{z}_0\\) optionally equal to some prior. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). solver : Diffrax (ODE) solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_iters : If True , returns all integration steps. record_every : int determining the sampling frequency the integration steps. Returns: List with solution of the activity dynamics for each layer.","title":"solve_inference()"},{"location":"api/Initialisation/","text":"Initialisation \u00a4 jpc . init_activities_with_ffwd ( model : PyTree [ typing . Callable ], input : ArrayLike , skip_model : Optional [ PyTree [ Callable ]] = None ) -> PyTree [ Array ] \u00a4 Initialises layers' activity with a feedforward pass \\(\\{ f_\\ell(\\mathbf{z}_{\\ell-1}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{x}\\) is the input. Main arguments: model : List of callable model (e.g. neural network) layers. input : input to the model. Other arguments: skip_model : Optional skip connection model. Returns: List with activity values of each layer. jpc . init_activities_from_normal ( key : PRNGKeyArray , layer_sizes : PyTree [ int ], mode : str , batch_size : int , sigma : Array = 0.05 ) -> PyTree [ Array ] \u00a4 Initialises network activities from a zero-mean Gaussian \\(\\sim \\mathcal{N}(0, \\sigma^2)\\) . Main arguments: key : jax.random.PRNGKey for sampling. layer_sizes : List with dimension of all layers (input, hidden and output). mode : If supervised , all hidden layers are initialised. If unsupervised the input layer \\(\\mathbf{z}_0\\) is also initialised. batch_size : Dimension of data batch. sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. Returns: List of randomly initialised activities for each layer. jpc . init_activities_with_amort ( amortiser : PyTree [ typing . Callable ], generator : PyTree [ typing . Callable ], input : ArrayLike ) -> PyTree [ Array ] \u00a4 Initialises layers' activity with an amortised network \\(\\{ f_{L-\\ell+1}(\\mathbf{z}_{L-\\ell}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{y}\\) is the input or generator's target. Note The output order is reversed for downstream use by the generator. Main arguments: amortiser : List of callable layers for model amortising the inference of the generator . generator : List of callable layers for the generative model. input : Input to the amortiser. Returns: List with amortised initialisation of each layer.","title":"Initialisation"},{"location":"api/Initialisation/#initialisation","text":"","title":"Initialisation"},{"location":"api/Initialisation/#jpc.init_activities_with_ffwd","text":"Initialises layers' activity with a feedforward pass \\(\\{ f_\\ell(\\mathbf{z}_{\\ell-1}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{x}\\) is the input. Main arguments: model : List of callable model (e.g. neural network) layers. input : input to the model. Other arguments: skip_model : Optional skip connection model. Returns: List with activity values of each layer.","title":"init_activities_with_ffwd()"},{"location":"api/Initialisation/#jpc.init_activities_from_normal","text":"Initialises network activities from a zero-mean Gaussian \\(\\sim \\mathcal{N}(0, \\sigma^2)\\) . Main arguments: key : jax.random.PRNGKey for sampling. layer_sizes : List with dimension of all layers (input, hidden and output). mode : If supervised , all hidden layers are initialised. If unsupervised the input layer \\(\\mathbf{z}_0\\) is also initialised. batch_size : Dimension of data batch. sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. Returns: List of randomly initialised activities for each layer.","title":"init_activities_from_normal()"},{"location":"api/Initialisation/#jpc.init_activities_with_amort","text":"Initialises layers' activity with an amortised network \\(\\{ f_{L-\\ell+1}(\\mathbf{z}_{L-\\ell}) \\}_{\\ell=1}^L\\) where \\(\\mathbf{z}_0 = \\mathbf{y}\\) is the input or generator's target. Note The output order is reversed for downstream use by the generator. Main arguments: amortiser : List of callable layers for model amortising the inference of the generator . generator : List of callable layers for the generative model. input : Input to the amortiser. Returns: List with amortised initialisation of each layer.","title":"init_activities_with_amort()"},{"location":"api/Testing/","text":"Testing \u00a4 jpc . test_discriminative_pc ( model : PyTree [ typing . Callable ], output : ArrayLike , input : ArrayLike , loss : str = 'MSE' , skip_model : Optional [ PyTree [ Callable ]] = None ) -> Tuple [ Array , Array ] \u00a4 Computes test metrics for a discriminative predictive coding network. Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : - loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). skip_model : Optional list of callable skip connection functions. Returns: Test loss and accuracy of output predictions. jpc . test_generative_pc ( model : PyTree [ typing . Callable ], output : ArrayLike , input : ArrayLike , key : PRNGKeyArray , layer_sizes : PyTree [ int ], batch_size : int , sigma : Array = 0.05 , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 500 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), skip_model : Optional [ PyTree [ Callable ]] = None ) -> Tuple [ Array , Array ] \u00a4 Computes test metrics for a generative predictive coding network. Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image). Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracy and output predictions. jpc . test_hpc ( generator : PyTree [ typing . Callable ], amortiser : PyTree [ typing . Callable ], output : ArrayLike , input : ArrayLike , key : PRNGKeyArray , layer_sizes : PyTree [ int ], batch_size : int , sigma : Array = 0.05 , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 500 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None )) -> Tuple [ Array , Array , Array , Array ] \u00a4 Computes test metrics for hybrid predictive coding trained in a supervised manner. Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator. Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . output : Observation or target of the generative model. input : Optional prior of the generator, target for the amortiser. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for initialisation of activities. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracies of all models and output predictions.","title":"Testing"},{"location":"api/Testing/#testing","text":"","title":"Testing"},{"location":"api/Testing/#jpc.test_discriminative_pc","text":"Computes test metrics for a discriminative predictive coding network. Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. Other arguments: loss : - loss : Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). skip_model : Optional list of callable skip connection functions. Returns: Test loss and accuracy of output predictions.","title":"test_discriminative_pc()"},{"location":"api/Testing/#jpc.test_generative_pc","text":"Computes test metrics for a generative predictive coding network. Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image). Main arguments: model : List of callable model (e.g. neural network) layers. output : Observation or target of the generative model. input : Optional prior of the generative model. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracy and output predictions.","title":"test_generative_pc()"},{"location":"api/Testing/#jpc.test_hpc","text":"Computes test metrics for hybrid predictive coding trained in a supervised manner. Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator. Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . output : Observation or target of the generative model. input : Optional prior of the generator, target for the amortiser. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for initialisation of activities. Other arguments: sigma : Standard deviation for Gaussian to sample activities from. Defaults to 5e-2. ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (500 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. Returns: Accuracies of all models and output predictions.","title":"test_hpc()"},{"location":"api/Training/","text":"Training \u00a4 jpc . make_pc_step ( model : PyTree [ typing . Callable ], optim : optax . _src . base . GradientTransformation | optax . _src . base . GradientTransformationExtraArgs , opt_state : Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]], output : ArrayLike , input : Optional [ ArrayLike ] = None , loss_id : str = 'MSE' , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 20 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), skip_model : Optional [ PyTree [ Callable ]] = None , key : Optional [ PRNGKeyArray ] = None , layer_sizes : Optional [ PyTree [ int ]] = None , batch_size : Optional [ int ] = None , sigma : Array = 0.05 , record_activities : bool = False , record_energies : bool = False , record_every : int = None , activity_norms : bool = False , param_norms : bool = False , grad_norms : bool = False , calculate_accuracy : bool = False ) -> Dict \u00a4 Updates network parameters with predictive coding. Main arguments: model : List of callable model (e.g. neural network) layers. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Note key , layer_sizes and batch_size must be passed if input is None , since unsupervised training will be assumed and activities need to be initialised randomly. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (20 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. skip_model : Optional list of callable skip connection functions. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. sigma : Standard deviation for Gaussian to sample activities from for random initialisation. Defaults to 5e-2. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. record_every : int determining the sampling frequency the integration steps. activity_norms : If True , computes l2 norm of the activities. param_norms : If True , computes l2 norm of the parameters. grad_norms : If True , computes l2 norm of parameter gradients. calculate_accuracy : If True , computes the training accuracy. Returns: Dict including model (and optional skip model) with updated parameters, optimiser, updated optimiser state, loss, energies, activities, and optionally other metrics (see other args above). Raises: ValueError for inconsistent inputs and invalid losses. jpc . make_hpc_step ( generator : PyTree [ typing . Callable ], amortiser : PyTree [ typing . Callable ], optims : Tuple [ optax . _src . base . GradientTransformationExtraArgs ], opt_states : Tuple [ Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]]], output : ArrayLike , input : Optional [ ArrayLike ] = None , ode_solver : AbstractSolver = Heun ( scan_kind = None ), max_t1 : int = 300 , dt : Array | int = None , stepsize_controller : AbstractStepSizeController = PIDController ( rtol = 0.001 , atol = 0.001 , pcoeff = 0 , icoeff = 1 , dcoeff = 0 , dtmin = None , dtmax = None , force_dtmin = True , step_ts = None , jump_ts = None , factormin = 0.2 , factormax = 10.0 , norm =< function rms_norm > , safety = 0.9 , error_order = None ), record_activities : bool = False , record_energies : bool = False ) -> Dict \u00a4 Updates parameters of a hybrid predictive coding network. Reference @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . optims : Optax optimisers (e.g. optax.sgd() ), one for each model. opt_states : State of Optax optimisers, one for each model. output : Observation of the generator, input to the amortiser. input : Optional prior of the generator, target for the amortiser. Other arguments: ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.. max_t1 : Maximum end of integration region (300 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. Returns: Dict including models with updated parameters, optimiser and state for each model, model activities, last inference step for the generator, MSE losses, and energies.","title":"Training"},{"location":"api/Training/#training","text":"","title":"Training"},{"location":"api/Training/#jpc.make_pc_step","text":"Updates network parameters with predictive coding. Main arguments: model : List of callable model (e.g. neural network) layers. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Note key , layer_sizes and batch_size must be passed if input is None , since unsupervised training will be assumed and activities need to be initialised randomly. Other arguments: loss_id : Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE'). ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. max_t1 : Maximum end of integration region (20 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. skip_model : Optional list of callable skip connection functions. key : jax.random.PRNGKey for random initialisation of activities. layer_sizes : Dimension of all layers (input, hidden and output). batch_size : Dimension of data batch for activity initialisation. sigma : Standard deviation for Gaussian to sample activities from for random initialisation. Defaults to 5e-2. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. record_every : int determining the sampling frequency the integration steps. activity_norms : If True , computes l2 norm of the activities. param_norms : If True , computes l2 norm of the parameters. grad_norms : If True , computes l2 norm of parameter gradients. calculate_accuracy : If True , computes the training accuracy. Returns: Dict including model (and optional skip model) with updated parameters, optimiser, updated optimiser state, loss, energies, activities, and optionally other metrics (see other args above). Raises: ValueError for inconsistent inputs and invalid losses.","title":"make_pc_step()"},{"location":"api/Training/#jpc.make_hpc_step","text":"Updates parameters of a hybrid predictive coding network. Reference @article { tscshantz2023hybrid , title = {Hybrid predictive coding: Inferring, fast and slow} , author = {Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L} , journal = {PLoS Computational Biology} , volume = {19} , number = {8} , pages = {e1011280} , year = {2023} , publisher = {Public Library of Science San Francisco, CA USA} } Note The input and output of the generator are the output and input of the amortiser, respectively. Main arguments: generator : List of callable layers for the generative model. amortiser : List of callable layers for model amortising the inference of the generator . optims : Optax optimisers (e.g. optax.sgd() ), one for each model. opt_states : State of Optax optimisers, one for each model. output : Observation of the generator, input to the amortiser. input : Optional prior of the generator, target for the amortiser. Other arguments: ode_solver : Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.. max_t1 : Maximum end of integration region (300 by default). dt : Integration step size. Defaults to None since the default stepsize_controller will automatically determine it. stepsize_controller : diffrax controller for step size integration. Defaults to PIDController . Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver. record_activities : If True , returns activities at every inference iteration. record_energies : If True , returns layer-wise energies at every inference iteration. Returns: Dict including models with updated parameters, optimiser and state for each model, model activities, last inference step for the generator, MSE losses, and energies.","title":"make_hpc_step()"},{"location":"api/Updates/","text":"Updates \u00a4 jpc . update_activities ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], optim : optax . _src . base . GradientTransformation | optax . _src . base . GradientTransformationExtraArgs , opt_state : Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]], output : ArrayLike , input : Optional [ ArrayLike ] = None ) -> Dict \u00a4 Updates activities of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with energy, updated activities, activity gradients, optimiser, and updated optimiser state. jpc . update_params ( params : Tuple [ PyTree [ Callable ], Optional [ PyTree [ Callable ]]], activities : PyTree [ ArrayLike ], optim : optax . _src . base . GradientTransformation | optax . _src . base . GradientTransformationExtraArgs , opt_state : Union [ jax . Array , numpy . ndarray , numpy . bool , numpy . number , Iterable [ ArrayTree ], Mapping [ Any , ArrayTree ]], output : ArrayLike , input : Optional [ ArrayLike ] = None ) -> Dict \u00a4 Updates parameters of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with model (and optional skip model) with updated parameters, parameter gradients, optimiser, and updated optimiser state.","title":"Updates"},{"location":"api/Updates/#updates","text":"","title":"Updates"},{"location":"api/Updates/#jpc.update_activities","text":"Updates activities of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with energy, updated activities, activity gradients, optimiser, and updated optimiser state.","title":"update_activities()"},{"location":"api/Updates/#jpc.update_params","text":"Updates parameters of a predictive coding network with a given Optax optimiser. Main arguments: params : Tuple with callable model layers and optional skip connections. activities : List of activities for each layer free to vary. optim : Optax optimiser, e.g. optax.sgd() . opt_state : State of Optax optimiser. output : Observation or target of the generative model. input : Optional prior of the generative model. Returns: Dictionary with model (and optional skip model) with updated parameters, parameter gradients, optimiser, and updated optimiser state.","title":"update_params()"},{"location":"api/make_mlp/","text":"make_mlp \u00a4 jpc . make_mlp ( key : PRNGKeyArray , layer_sizes : PyTree [ int ], act_fn : str , use_bias : bool = True ) -> PyTree [ typing . Callable ] \u00a4 Creates a multi-layer perceptron compatible with predictive coding updates. Main arguments: key : jax.random.PRNGKey for parameter initialisation. layer_sizes : Dimension of all layers (input, hidden and output). Options are linear , tanh and relu . act_fn : Activation function for all layers except the output. use_bias : True by default. Returns: List of callable fully connected layers.","title":"make_mlp"},{"location":"api/make_mlp/#make_mlp","text":"","title":"make_mlp"},{"location":"api/make_mlp/#jpc.make_mlp","text":"Creates a multi-layer perceptron compatible with predictive coding updates. Main arguments: key : jax.random.PRNGKey for parameter initialisation. layer_sizes : Dimension of all layers (input, hidden and output). Options are linear , tanh and relu . act_fn : Activation function for all layers except the output. use_bias : True by default. Returns: List of callable fully connected layers.","title":"make_mlp()"},{"location":"examples/discriminative_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Discriminative PC on MNIST \u00a4 This notebook demonstrates how to train a neural network with predictive coding (PC) to discriminate or classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 import jpc import jax import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 784 , 300 , 300 , 10 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300 Dataset \u00a4 Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True ), Lambda(fn=Identity()) ) )] Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , input = img_batch , output = label_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break Run \u00a4 import warnings with warnings . catch_warnings (): warnings . simplefilter ( 'ignore' ) train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 100, train loss=0.018149, avg test accuracy=93.790062 Train iter 200, train loss=0.012088, avg test accuracy=95.142227 Train iter 300, train loss=0.016424, avg test accuracy=95.723160","title":"Discriminative PC"},{"location":"examples/discriminative_pc/#discriminative-pc-on-mnist","text":"This notebook demonstrates how to train a neural network with predictive coding (PC) to discriminate or classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 import jpc import jax import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Discriminative PC on MNIST"},{"location":"examples/discriminative_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 784 , 300 , 300 , 10 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300","title":"Hyperparameters"},{"location":"examples/discriminative_pc/#dataset","text":"Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ]","title":"Dataset"},{"location":"examples/discriminative_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/discriminative_pc/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , input = img_batch , output = label_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break","title":"Train and test"},{"location":"examples/discriminative_pc/#run","text":"import warnings with warnings . catch_warnings (): warnings . simplefilter ( 'ignore' ) train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 100, train loss=0.018149, avg test accuracy=93.790062 Train iter 200, train loss=0.012088, avg test accuracy=95.142227 Train iter 300, train loss=0.016424, avg test accuracy=95.723160","title":"Run"},{"location":"examples/hybrid_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Hybrid PC on MNIST \u00a4 This notebook demonstrates how to train a hybrid predictive coding network that can both generate and classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 50 TEST_EVERY = 100 N_TRAIN_ITERS = 300 Dataset \u00a4 Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_imgs ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ]) plt . show () Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )] Train and test \u00a4 A hybrid PC network can be trained in a single line of code with jpc.make_hpc_step() . See the documentation for more. Similarly, we can use jpc.test_hpc() to compute different test metrics. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ): amort_accs , hpc_accs , gen_accs = 0 , 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () amort_acc , hpc_acc , gen_acc , img_preds = jpc . test_hpc ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , generator = generator , amortiser = amortiser , input = label_batch , output = img_batch ) amort_accs += amort_acc hpc_accs += hpc_acc gen_accs += gen_acc return ( amort_accs / len ( test_loader ), hpc_accs / len ( test_loader ), gen_accs / len ( test_loader ), label_batch , img_preds ) def train ( seed , layer_sizes , act_fn , batch_size , lr , max_t1 , test_every , n_train_iters ): key = jax . random . PRNGKey ( seed ) key , * subkey = jax . random . split ( key , 3 ) generator = jpc . make_mlp ( subkey [ 0 ], layer_sizes , act_fn ) amortiser = jpc . make_mlp ( subkey [ 1 ], layer_sizes [:: - 1 ], act_fn ) gen_optim = optax . adam ( lr ) amort_optim = optax . adam ( lr ) optims = [ gen_optim , amort_optim ] gen_opt_state = gen_optim . init ( ( eqx . filter ( generator , eqx . is_array ), None ) ) amort_opt_state = amort_optim . init ( eqx . filter ( amortiser , eqx . is_array )) opt_states = [ gen_opt_state , amort_opt_state ] train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_hpc_step ( generator = generator , amortiser = amortiser , optims = optims , opt_states = opt_states , input = label_batch , output = img_batch , max_t1 = max_t1 ) generator , amortiser = result [ \"generator\" ], result [ \"amortiser\" ] optims , opt_states = result [ \"optims\" ], result [ \"opt_states\" ] gen_loss , amort_loss = result [ \"losses\" ] if (( iter + 1 ) % test_every ) == 0 : amort_acc , hpc_acc , gen_acc , label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ) print ( f \"Iter { iter + 1 } , gen loss= { gen_loss : 4f } , \" f \"amort loss= { amort_loss : 4f } , \" f \"avg amort test accuracy= { amort_acc : 4f } , \" f \"avg hpc test accuracy= { hpc_acc : 4f } , \" f \"avg gen test accuracy= { gen_acc : 4f } , \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_imgs ( img_preds , label_batch ) return amortiser , generator Run \u00a4 network = train ( seed = SEED , layer_sizes = LAYER_SIZES , act_fn = ACT_FN , batch_size = BATCH_SIZE , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Iter 100, gen loss=0.558071, amort loss=0.056306, avg amort test accuracy=76.782852, avg hpc test accuracy=79.727562, avg gen test accuracy=78.675880, Iter 200, gen loss=0.622492, amort loss=0.039034, avg amort test accuracy=83.503609, avg hpc test accuracy=81.740784, avg gen test accuracy=81.109779, Iter 300, gen loss=0.548741, amort loss=0.039427, avg amort test accuracy=85.987579, avg hpc test accuracy=82.311699, avg gen test accuracy=81.209938,","title":"Hybrid PC"},{"location":"examples/hybrid_pc/#hybrid-pc-on-mnist","text":"This notebook demonstrates how to train a hybrid predictive coding network that can both generate and classify MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Hybrid PC on MNIST"},{"location":"examples/hybrid_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 50 TEST_EVERY = 100 N_TRAIN_ITERS = 300","title":"Hyperparameters"},{"location":"examples/hybrid_pc/#dataset","text":"Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_imgs ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ]) plt . show ()","title":"Dataset"},{"location":"examples/hybrid_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) _ , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 784 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 10 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/hybrid_pc/#train-and-test","text":"A hybrid PC network can be trained in a single line of code with jpc.make_hpc_step() . See the documentation for more. Similarly, we can use jpc.test_hpc() to compute different test metrics. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ): amort_accs , hpc_accs , gen_accs = 0 , 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () amort_acc , hpc_acc , gen_acc , img_preds = jpc . test_hpc ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , generator = generator , amortiser = amortiser , input = label_batch , output = img_batch ) amort_accs += amort_acc hpc_accs += hpc_acc gen_accs += gen_acc return ( amort_accs / len ( test_loader ), hpc_accs / len ( test_loader ), gen_accs / len ( test_loader ), label_batch , img_preds ) def train ( seed , layer_sizes , act_fn , batch_size , lr , max_t1 , test_every , n_train_iters ): key = jax . random . PRNGKey ( seed ) key , * subkey = jax . random . split ( key , 3 ) generator = jpc . make_mlp ( subkey [ 0 ], layer_sizes , act_fn ) amortiser = jpc . make_mlp ( subkey [ 1 ], layer_sizes [:: - 1 ], act_fn ) gen_optim = optax . adam ( lr ) amort_optim = optax . adam ( lr ) optims = [ gen_optim , amort_optim ] gen_opt_state = gen_optim . init ( ( eqx . filter ( generator , eqx . is_array ), None ) ) amort_opt_state = amort_optim . init ( eqx . filter ( amortiser , eqx . is_array )) opt_states = [ gen_opt_state , amort_opt_state ] train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_hpc_step ( generator = generator , amortiser = amortiser , optims = optims , opt_states = opt_states , input = label_batch , output = img_batch , max_t1 = max_t1 ) generator , amortiser = result [ \"generator\" ], result [ \"amortiser\" ] optims , opt_states = result [ \"optims\" ], result [ \"opt_states\" ] gen_loss , amort_loss = result [ \"losses\" ] if (( iter + 1 ) % test_every ) == 0 : amort_acc , hpc_acc , gen_acc , label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , generator , amortiser , test_loader ) print ( f \"Iter { iter + 1 } , gen loss= { gen_loss : 4f } , \" f \"amort loss= { amort_loss : 4f } , \" f \"avg amort test accuracy= { amort_acc : 4f } , \" f \"avg hpc test accuracy= { hpc_acc : 4f } , \" f \"avg gen test accuracy= { gen_acc : 4f } , \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_imgs ( img_preds , label_batch ) return amortiser , generator","title":"Train and test"},{"location":"examples/hybrid_pc/#run","text":"network = train ( seed = SEED , layer_sizes = LAYER_SIZES , act_fn = ACT_FN , batch_size = BATCH_SIZE , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Iter 100, gen loss=0.558071, amort loss=0.056306, avg amort test accuracy=76.782852, avg hpc test accuracy=79.727562, avg gen test accuracy=78.675880, Iter 200, gen loss=0.622492, amort loss=0.039034, avg amort test accuracy=83.503609, avg hpc test accuracy=81.740784, avg gen test accuracy=81.109779, Iter 300, gen loss=0.548741, amort loss=0.039427, avg amort test accuracy=85.987579, avg hpc test accuracy=82.311699, avg gen test accuracy=81.209938,","title":"Run"},{"location":"examples/implementing_pc_from_scratch/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); \u2699\ufe0f Implementing Predictive Coding from Scratch (in JAX) \u00a4 In this notebook, we walk through how to implement Predictive Coding (PC) from scratch using JAX. In particular, we are going to train a simple feedforward network to classify MNIST digits with PC. It might be a good idea to revisit the introductory lecture on PC from this week. If you're not familiar with JAX, have a look at their docs , but we will explain all the necessary concepts below. JAX is basically numpy for GPUs and other hardware accelerators. We will also use: Equinox , which allows you to define neural nets with PyTorch-like syntax; and Optax , which provides a range of common machine learning optimisers such as gradient descent and Adam. Installations & imports \u00a4 #@title installations %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 #@title imports import jax.random as jr import jax.numpy as jnp from jax import vmap , grad from jax.tree_util import tree_map import equinox as eqx import equinox.nn as nn from equinox import filter_grad import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import warnings warnings . simplefilter ( \"ignore\" ) Hyperparameters \u00a4 We define some global parameters related to the data, network, optimisers, etc. SEED = 827 INPUT_DIM , OUTPUT_DIM = 28 * 28 , 10 NETWORK_WIDTH = 300 ACTIVITY_LR = 1e-1 INFERENCE_STEPS = 20 PARAM_LR = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 50 N_TRAIN_ITERS = 500 Dataset \u00a4 Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] PC energy \u00a4 First, recall that PC can be derived as a variational inference algorithm under certain assumptions. In particular, if we assume * a dirac delta (point mass) posterior and * a hierarchical Gaussian generative model, we get the standard PC energy \\begin{equation} \\mathcal{F} = \\frac{1}{2N}\\sum_{i=1}^{N} \\sum_{\\ell=1}^L ||\\mathbf{z} {\\ell, i} - f \\ell(W_\\ell \\mathbf{z} {\\ell-1, i} + \\mathbf{b} \\ell)||^2_2 \\end{equation} which is just a sum of squared prediction errors at each network layer. Here we are being a little bit more precise than in the lecture, including multiple ( \\(N\\) ) data points and biases \\(\\mathbf{b}_\\ell\\) . \ud83e\udd14 Food for thought : Think about how the form of this energy could change depending other assumptions we make about the generative model. See, for example, Learning on Arbitrary Graph Topologies via Predictive Coding by Salvatori et al. (2022). Let's start by implementing this energy below. The function simply takes the model (with all the parameters), some initialised activities, and some input and output. Given these, it simply sums the prediction error at each layer. NOTE : below we use vmap , one of the core JAX transforms that allows you to vectorise operations, in this case for multiple data points or over a batch. See https://jax.readthedocs.io/en/latest/automatic-vectorization.html for more details. def pc_energy_fn ( model , activities , input , output ): batch_size = output . shape [ 0 ] n_activity_layers = len ( activities ) - 1 n_layers = len ( model ) - 1 eL = output - vmap ( model [ - 1 ])( activities [ - 2 ]) energies = [ jnp . sum ( eL ** 2 )] for act_l , net_l in zip ( range ( 1 , n_activity_layers ), range ( 1 , n_layers ) ): err = activities [ act_l ] - vmap ( model [ net_l ])( activities [ act_l - 1 ]) energies . append ( jnp . sum ( err ** 2 )) e1 = activities [ 0 ] - vmap ( model [ 0 ])( input ) energies . append ( jnp . sum ( e1 ** 2 )) return jnp . sum ( jnp . array ( energies )) / batch_size Now let's test it. To do so, we first need a model. Below we use Equinox to create a simple feedforward network with 2 hidden layers and tanh activations. Note that we split the model into different parts with nn.Sequential to define the activities which PC will optimise over (during inference, more on this below). \u2753 Question : Think about other ways in which we could split the layers, for example by separating the non-linearities. Can you think of potential issues with this? # jax uses explicit random number generators (see https://jax.readthedocs.io/en/latest/random-numbers.html) key = jr . PRNGKey ( SEED ) subkeys = jr . split ( key , 3 ) model = [ nn . Sequential ( [ nn . Linear ( INPUT_DIM , NETWORK_WIDTH , key = subkeys [ 0 ]), nn . Lambda ( jnp . tanh ) ], ), nn . Sequential ( [ nn . Linear ( NETWORK_WIDTH , NETWORK_WIDTH , key = subkeys [ 1 ]), nn . Lambda ( jnp . tanh ) ], ), nn . Linear ( NETWORK_WIDTH , OUTPUT_DIM , key = subkeys [ 2 ]), ] model [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] The last thing with need is to initialise the activities. For this, we will use a feedforward pass as often done in practice. \u2753 Question : Can you think of other ways of initialising the activities? def init_activities_with_ffwd ( model , input ): activities = [ vmap ( model [ 0 ])( input )] for l in range ( 1 , len ( model )): layer_output = vmap ( model [ l ])( activities [ l - 1 ]) activities . append ( layer_output ) return activities Let's test it on an MNIST sample. # get a data sample train_loader , test_loader = get_mnist_loaders ( BATCH_SIZE ) img_batch , label_batch = next ( iter ( train_loader )) # we need to turn the torch.Tensor data into numpy arrays for jax img_batch , label_batch = img_batch . numpy (), label_batch . numpy () # let's check our initialised activities activities = init_activities_with_ffwd ( model , img_batch ) for i , a in enumerate ( activities ): print ( f \"activity z at layer { i + 1 } : { a . shape } \" ) activity z at layer 1: (64, 300) activity z at layer 2: (64, 300) activity z at layer 3: (64, 10) Ok so now we have everything to test our PC energy function: model, activities, and some data. pc_energy_fn ( model = model , activities = activities , input = img_batch , output = label_batch ) Array(1.2335204, dtype=float32) And it works! Energy gradients \u00a4 How do we minimise the PC energy we defined above (Eq. 1)? Recall from the lecture that we do this in two phases: first with respect to the activities (inference) and then with respect to the weights (learning). \\[\\begin{equation} \\textit{Inference:} - \\frac{\\partial \\mathcal{F}}{\\partial \\mathbf{z}_\\ell} \\end{equation}\\] \\[\\begin{equation} \\textit{Learning:} - \\frac{\\partial \\mathcal{F}}{\\partial W_\\ell} \\end{equation}\\] So we just need to take these gradients of the energy. We are going to use autodiff, which JAX embeds by design (https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). If you're familiar with PyTorch, you are probably used to loss.backward() for this, which might feel obstruse at times. JAX, on the other hand, is a fully functional (as opposed to object-oriented) language whose syntax is very close to the maths as you can see below. # note how close this code is to the maths # this can be read as \"take the gradient of the energy... # ...with the respect to the 2nd argument (the activities) def compute_activity_grad ( model , activities , input , output ): return grad ( pc_energy_fn , argnums = 1 )( model , activities , input , output ) Let's test this out. dFdzs = compute_activity_grad ( model = model , activities = activities , input = img_batch , output = label_batch ) for i , dFdz in enumerate ( dFdzs ): print ( f \"activity gradient dFdz shape at layer { i + 1 } : { dFdz . shape } \" ) activity gradient dFdz shape at layer 1: (64, 300) activity gradient dFdz shape at layer 2: (64, 300) activity gradient dFdz shape at layer 3: (64, 10) Now we do the same and take the gradient of the energy with respect to the parameters. Technical note : below we use Equinox's convenience function filter_grad rather than JAX's native grad . This is because things like activation functions do not have parameters and so we do not want to differentiate them. filter_grad automatically filters these non-differentiable objects for us, while grad alone would throw an error. # note that, compared to the previous function,... # ...we just change the argument with respect to which... # ...we are differentiating (the first, or in this case the model) def compute_param_grad ( model , activities , input , output ): return filter_grad ( pc_energy_fn )( model , activities , input , output ) And let's test it. param_grads = compute_param_grad ( model = model , activities = activities , input = img_batch , output = label_batch ) Updates \u00a4 Before putting everything together, let's wrap our gradients into update functions. This will also allow us to use JAX's jit primitive, which essentially compiles your code the first time it's executed so that it can be run more efficiently the next time (see https://jax.readthedocs.io/en/latest/jit-compilation.html for more details). These functions take an (Optax) optimiser such as gradient descent in addition to the previous arguments (model, activities and data). @eqx . filter_jit def update_activities ( model , activities , optim , opt_state , input , output ): activity_grads = compute_activity_grad ( model = model , activities = activities , input = input , output = output ) activity_updates , activity_opt_state = optim . update ( updates = activity_grads , state = opt_state , params = activities ) activities = eqx . apply_updates ( model = activities , updates = activity_updates ) return activities , optim , opt_state # note that the only difference with the above function is... # ...the variable we are updating (parameters vs activities) @eqx . filter_jit def update_params ( model , activities , optim , opt_state , input , output ): param_grads = compute_param_grad ( model = model , activities = activities , input = input , output = output ) param_updates , param_opt_state = optim . update ( updates = param_grads , state = opt_state , params = model ) model = eqx . apply_updates ( model = model , updates = param_updates ) return model , optim , opt_state Putting everything together: Training and testing \u00a4 Now that we have our activity and parameter updates, we just need to wrap them in a training and test loop. # note: the accuracy computation below could be sped up... # ...with jit in a separate function def evaluate ( model , test_loader ): avg_test_acc = 0 for test_iter , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () preds = init_activities_with_ffwd ( model , img_batch )[ - 1 ] test_acc = jnp . mean ( jnp . argmax ( label_batch , axis = 1 ) == jnp . argmax ( preds , axis = 1 ) ) * 100 avg_test_acc += test_acc return avg_test_acc / len ( test_loader ) def train ( model , activity_lr , inference_steps , param_lr , batch_size , test_every , n_train_iters ): # define optimisers for activities and parameters activity_optim = optax . sgd ( activity_lr ) param_optim = optax . adam ( param_lr ) param_opt_state = param_optim . init ( eqx . filter ( model , eqx . is_array )) train_loader , test_loader = get_mnist_loaders ( batch_size ) for train_iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () # initialise activities activities = init_activities_with_ffwd ( model , img_batch ) activity_opt_state = activity_optim . init ( activities ) train_loss = jnp . mean (( label_batch - activities [ - 1 ]) ** 2 ) # inference for t in range ( inference_steps ): activities , activity_optim , activity_opt_state = update_activities ( model = model , activities = activities , optim = activity_optim , opt_state = activity_opt_state , input = img_batch , output = label_batch ) # learning model , param_optim , param_opt_state = update_params ( model = model , activities = activities , # note how we use the optimised activities optim = param_optim , opt_state = param_opt_state , input = img_batch , output = label_batch ) if (( train_iter + 1 ) % test_every ) == 0 : avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { train_iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( train_iter + 1 ) >= n_train_iters : break Run \u00a4 Let's test our implementation. train ( model = model , activity_lr = ACTIVITY_LR , inference_steps = INFERENCE_STEPS , param_lr = PARAM_LR , batch_size = BATCH_SIZE , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 50, train loss=0.065566, avg test accuracy=72.726364 Train iter 100, train loss=0.046521, avg test accuracy=76.292068 Train iter 150, train loss=0.042710, avg test accuracy=86.568512 Train iter 200, train loss=0.029598, avg test accuracy=89.082535 Train iter 250, train loss=0.031486, avg test accuracy=89.222755 Train iter 300, train loss=0.016624, avg test accuracy=91.296074 Train iter 350, train loss=0.025201, avg test accuracy=92.648239 Train iter 400, train loss=0.018597, avg test accuracy=92.968750 Train iter 450, train loss=0.019027, avg test accuracy=94.130608 Train iter 500, train loss=0.014850, avg test accuracy=93.760017 \ud83e\udd73 Great, we see that our model is training! You can probably improve the performance by tweaking some of the hyperparameters (e.g. try a higher number of inference steps). Even if you didn't follow all the implementation details, you should now have at least an idea of how PC works in practice. Indeed, this is basically the core code behind a new PC library our lab will soon release: JPC . Play around with the notebook examples there where you can learn how to train a variety of PC networks.","title":"Implementing pc from scratch"},{"location":"examples/implementing_pc_from_scratch/#implementing-predictive-coding-from-scratch-in-jax","text":"In this notebook, we walk through how to implement Predictive Coding (PC) from scratch using JAX. In particular, we are going to train a simple feedforward network to classify MNIST digits with PC. It might be a good idea to revisit the introductory lecture on PC from this week. If you're not familiar with JAX, have a look at their docs , but we will explain all the necessary concepts below. JAX is basically numpy for GPUs and other hardware accelerators. We will also use: Equinox , which allows you to define neural nets with PyTorch-like syntax; and Optax , which provides a range of common machine learning optimisers such as gradient descent and Adam.","title":"\u2699\ufe0f Implementing Predictive Coding from Scratch (in JAX)"},{"location":"examples/implementing_pc_from_scratch/#installations-imports","text":"#@title installations %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 #@title imports import jax.random as jr import jax.numpy as jnp from jax import vmap , grad from jax.tree_util import tree_map import equinox as eqx import equinox.nn as nn from equinox import filter_grad import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import warnings warnings . simplefilter ( \"ignore\" )","title":"Installations & imports"},{"location":"examples/implementing_pc_from_scratch/#hyperparameters","text":"We define some global parameters related to the data, network, optimisers, etc. SEED = 827 INPUT_DIM , OUTPUT_DIM = 28 * 28 , 10 NETWORK_WIDTH = 300 ACTIVITY_LR = 1e-1 INFERENCE_STEPS = 20 PARAM_LR = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 50 N_TRAIN_ITERS = 500","title":"Hyperparameters"},{"location":"examples/implementing_pc_from_scratch/#dataset","text":"Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ]","title":"Dataset"},{"location":"examples/implementing_pc_from_scratch/#pc-energy","text":"First, recall that PC can be derived as a variational inference algorithm under certain assumptions. In particular, if we assume * a dirac delta (point mass) posterior and * a hierarchical Gaussian generative model, we get the standard PC energy \\begin{equation} \\mathcal{F} = \\frac{1}{2N}\\sum_{i=1}^{N} \\sum_{\\ell=1}^L ||\\mathbf{z} {\\ell, i} - f \\ell(W_\\ell \\mathbf{z} {\\ell-1, i} + \\mathbf{b} \\ell)||^2_2 \\end{equation} which is just a sum of squared prediction errors at each network layer. Here we are being a little bit more precise than in the lecture, including multiple ( \\(N\\) ) data points and biases \\(\\mathbf{b}_\\ell\\) . \ud83e\udd14 Food for thought : Think about how the form of this energy could change depending other assumptions we make about the generative model. See, for example, Learning on Arbitrary Graph Topologies via Predictive Coding by Salvatori et al. (2022). Let's start by implementing this energy below. The function simply takes the model (with all the parameters), some initialised activities, and some input and output. Given these, it simply sums the prediction error at each layer. NOTE : below we use vmap , one of the core JAX transforms that allows you to vectorise operations, in this case for multiple data points or over a batch. See https://jax.readthedocs.io/en/latest/automatic-vectorization.html for more details. def pc_energy_fn ( model , activities , input , output ): batch_size = output . shape [ 0 ] n_activity_layers = len ( activities ) - 1 n_layers = len ( model ) - 1 eL = output - vmap ( model [ - 1 ])( activities [ - 2 ]) energies = [ jnp . sum ( eL ** 2 )] for act_l , net_l in zip ( range ( 1 , n_activity_layers ), range ( 1 , n_layers ) ): err = activities [ act_l ] - vmap ( model [ net_l ])( activities [ act_l - 1 ]) energies . append ( jnp . sum ( err ** 2 )) e1 = activities [ 0 ] - vmap ( model [ 0 ])( input ) energies . append ( jnp . sum ( e1 ** 2 )) return jnp . sum ( jnp . array ( energies )) / batch_size Now let's test it. To do so, we first need a model. Below we use Equinox to create a simple feedforward network with 2 hidden layers and tanh activations. Note that we split the model into different parts with nn.Sequential to define the activities which PC will optimise over (during inference, more on this below). \u2753 Question : Think about other ways in which we could split the layers, for example by separating the non-linearities. Can you think of potential issues with this? # jax uses explicit random number generators (see https://jax.readthedocs.io/en/latest/random-numbers.html) key = jr . PRNGKey ( SEED ) subkeys = jr . split ( key , 3 ) model = [ nn . Sequential ( [ nn . Linear ( INPUT_DIM , NETWORK_WIDTH , key = subkeys [ 0 ]), nn . Lambda ( jnp . tanh ) ], ), nn . Sequential ( [ nn . Linear ( NETWORK_WIDTH , NETWORK_WIDTH , key = subkeys [ 1 ]), nn . Lambda ( jnp . tanh ) ], ), nn . Linear ( NETWORK_WIDTH , OUTPUT_DIM , key = subkeys [ 2 ]), ] model [Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )] The last thing with need is to initialise the activities. For this, we will use a feedforward pass as often done in practice. \u2753 Question : Can you think of other ways of initialising the activities? def init_activities_with_ffwd ( model , input ): activities = [ vmap ( model [ 0 ])( input )] for l in range ( 1 , len ( model )): layer_output = vmap ( model [ l ])( activities [ l - 1 ]) activities . append ( layer_output ) return activities Let's test it on an MNIST sample. # get a data sample train_loader , test_loader = get_mnist_loaders ( BATCH_SIZE ) img_batch , label_batch = next ( iter ( train_loader )) # we need to turn the torch.Tensor data into numpy arrays for jax img_batch , label_batch = img_batch . numpy (), label_batch . numpy () # let's check our initialised activities activities = init_activities_with_ffwd ( model , img_batch ) for i , a in enumerate ( activities ): print ( f \"activity z at layer { i + 1 } : { a . shape } \" ) activity z at layer 1: (64, 300) activity z at layer 2: (64, 300) activity z at layer 3: (64, 10) Ok so now we have everything to test our PC energy function: model, activities, and some data. pc_energy_fn ( model = model , activities = activities , input = img_batch , output = label_batch ) Array(1.2335204, dtype=float32) And it works!","title":"PC energy"},{"location":"examples/implementing_pc_from_scratch/#energy-gradients","text":"How do we minimise the PC energy we defined above (Eq. 1)? Recall from the lecture that we do this in two phases: first with respect to the activities (inference) and then with respect to the weights (learning). \\[\\begin{equation} \\textit{Inference:} - \\frac{\\partial \\mathcal{F}}{\\partial \\mathbf{z}_\\ell} \\end{equation}\\] \\[\\begin{equation} \\textit{Learning:} - \\frac{\\partial \\mathcal{F}}{\\partial W_\\ell} \\end{equation}\\] So we just need to take these gradients of the energy. We are going to use autodiff, which JAX embeds by design (https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). If you're familiar with PyTorch, you are probably used to loss.backward() for this, which might feel obstruse at times. JAX, on the other hand, is a fully functional (as opposed to object-oriented) language whose syntax is very close to the maths as you can see below. # note how close this code is to the maths # this can be read as \"take the gradient of the energy... # ...with the respect to the 2nd argument (the activities) def compute_activity_grad ( model , activities , input , output ): return grad ( pc_energy_fn , argnums = 1 )( model , activities , input , output ) Let's test this out. dFdzs = compute_activity_grad ( model = model , activities = activities , input = img_batch , output = label_batch ) for i , dFdz in enumerate ( dFdzs ): print ( f \"activity gradient dFdz shape at layer { i + 1 } : { dFdz . shape } \" ) activity gradient dFdz shape at layer 1: (64, 300) activity gradient dFdz shape at layer 2: (64, 300) activity gradient dFdz shape at layer 3: (64, 10) Now we do the same and take the gradient of the energy with respect to the parameters. Technical note : below we use Equinox's convenience function filter_grad rather than JAX's native grad . This is because things like activation functions do not have parameters and so we do not want to differentiate them. filter_grad automatically filters these non-differentiable objects for us, while grad alone would throw an error. # note that, compared to the previous function,... # ...we just change the argument with respect to which... # ...we are differentiating (the first, or in this case the model) def compute_param_grad ( model , activities , input , output ): return filter_grad ( pc_energy_fn )( model , activities , input , output ) And let's test it. param_grads = compute_param_grad ( model = model , activities = activities , input = img_batch , output = label_batch )","title":"Energy gradients"},{"location":"examples/implementing_pc_from_scratch/#updates","text":"Before putting everything together, let's wrap our gradients into update functions. This will also allow us to use JAX's jit primitive, which essentially compiles your code the first time it's executed so that it can be run more efficiently the next time (see https://jax.readthedocs.io/en/latest/jit-compilation.html for more details). These functions take an (Optax) optimiser such as gradient descent in addition to the previous arguments (model, activities and data). @eqx . filter_jit def update_activities ( model , activities , optim , opt_state , input , output ): activity_grads = compute_activity_grad ( model = model , activities = activities , input = input , output = output ) activity_updates , activity_opt_state = optim . update ( updates = activity_grads , state = opt_state , params = activities ) activities = eqx . apply_updates ( model = activities , updates = activity_updates ) return activities , optim , opt_state # note that the only difference with the above function is... # ...the variable we are updating (parameters vs activities) @eqx . filter_jit def update_params ( model , activities , optim , opt_state , input , output ): param_grads = compute_param_grad ( model = model , activities = activities , input = input , output = output ) param_updates , param_opt_state = optim . update ( updates = param_grads , state = opt_state , params = model ) model = eqx . apply_updates ( model = model , updates = param_updates ) return model , optim , opt_state","title":"Updates"},{"location":"examples/implementing_pc_from_scratch/#putting-everything-together-training-and-testing","text":"Now that we have our activity and parameter updates, we just need to wrap them in a training and test loop. # note: the accuracy computation below could be sped up... # ...with jit in a separate function def evaluate ( model , test_loader ): avg_test_acc = 0 for test_iter , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () preds = init_activities_with_ffwd ( model , img_batch )[ - 1 ] test_acc = jnp . mean ( jnp . argmax ( label_batch , axis = 1 ) == jnp . argmax ( preds , axis = 1 ) ) * 100 avg_test_acc += test_acc return avg_test_acc / len ( test_loader ) def train ( model , activity_lr , inference_steps , param_lr , batch_size , test_every , n_train_iters ): # define optimisers for activities and parameters activity_optim = optax . sgd ( activity_lr ) param_optim = optax . adam ( param_lr ) param_opt_state = param_optim . init ( eqx . filter ( model , eqx . is_array )) train_loader , test_loader = get_mnist_loaders ( batch_size ) for train_iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () # initialise activities activities = init_activities_with_ffwd ( model , img_batch ) activity_opt_state = activity_optim . init ( activities ) train_loss = jnp . mean (( label_batch - activities [ - 1 ]) ** 2 ) # inference for t in range ( inference_steps ): activities , activity_optim , activity_opt_state = update_activities ( model = model , activities = activities , optim = activity_optim , opt_state = activity_opt_state , input = img_batch , output = label_batch ) # learning model , param_optim , param_opt_state = update_params ( model = model , activities = activities , # note how we use the optimised activities optim = param_optim , opt_state = param_opt_state , input = img_batch , output = label_batch ) if (( train_iter + 1 ) % test_every ) == 0 : avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { train_iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( train_iter + 1 ) >= n_train_iters : break","title":"Putting everything together: Training and testing"},{"location":"examples/implementing_pc_from_scratch/#run","text":"Let's test our implementation. train ( model = model , activity_lr = ACTIVITY_LR , inference_steps = INFERENCE_STEPS , param_lr = PARAM_LR , batch_size = BATCH_SIZE , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 50, train loss=0.065566, avg test accuracy=72.726364 Train iter 100, train loss=0.046521, avg test accuracy=76.292068 Train iter 150, train loss=0.042710, avg test accuracy=86.568512 Train iter 200, train loss=0.029598, avg test accuracy=89.082535 Train iter 250, train loss=0.031486, avg test accuracy=89.222755 Train iter 300, train loss=0.016624, avg test accuracy=91.296074 Train iter 350, train loss=0.025201, avg test accuracy=92.648239 Train iter 400, train loss=0.018597, avg test accuracy=92.968750 Train iter 450, train loss=0.019027, avg test accuracy=94.130608 Train iter 500, train loss=0.014850, avg test accuracy=93.760017 \ud83e\udd73 Great, we see that our model is training! You can probably improve the performance by tweaking some of the hyperparameters (e.g. try a higher number of inference steps). Even if you didn't follow all the implementation details, you should now have at least an idea of how PC works in practice. Indeed, this is basically the core code behind a new PC library our lab will soon release: JPC . Play around with the notebook examples there where you can learn how to train a variety of PC networks.","title":"Run"},{"location":"examples/linear_net_theoretical_activities/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Theoretical activities of deep linear networks \u00a4 %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install plotly == 5.11.0 ! pip install - U kaleido import jpc import jax from jax import vmap import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import plotly.graph_objs as go import plotly.io as pio pio . renderers . default = 'iframe' Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 10 N_TRAIN_ITERS = 20 Dataset \u00a4 Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] Plotting \u00a4 def plot_layer_energies ( energies ): n_train_iters = energies [ \"theory\" ] . shape [ 0 ] n_energies = energies [ \"theory\" ] . shape [ 1 ] train_iters = [ b + 1 for b in range ( n_train_iters )] colors = [ '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' ] fig = go . Figure () for n in range ( n_energies ): fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"theory\" ][:, n ], mode = \"lines\" , line = dict ( width = 2 , dash = \"dash\" , color = colors [ n ] ), showlegend = False ) ) fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"experiment\" ][:, n ], name = f \"$\\Large {{ \\ell_ { n + 1 } }} $\" , mode = \"lines\" , line = dict ( width = 3 , dash = \"solid\" , color = colors [ n ] ), ) ) fig . update_layout ( height = 350 , width = 475 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ), ) fig . write_image ( \"dln_layer_energies_example.pdf\" ) return fig Linear network \u00a4 key = jax . random . PRNGKey ( 0 ) subkeys = jax . random . split ( key , 21 ) network = [ eqx . nn . Linear ( 784 , 300 , key = subkeys [ 0 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 1 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 2 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 3 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 4 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 5 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 6 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 7 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 8 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 9 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 10 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 11 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 12 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 13 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 14 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 15 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 16 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 17 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 18 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 19 ], use_bias = False ), eqx . nn . Linear ( 300 , 10 , key = subkeys [ 20 ], use_bias = False ), ] Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): test_acc = 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch = img_batch . numpy () label_batch = label_batch . numpy () test_acc += jpc . test_discriminative_pc ( model = model , y = label_batch , x = img_batch ) return test_acc / len ( test_loader ) def train ( model , lr , batch_size , t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( eqx . filter ( model , eqx . is_array )) train_loader , test_loader = get_mnist_loaders ( batch_size ) As , cond_numbers = [], [] num_energies , theory_energies = [], [] num_total_energies , theory_total_energies = [], [] for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch = img_batch . numpy () label_batch = label_batch . numpy () theory_total_energies . append ( jpc . linear_equilib_energy ( network = model , x = img_batch , y = label_batch ) ) A = jpc . linear_activities_coeff_matrix ([ l . weight for l in model ]) As . append ( A ) cond_numbers . append ( jnp . linalg . cond ( A )) theory_activities = jpc . linear_activities_solution ( network = model , x = img_batch , y = label_batch ) theory_energies . append ( jnp . flip ( jpc . pc_energy_fn ( model , theory_activities , x = img_batch , y = label_batch , record_layers = True )) ) result = jpc . make_pc_step ( model , optim , opt_state , y = label_batch , x = img_batch , t1 = t1 , record_energies = True ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss , t_max = result [ \"loss\" ], result [ \"t_max\" ] num_total_energies . append ( result [ \"energies\" ][:, t_max - 1 ] . sum ()) num_energies . append ( result [ \"energies\" ][:, t_max - 1 ]) if (( iter + 1 ) % test_every ) == 0 : avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break return { \"experiment\" : jnp . array ( num_energies ), \"theory\" : jnp . array ( theory_energies ) }, As , cond_numbers Run \u00a4 energies , As , cond_numbers = train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , t1 = 300 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) /Users/fi69/PycharmProjects/jpc/venv/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release. Train iter 10, train loss=0.091816, avg test accuracy=0.270333 Train iter 20, train loss=0.089849, avg test accuracy=0.335337 import matplotlib.pyplot as plt plt . plot ( cond_numbers ) [] n_train_iters = energies [ \"theory\" ] . shape [ 0 ] n_energies = energies [ \"theory\" ] . shape [ 1 ] train_iters = [ b + 1 for b in range ( n_train_iters )] colors = [ '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' , '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' , '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' ] fig = go . Figure () for n in range ( n_energies ): fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"theory\" ][:, n ], mode = \"lines\" , line = dict ( width = 3 , dash = \"dash\" , color = colors [ n ] ), showlegend = False ) ) fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"experiment\" ][:, n ], name = f \"$\\Large {{ \\ell_ {{ { n + 1 } }}}} $\" , mode = \"lines\" , line = dict ( width = 2 , dash = \"solid\" , color = colors [ n ] ), ) ) fig . update_layout ( height = 400 , width = 600 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ) )","title":"Linear net theoretical activities"},{"location":"examples/linear_net_theoretical_activities/#theoretical-activities-of-deep-linear-networks","text":"%% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install plotly == 5.11.0 ! pip install - U kaleido import jpc import jax from jax import vmap import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import plotly.graph_objs as go import plotly.io as pio pio . renderers . default = 'iframe'","title":"Theoretical activities of deep linear networks"},{"location":"examples/linear_net_theoretical_activities/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 10 N_TRAIN_ITERS = 20","title":"Hyperparameters"},{"location":"examples/linear_net_theoretical_activities/#dataset","text":"Some utils to fetch MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ]","title":"Dataset"},{"location":"examples/linear_net_theoretical_activities/#plotting","text":"def plot_layer_energies ( energies ): n_train_iters = energies [ \"theory\" ] . shape [ 0 ] n_energies = energies [ \"theory\" ] . shape [ 1 ] train_iters = [ b + 1 for b in range ( n_train_iters )] colors = [ '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' ] fig = go . Figure () for n in range ( n_energies ): fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"theory\" ][:, n ], mode = \"lines\" , line = dict ( width = 2 , dash = \"dash\" , color = colors [ n ] ), showlegend = False ) ) fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"experiment\" ][:, n ], name = f \"$\\Large {{ \\ell_ { n + 1 } }} $\" , mode = \"lines\" , line = dict ( width = 3 , dash = \"solid\" , color = colors [ n ] ), ) ) fig . update_layout ( height = 350 , width = 475 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ), ) fig . write_image ( \"dln_layer_energies_example.pdf\" ) return fig","title":"Plotting"},{"location":"examples/linear_net_theoretical_activities/#linear-network","text":"key = jax . random . PRNGKey ( 0 ) subkeys = jax . random . split ( key , 21 ) network = [ eqx . nn . Linear ( 784 , 300 , key = subkeys [ 0 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 1 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 2 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 3 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 4 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 5 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 6 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 7 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 8 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 9 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 10 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 11 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 12 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 13 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 14 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 15 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 16 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 17 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 18 ], use_bias = False ), eqx . nn . Linear ( 300 , 300 , key = subkeys [ 19 ], use_bias = False ), eqx . nn . Linear ( 300 , 10 , key = subkeys [ 20 ], use_bias = False ), ]","title":"Linear network"},{"location":"examples/linear_net_theoretical_activities/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): test_acc = 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch = img_batch . numpy () label_batch = label_batch . numpy () test_acc += jpc . test_discriminative_pc ( model = model , y = label_batch , x = img_batch ) return test_acc / len ( test_loader ) def train ( model , lr , batch_size , t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( eqx . filter ( model , eqx . is_array )) train_loader , test_loader = get_mnist_loaders ( batch_size ) As , cond_numbers = [], [] num_energies , theory_energies = [], [] num_total_energies , theory_total_energies = [], [] for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch = img_batch . numpy () label_batch = label_batch . numpy () theory_total_energies . append ( jpc . linear_equilib_energy ( network = model , x = img_batch , y = label_batch ) ) A = jpc . linear_activities_coeff_matrix ([ l . weight for l in model ]) As . append ( A ) cond_numbers . append ( jnp . linalg . cond ( A )) theory_activities = jpc . linear_activities_solution ( network = model , x = img_batch , y = label_batch ) theory_energies . append ( jnp . flip ( jpc . pc_energy_fn ( model , theory_activities , x = img_batch , y = label_batch , record_layers = True )) ) result = jpc . make_pc_step ( model , optim , opt_state , y = label_batch , x = img_batch , t1 = t1 , record_energies = True ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss , t_max = result [ \"loss\" ], result [ \"t_max\" ] num_total_energies . append ( result [ \"energies\" ][:, t_max - 1 ] . sum ()) num_energies . append ( result [ \"energies\" ][:, t_max - 1 ]) if (( iter + 1 ) % test_every ) == 0 : avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break return { \"experiment\" : jnp . array ( num_energies ), \"theory\" : jnp . array ( theory_energies ) }, As , cond_numbers","title":"Train and test"},{"location":"examples/linear_net_theoretical_activities/#run","text":"energies , As , cond_numbers = train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , t1 = 300 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) /Users/fi69/PycharmProjects/jpc/venv/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release. Train iter 10, train loss=0.091816, avg test accuracy=0.270333 Train iter 20, train loss=0.089849, avg test accuracy=0.335337 import matplotlib.pyplot as plt plt . plot ( cond_numbers ) [] n_train_iters = energies [ \"theory\" ] . shape [ 0 ] n_energies = energies [ \"theory\" ] . shape [ 1 ] train_iters = [ b + 1 for b in range ( n_train_iters )] colors = [ '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' , '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' , '#636EFA' , '#EF553B' , '#00CC96' , '#AB63FA' , '#FFA15A' , '#19D3F3' , '#FF6692' , '#B6E880' , '#FF97FF' , '#FECB52' , '#8C564B' ] fig = go . Figure () for n in range ( n_energies ): fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"theory\" ][:, n ], mode = \"lines\" , line = dict ( width = 3 , dash = \"dash\" , color = colors [ n ] ), showlegend = False ) ) fig . add_traces ( go . Scatter ( x = train_iters , y = energies [ \"experiment\" ][:, n ], name = f \"$\\Large {{ \\ell_ {{ { n + 1 } }}}} $\" , mode = \"lines\" , line = dict ( width = 2 , dash = \"solid\" , color = colors [ n ] ), ) ) fig . update_layout ( height = 400 , width = 600 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ) )","title":"Run"},{"location":"examples/linear_net_theoretical_energy/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Theoretical energy of deep linear networks \u00a4 %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install plotly == 5.11.0 ! pip install - U kaleido import jpc import jax from jax import vmap import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import plotly.graph_objs as go import plotly.io as pio pio . renderers . default = 'iframe' import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 300 TEST_EVERY = 10 N_TRAIN_ITERS = 100 Dataset \u00a4 Some utils to fetch MNIST. def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] Plotting \u00a4 def plot_total_energies ( energies ): n_train_iters = len ( energies [ \"theory\" ]) train_iters = [ b + 1 for b in range ( n_train_iters )] fig = go . Figure () for energy_type , energy in energies . items (): is_theory = energy_type == \"theory\" fig . add_traces ( go . Scatter ( x = train_iters , y = energy , name = energy_type , mode = \"lines\" , line = dict ( width = 3 , dash = \"dash\" if is_theory else \"solid\" , color = \"rgb(27, 158, 119)\" if is_theory else \"#00CC96\" ), legendrank = 1 if is_theory else 2 ) ) fig . update_layout ( height = 300 , width = 450 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ), ) fig . write_image ( \"dln_total_energy.pdf\" ) return fig Linear network \u00a4 key = jax . random . PRNGKey ( 0 ) width , n_hidden = 300 , 10 network = jpc . make_mlp ( key , [ 784 ] + [ width ] * n_hidden + [ 10 ], act_fn = \"linear\" , use_bias = False ) Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , output = label_batch , input = img_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) num_total_energies , theory_total_energies = [], [] for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () theory_total_energies . append ( jpc . linear_equilib_energy ( network = model , x = img_batch , y = label_batch ) ) result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch , max_t1 = max_t1 , record_energies = True ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss , t_max = result [ \"loss\" ], result [ \"t_max\" ] num_total_energies . append ( result [ \"energies\" ][:, t_max - 1 ] . sum ()) if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break return { \"experiment\" : jnp . array ( num_total_energies ), \"theory\" : jnp . array ( theory_total_energies ) } Run \u00a4 energies = train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_total_energies ( energies ) Train iter 10, train loss=0.067622, avg test accuracy=59.535255 Train iter 20, train loss=0.054068, avg test accuracy=75.230370 Train iter 30, train loss=0.063356, avg test accuracy=77.453926 Train iter 40, train loss=0.051848, avg test accuracy=80.048080 Train iter 50, train loss=0.061488, avg test accuracy=82.061295 Train iter 60, train loss=0.044830, avg test accuracy=80.789261 Train iter 70, train loss=0.045716, avg test accuracy=84.174683 Train iter 80, train loss=0.053921, avg test accuracy=82.041267 Train iter 90, train loss=0.040125, avg test accuracy=83.072914 Train iter 100, train loss=0.050980, avg test accuracy=83.974358","title":"Linear theoretical energy"},{"location":"examples/linear_net_theoretical_energy/#theoretical-energy-of-deep-linear-networks","text":"%% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install plotly == 5.11.0 ! pip install - U kaleido import jpc import jax from jax import vmap import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import plotly.graph_objs as go import plotly.io as pio pio . renderers . default = 'iframe' import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Theoretical energy of deep linear networks"},{"location":"examples/linear_net_theoretical_energy/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 300 TEST_EVERY = 10 N_TRAIN_ITERS = 100","title":"Hyperparameters"},{"location":"examples/linear_net_theoretical_energy/#dataset","text":"Some utils to fetch MNIST. def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ]","title":"Dataset"},{"location":"examples/linear_net_theoretical_energy/#plotting","text":"def plot_total_energies ( energies ): n_train_iters = len ( energies [ \"theory\" ]) train_iters = [ b + 1 for b in range ( n_train_iters )] fig = go . Figure () for energy_type , energy in energies . items (): is_theory = energy_type == \"theory\" fig . add_traces ( go . Scatter ( x = train_iters , y = energy , name = energy_type , mode = \"lines\" , line = dict ( width = 3 , dash = \"dash\" if is_theory else \"solid\" , color = \"rgb(27, 158, 119)\" if is_theory else \"#00CC96\" ), legendrank = 1 if is_theory else 2 ) ) fig . update_layout ( height = 300 , width = 450 , xaxis = dict ( title = \"Training iteration\" , tickvals = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ticktext = [ 1 , int ( train_iters [ - 1 ] / 2 ), train_iters [ - 1 ]], ), yaxis = dict ( title = \"Energy\" , nticks = 3 ), font = dict ( size = 16 ), ) fig . write_image ( \"dln_total_energy.pdf\" ) return fig","title":"Plotting"},{"location":"examples/linear_net_theoretical_energy/#linear-network","text":"key = jax . random . PRNGKey ( 0 ) width , n_hidden = 300 , 10 network = jpc . make_mlp ( key , [ 784 ] + [ width ] * n_hidden + [ 10 ], act_fn = \"linear\" , use_bias = False )","title":"Linear network"},{"location":"examples/linear_net_theoretical_energy/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. def evaluate ( model , test_loader ): avg_test_loss , avg_test_acc = 0 , 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () test_loss , test_acc = jpc . test_discriminative_pc ( model = model , output = label_batch , input = img_batch ) avg_test_loss += test_loss avg_test_acc += test_acc return avg_test_loss / len ( test_loader ), avg_test_acc / len ( test_loader ) def train ( model , lr , batch_size , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( model , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) num_total_energies , theory_total_energies = [], [] for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () theory_total_energies . append ( jpc . linear_equilib_energy ( network = model , x = img_batch , y = label_batch ) ) result = jpc . make_pc_step ( model , optim , opt_state , output = label_batch , input = img_batch , max_t1 = max_t1 , record_energies = True ) model , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss , t_max = result [ \"loss\" ], result [ \"t_max\" ] num_total_energies . append ( result [ \"energies\" ][:, t_max - 1 ] . sum ()) if (( iter + 1 ) % test_every ) == 0 : avg_test_loss , avg_test_acc = evaluate ( model , test_loader ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break return { \"experiment\" : jnp . array ( num_total_energies ), \"theory\" : jnp . array ( theory_total_energies ) }","title":"Train and test"},{"location":"examples/linear_net_theoretical_energy/#run","text":"energies = train ( model = network , lr = LEARNING_RATE , batch_size = BATCH_SIZE , test_every = TEST_EVERY , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_total_energies ( energies ) Train iter 10, train loss=0.067622, avg test accuracy=59.535255 Train iter 20, train loss=0.054068, avg test accuracy=75.230370 Train iter 30, train loss=0.063356, avg test accuracy=77.453926 Train iter 40, train loss=0.051848, avg test accuracy=80.048080 Train iter 50, train loss=0.061488, avg test accuracy=82.061295 Train iter 60, train loss=0.044830, avg test accuracy=80.789261 Train iter 70, train loss=0.045716, avg test accuracy=84.174683 Train iter 80, train loss=0.053921, avg test accuracy=82.041267 Train iter 90, train loss=0.040125, avg test accuracy=83.072914 Train iter 100, train loss=0.050980, avg test accuracy=83.974358","title":"Run"},{"location":"examples/supervised_generative_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Generative PC on MNIST \u00a4 This notebook demonstrates how to train a neural network with predictive coding to generate MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn from diffrax import Heun import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 TEST_EVERY = 50 N_TRAIN_ITERS = 200 Dataset \u00a4 Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_img_preds ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ], fontsize = 16 ) plt . show () Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )] Train and test \u00a4 A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_generative_pc() to get some test metrics including accuracy of inferred labels and image predictions. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. Note that to train in an unsupervised way, you can simply remove the input from jpc.make_pc_step() and the evaluate() script. def evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 ): test_acc = 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () acc , img_preds = jpc . test_generative_pc ( model = network , input = label_batch , output = img_batch , key = key , layer_sizes = layer_sizes , batch_size = batch_size , max_t1 = max_t1 ) test_acc += acc avg_test_acc = test_acc / len ( test_loader ) return avg_test_acc , label_batch , img_preds def train ( key , layer_sizes , batch_size , network , lr , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , input = label_batch , output = img_batch , max_t1 = max_t1 ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_acc , test_label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 = max_t1 ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_img_preds ( img_preds , test_label_batch ) return network Run \u00a4 network = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 50, train loss=0.631369, avg test accuracy=74.959938 Train iter 100, train loss=0.607500, avg test accuracy=79.206734 Train iter 150, train loss=0.577637, avg test accuracy=80.418671 Train iter 200, train loss=0.555235, avg test accuracy=79.236778","title":"Supervised generative PC"},{"location":"examples/supervised_generative_pc/#generative-pc-on-mnist","text":"This notebook demonstrates how to train a neural network with predictive coding to generate MNIST digits. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn from diffrax import Heun import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Generative PC on MNIST"},{"location":"examples/supervised_generative_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 10 , 300 , 300 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 TEST_EVERY = 50 N_TRAIN_ITERS = 200","title":"Hyperparameters"},{"location":"examples/supervised_generative_pc/#dataset","text":"Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , label = super () . __getitem__ ( index ) img = torch . flatten ( img ) label = one_hot ( label ) return img , label def one_hot ( labels , n_classes = 10 ): arr = torch . eye ( n_classes ) return arr [ labels ] def plot_mnist_img_preds ( imgs , labels , n_imgs = 10 ): plt . figure ( figsize = ( 20 , 2 )) for i in range ( n_imgs ): plt . subplot ( 1 , n_imgs , i + 1 ) plt . xticks ([]) plt . yticks ([]) plt . grid ( False ) plt . imshow ( imgs [ i ] . reshape ( 28 , 28 ), cmap = plt . cm . binary_r ) plt . xlabel ( jnp . argmax ( labels , axis = 1 )[ i ], fontsize = 16 ) plt . show ()","title":"Dataset"},{"location":"examples/supervised_generative_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/supervised_generative_pc/#train-and-test","text":"A PC network can be trained in a single line of code with jpc.make_pc_step() . See the documentation for more. Similarly, we can use jpc.test_generative_pc() to get some test metrics including accuracy of inferred labels and image predictions. Note that these functions are already \"jitted\" for performance. Below we simply wrap each of these functions in our training and test loops, respectively. Note that to train in an unsupervised way, you can simply remove the input from jpc.make_pc_step() and the evaluate() script. def evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 ): test_acc = 0 for batch_id , ( img_batch , label_batch ) in enumerate ( test_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () acc , img_preds = jpc . test_generative_pc ( model = network , input = label_batch , output = img_batch , key = key , layer_sizes = layer_sizes , batch_size = batch_size , max_t1 = max_t1 ) test_acc += acc avg_test_acc = test_acc / len ( test_loader ) return avg_test_acc , label_batch , img_preds def train ( key , layer_sizes , batch_size , network , lr , max_t1 , test_every , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) for iter , ( img_batch , label_batch ) in enumerate ( train_loader ): img_batch , label_batch = img_batch . numpy (), label_batch . numpy () result = jpc . make_pc_step ( model = network , optim = optim , opt_state = opt_state , input = label_batch , output = img_batch , max_t1 = max_t1 ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_loss = result [ \"loss\" ] if (( iter + 1 ) % test_every ) == 0 : avg_test_acc , test_label_batch , img_preds = evaluate ( key , layer_sizes , batch_size , network , test_loader , max_t1 = max_t1 ) print ( f \"Train iter { iter + 1 } , train loss= { train_loss : 4f } , \" f \"avg test accuracy= { avg_test_acc : 4f } \" ) if ( iter + 1 ) >= n_train_iters : break plot_mnist_img_preds ( img_preds , test_label_batch ) return network","title":"Train and test"},{"location":"examples/supervised_generative_pc/#run","text":"network = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , test_every = TEST_EVERY , n_train_iters = N_TRAIN_ITERS ) Train iter 50, train loss=0.631369, avg test accuracy=74.959938 Train iter 100, train loss=0.607500, avg test accuracy=79.206734 Train iter 150, train loss=0.577637, avg test accuracy=80.418671 Train iter 200, train loss=0.555235, avg test accuracy=79.236778","title":"Run"},{"location":"examples/unsupervised_generative_pc/","text":"(function() { function addWidgetsRenderer() { var requireJsScript = document.createElement('script'); requireJsScript.src = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js'; var mimeElement = document.querySelector('script[type=\"application/vnd.jupyter.widget-view+json\"]'); var jupyterWidgetsScript = document.createElement('script'); var widgetRendererSrc = 'https://unpkg.com/@jupyter-widgets/html-manager@*/dist/embed-amd.js'; var widgetState; // Fallback for older version: try { widgetState = mimeElement && JSON.parse(mimeElement.innerHTML); if (widgetState && (widgetState.version_major < 2 || !widgetState.version_major)) { widgetRendererSrc = 'jupyter-js-widgets@*/dist/embed.js'; } } catch(e) {} jupyterWidgetsScript.src = widgetRendererSrc; document.body.appendChild(requireJsScript); document.body.appendChild(jupyterWidgetsScript); } document.addEventListener('DOMContentLoaded', addWidgetsRenderer); }()); Unsupervised generative PC on MNIST \u00a4 This notebook demonstrates how to train a neural network with predictive coding to encode MNIST digits in an unsupervised manner. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import matplotlib.colors as mcolors import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings Hyperparameters \u00a4 We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 50 , 100 , 100 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 N_TRAIN_ITERS = 300 Dataset \u00a4 Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , _ = super () . __getitem__ ( index ) img = torch . flatten ( img ) return img Plotting \u00a4 def plot_train_energies ( train_energies , ts ): t_max = int ( ts [ 0 ]) norm = mcolors . Normalize ( vmin = 0 , vmax = len ( energies ) - 1 ) fig , ax = plt . subplots ( figsize = ( 8 , 4 )) cmap_blues = plt . get_cmap ( \"Blues\" ) cmap_reds = plt . get_cmap ( \"Reds\" ) cmap_greens = plt . get_cmap ( \"Greens\" ) legend_handles = [] legend_labels = [] for t , energies_iter in enumerate ( energies ): line1 , = ax . plot ( energies_iter [ 0 , : t_max ], color = cmap_blues ( norm ( t ))) line2 , = ax . plot ( energies_iter [ 1 , : t_max ], color = cmap_reds ( norm ( t ))) line3 , = ax . plot ( energies_iter [ 2 , : t_max ], color = cmap_greens ( norm ( t ))) if t == 70 : legend_handles . append ( line1 ) legend_labels . append ( \"$\\ell_1$\" ) legend_handles . append ( line2 ) legend_labels . append ( \"$\\ell_2$\" ) legend_handles . append ( line3 ) legend_labels . append ( \"$\\ell_3$\" ) ax . legend ( legend_handles , legend_labels , loc = \"best\" , fontsize = 16 ) sm = plt . cm . ScalarMappable ( cmap = plt . get_cmap ( \"Greys\" ), norm = norm ) sm . _A = [] cbar = fig . colorbar ( sm , ax = ax ) cbar . set_label ( \"Training iteration\" , fontsize = 16 , labelpad = 14 ) cbar . ax . tick_params ( labelsize = 14 ) plt . gca () . tick_params ( axis = \"both\" , which = \"major\" , labelsize = 16 ) ax . set_xlabel ( \"Inference iterations\" , fontsize = 18 , labelpad = 14 ) ax . set_ylabel ( \"Energy\" , fontsize = 18 , labelpad = 14 ) ax . set_yscale ( \"log\" ) plt . show () Network \u00a4 For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[100,50], bias=f32[100], in_features=50, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[100,100], bias=f32[100], in_features=100, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,100], bias=f32[784], in_features=100, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )] Train \u00a4 def train ( key , layer_sizes , batch_size , network , lr , max_t1 , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) train_energies , ts = [], [] for iter , img_batch in enumerate ( train_loader ): img_batch = img_batch . numpy () result = jpc . make_pc_step ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , model = network , optim = optim , opt_state = opt_state , output = img_batch , max_t1 = max_t1 , record_activities = True , record_energies = True ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_energies . append ( result [ \"energies\" ]) ts . append ( result [ \"t_max\" ]) if ( iter + 1 ) >= n_train_iters : break return result [ \"model\" ], train_energies , ts Run \u00a4 network , energies , ts = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_train_energies ( energies , ts )","title":"Unsupervised generative PC"},{"location":"examples/unsupervised_generative_pc/#unsupervised-generative-pc-on-mnist","text":"This notebook demonstrates how to train a neural network with predictive coding to encode MNIST digits in an unsupervised manner. %% capture ! pip install torch == 2.3.1 ! pip install torchvision == 0.18.1 ! pip install matplotlib == 3.0.0 import jpc import jax import jax.numpy as jnp import equinox as eqx import equinox.nn as nn import optax import torch from torch.utils.data import DataLoader from torchvision import datasets , transforms import matplotlib.pyplot as plt import matplotlib.colors as mcolors import warnings warnings . simplefilter ( 'ignore' ) # ignore warnings","title":"Unsupervised generative PC on MNIST"},{"location":"examples/unsupervised_generative_pc/#hyperparameters","text":"We define some global parameters, including network architecture, learning rate, batch size etc. SEED = 0 LAYER_SIZES = [ 50 , 100 , 100 , 784 ] ACT_FN = \"relu\" LEARNING_RATE = 1e-3 BATCH_SIZE = 64 MAX_T1 = 100 N_TRAIN_ITERS = 300","title":"Hyperparameters"},{"location":"examples/unsupervised_generative_pc/#dataset","text":"Some utils to fetch and plot MNIST. #@title data utils def get_mnist_loaders ( batch_size ): train_data = MNIST ( train = True , normalise = True ) test_data = MNIST ( train = False , normalise = True ) train_loader = DataLoader ( dataset = train_data , batch_size = batch_size , shuffle = True , drop_last = True ) test_loader = DataLoader ( dataset = test_data , batch_size = batch_size , shuffle = True , drop_last = True ) return train_loader , test_loader class MNIST ( datasets . MNIST ): def __init__ ( self , train , normalise = True , save_dir = \"data\" ): if normalise : transform = transforms . Compose ( [ transforms . ToTensor (), transforms . Normalize ( mean = ( 0.1307 ), std = ( 0.3081 ) ) ] ) else : transform = transforms . Compose ([ transforms . ToTensor ()]) super () . __init__ ( save_dir , download = True , train = train , transform = transform ) def __getitem__ ( self , index ): img , _ = super () . __getitem__ ( index ) img = torch . flatten ( img ) return img","title":"Dataset"},{"location":"examples/unsupervised_generative_pc/#plotting","text":"def plot_train_energies ( train_energies , ts ): t_max = int ( ts [ 0 ]) norm = mcolors . Normalize ( vmin = 0 , vmax = len ( energies ) - 1 ) fig , ax = plt . subplots ( figsize = ( 8 , 4 )) cmap_blues = plt . get_cmap ( \"Blues\" ) cmap_reds = plt . get_cmap ( \"Reds\" ) cmap_greens = plt . get_cmap ( \"Greens\" ) legend_handles = [] legend_labels = [] for t , energies_iter in enumerate ( energies ): line1 , = ax . plot ( energies_iter [ 0 , : t_max ], color = cmap_blues ( norm ( t ))) line2 , = ax . plot ( energies_iter [ 1 , : t_max ], color = cmap_reds ( norm ( t ))) line3 , = ax . plot ( energies_iter [ 2 , : t_max ], color = cmap_greens ( norm ( t ))) if t == 70 : legend_handles . append ( line1 ) legend_labels . append ( \"$\\ell_1$\" ) legend_handles . append ( line2 ) legend_labels . append ( \"$\\ell_2$\" ) legend_handles . append ( line3 ) legend_labels . append ( \"$\\ell_3$\" ) ax . legend ( legend_handles , legend_labels , loc = \"best\" , fontsize = 16 ) sm = plt . cm . ScalarMappable ( cmap = plt . get_cmap ( \"Greys\" ), norm = norm ) sm . _A = [] cbar = fig . colorbar ( sm , ax = ax ) cbar . set_label ( \"Training iteration\" , fontsize = 16 , labelpad = 14 ) cbar . ax . tick_params ( labelsize = 14 ) plt . gca () . tick_params ( axis = \"both\" , which = \"major\" , labelsize = 16 ) ax . set_xlabel ( \"Inference iterations\" , fontsize = 18 , labelpad = 14 ) ax . set_ylabel ( \"Energy\" , fontsize = 18 , labelpad = 14 ) ax . set_yscale ( \"log\" ) plt . show ()","title":"Plotting"},{"location":"examples/unsupervised_generative_pc/#network","text":"For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in Equinox . For example, we can define a ReLU MLP with two hidden layers as follows key = jax . random . PRNGKey ( SEED ) key , * subkeys = jax . random . split ( key , 4 ) network = [ nn . Sequential ( [ nn . Linear ( 10 , 300 , key = subkeys [ 0 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Sequential ( [ nn . Linear ( 300 , 300 , key = subkeys [ 1 ]), nn . Lambda ( jax . nn . relu ) ], ), nn . Linear ( 300 , 784 , key = subkeys [ 2 ]), ] print ( network ) [Sequential( layers=( Linear( weight=f32[300,10], bias=f32[300], in_features=10, out_features=300, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=) ) ), Linear( weight=f32[784,300], bias=f32[784], in_features=300, out_features=784, use_bias=True )] You can also use the utility jpc.get_fc_network to define an MLP or fully connected network with some activation functions. network = jpc . make_mlp ( key , LAYER_SIZES , act_fn = \"relu\" ) print ( network ) [Sequential( layers=( Linear( weight=f32[100,50], bias=f32[100], in_features=50, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[100,100], bias=f32[100], in_features=100, out_features=100, use_bias=True ), Lambda(fn=) ) ), Sequential( layers=( Linear( weight=f32[784,100], bias=f32[784], in_features=100, out_features=784, use_bias=True ), Lambda(fn=Identity()) ) )]","title":"Network"},{"location":"examples/unsupervised_generative_pc/#train","text":"def train ( key , layer_sizes , batch_size , network , lr , max_t1 , n_train_iters ): optim = optax . adam ( lr ) opt_state = optim . init ( ( eqx . filter ( network , eqx . is_array ), None ) ) train_loader , test_loader = get_mnist_loaders ( batch_size ) train_energies , ts = [], [] for iter , img_batch in enumerate ( train_loader ): img_batch = img_batch . numpy () result = jpc . make_pc_step ( key = key , layer_sizes = layer_sizes , batch_size = batch_size , model = network , optim = optim , opt_state = opt_state , output = img_batch , max_t1 = max_t1 , record_activities = True , record_energies = True ) network , optim , opt_state = result [ \"model\" ], result [ \"optim\" ], result [ \"opt_state\" ] train_energies . append ( result [ \"energies\" ]) ts . append ( result [ \"t_max\" ]) if ( iter + 1 ) >= n_train_iters : break return result [ \"model\" ], train_energies , ts","title":"Train"},{"location":"examples/unsupervised_generative_pc/#run","text":"network , energies , ts = train ( key = key , layer_sizes = LAYER_SIZES , batch_size = BATCH_SIZE , network = network , lr = LEARNING_RATE , max_t1 = MAX_T1 , n_train_iters = N_TRAIN_ITERS ) plot_train_energies ( energies , ts )","title":"Run"}]} \ No newline at end of file diff --git a/sitemap.xml b/sitemap.xml index 9f3324e..95a3763 100644 --- a/sitemap.xml +++ b/sitemap.xml @@ -95,4 +95,14 @@ 2024-11-25 daily + + None + 2024-11-25 + daily + + + None + 2024-11-25 + daily + \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 50bc90b..4b9f8af 100644 Binary files a/sitemap.xml.gz and b/sitemap.xml.gz differ