⚙️ Implementing Predictive Coding from Scratch (in JAX)¤
+In this notebook, we walk through how to implement Predictive Coding (PC) from scratch using JAX. In particular, we are going to train a simple feedforward network to classify MNIST digits with PC. It might be a good idea to revisit the introductory lecture on PC from this week.
+If you're not familiar with JAX, have a look at their docs, but we will explain all the necessary concepts below. JAX is basically numpy for GPUs and other hardware accelerators. We will also use: Equinox, which allows you to define neural nets with PyTorch-like syntax; and Optax, which provides a range of common machine learning optimisers such as gradient descent and Adam.
+Installations & imports¤
+#@title installations
+
+
+%%capture
+!pip install torch==2.3.1
+!pip install torchvision==0.18.1
+
#@title imports
+
+
+import jax.random as jr
+import jax.numpy as jnp
+from jax import vmap, grad
+from jax.tree_util import tree_map
+
+import equinox as eqx
+import equinox.nn as nn
+from equinox import filter_grad
+import optax
+
+import torch
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+
+import warnings
+warnings.simplefilter("ignore")
+
Hyperparameters¤
+We define some global parameters related to the data, network, optimisers, etc.
+SEED = 827
+
+INPUT_DIM, OUTPUT_DIM = 28*28, 10
+NETWORK_WIDTH = 300
+
+ACTIVITY_LR = 1e-1
+INFERENCE_STEPS = 20
+
+PARAM_LR = 1e-3
+BATCH_SIZE = 64
+
+TEST_EVERY = 50
+N_TRAIN_ITERS = 500
+
Dataset¤
+Some utils to fetch MNIST.
+#@title data utils
+
+
+def get_mnist_loaders(batch_size):
+ train_data = MNIST(train=True, normalise=True)
+ test_data = MNIST(train=False, normalise=True)
+ train_loader = DataLoader(
+ dataset=train_data,
+ batch_size=batch_size,
+ shuffle=True,
+ drop_last=True
+ )
+ test_loader = DataLoader(
+ dataset=test_data,
+ batch_size=batch_size,
+ shuffle=True,
+ drop_last=True
+ )
+ return train_loader, test_loader
+
+
+class MNIST(datasets.MNIST):
+ def __init__(self, train, normalise=True, save_dir="data"):
+ if normalise:
+ transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=(0.1307), std=(0.3081)
+ )
+ ]
+ )
+ else:
+ transform = transforms.Compose([transforms.ToTensor()])
+ super().__init__(save_dir, download=True, train=train, transform=transform)
+
+ def __getitem__(self, index):
+ img, label = super().__getitem__(index)
+ img = torch.flatten(img)
+ label = one_hot(label)
+ return img, label
+
+
+def one_hot(labels, n_classes=10):
+ arr = torch.eye(n_classes)
+ return arr[labels]
+
PC energy¤
+First, recall that PC can be derived as a variational inference algorithm under certain assumptions. In particular, if we assume +* a dirac delta (point mass) posterior and +* a hierarchical Gaussian generative model,
+we get the standard PC energy
+\begin{equation} + \mathcal{F} = \frac{1}{2N}\sum_{i=1}^{N} \sum_{\ell=1}^L ||\mathbf{z}{\ell, i} - f\ell(W_\ell \mathbf{z}{\ell-1, i} + \mathbf{b}\ell)||^2_2 +\end{equation} +which is just a sum of squared prediction errors at each network layer. Here we are being a little bit more precise than in the lecture, including multiple (\(N\)) data points and biases \(\mathbf{b}_\ell\).
+🤔 Food for thought: Think about how the form of this energy could change depending other assumptions we make about the generative model. See, for example, Learning on Arbitrary Graph Topologies via Predictive Coding + by Salvatori et al. (2022).
+Let's start by implementing this energy below. The function simply takes the model (with all the parameters), some initialised activities, and some input and output. Given these, it simply sums the prediction error at each layer.
+NOTE: below we use vmap
, one of the core JAX transforms that allows you to vectorise operations, in this case for multiple data points or over a batch. See https://jax.readthedocs.io/en/latest/automatic-vectorization.html for more details.
def pc_energy_fn(model, activities, input, output):
+ batch_size = output.shape[0]
+ n_activity_layers = len(activities) - 1
+ n_layers = len(model) - 1
+
+ eL = output - vmap(model[-1])(activities[-2])
+ energies = [jnp.sum(eL ** 2)]
+ for act_l, net_l in zip(
+ range(1, n_activity_layers),
+ range(1, n_layers)
+ ):
+ err = activities[act_l] - vmap(model[net_l])(activities[act_l - 1])
+ energies.append(jnp.sum(err ** 2))
+
+ e1 = activities[0] - vmap(model[0])(input)
+ energies.append(jnp.sum(e1 ** 2))
+
+ return jnp.sum(jnp.array(energies)) / batch_size
+
Now let's test it. To do so, we first need a model. Below we use Equinox to create a simple feedforward network with 2 hidden layers and tanh activations. Note that we split the model into different parts with nn.Sequential
to define the activities which PC will optimise over (during inference, more on this below).
❓ Question: Think about other ways in which we could split the layers, for example by separating the non-linearities. Can you think of potential issues with this?
+# jax uses explicit random number generators (see https://jax.readthedocs.io/en/latest/random-numbers.html)
+key = jr.PRNGKey(SEED)
+subkeys = jr.split(key, 3)
+
+model = [
+ nn.Sequential(
+ [
+ nn.Linear(INPUT_DIM, NETWORK_WIDTH, key=subkeys[0]),
+ nn.Lambda(jnp.tanh)
+ ],
+ ),
+ nn.Sequential(
+ [
+ nn.Linear(NETWORK_WIDTH, NETWORK_WIDTH, key=subkeys[1]),
+ nn.Lambda(jnp.tanh)
+ ],
+ ),
+ nn.Linear(NETWORK_WIDTH, OUTPUT_DIM, key=subkeys[2]),
+]
+model
+
The last thing with need is to initialise the activities. For this, we will use a feedforward pass as often done in practice.
+❓ Question: Can you think of other ways of initialising the activities?
+def init_activities_with_ffwd(model, input):
+ activities = [vmap(model[0])(input)]
+ for l in range(1, len(model)):
+ layer_output = vmap(model[l])(activities[l - 1])
+ activities.append(layer_output)
+
+ return activities
+
Let's test it on an MNIST sample.
+# get a data sample
+train_loader, test_loader = get_mnist_loaders(BATCH_SIZE)
+img_batch, label_batch = next(iter(train_loader))
+
+# we need to turn the torch.Tensor data into numpy arrays for jax
+img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
+
+# let's check our initialised activities
+activities = init_activities_with_ffwd(model, img_batch)
+for i, a in enumerate(activities):
+ print(f"activity z at layer {i+1}: {a.shape}")
+
Ok so now we have everything to test our PC energy function: model, activities, and some data.
+pc_energy_fn(
+ model=model,
+ activities=activities,
+ input=img_batch,
+ output=label_batch
+)
+
And it works!
+Energy gradients¤
+How do we minimise the PC energy we defined above (Eq. 1)? Recall from the lecture that we do this in two phases: first with respect to the activities (inference) and then with respect to the weights (learning).
+So we just need to take these gradients of the energy. We are going to use autodiff, which JAX embeds by design (https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). If you're familiar with PyTorch, you are probably used to loss.backward()
for this, which might feel obstruse at times. JAX, on the other hand, is a fully functional (as opposed to object-oriented) language whose syntax is very close to the maths as you can see below.
# note how close this code is to the maths
+# this can be read as "take the gradient of the energy...
+# ...with the respect to the 2nd argument (the activities)
+
+def compute_activity_grad(model, activities, input, output):
+ return grad(pc_energy_fn, argnums=1)(
+ model,
+ activities,
+ input,
+ output
+ )
+
Let's test this out.
+dFdzs = compute_activity_grad(
+ model=model,
+ activities=activities,
+ input=img_batch,
+ output=label_batch
+)
+for i, dFdz in enumerate(dFdzs):
+ print(f"activity gradient dFdz shape at layer {i+1}: {dFdz.shape}")
+
Now we do the same and take the gradient of the energy with respect to the parameters.
+Technical note: below we use Equinox's convenience function filter_grad
rather than JAX's native grad
. This is because things like activation functions do not have parameters and so we do not want to differentiate them. filter_grad
automatically filters these non-differentiable objects for us, while grad
alone would throw an error.
# note that, compared to the previous function,...
+# ...we just change the argument with respect to which...
+# ...we are differentiating (the first, or in this case the model)
+
+def compute_param_grad(model, activities, input, output):
+ return filter_grad(pc_energy_fn)(
+ model,
+ activities,
+ input,
+ output
+ )
+
And let's test it.
+param_grads = compute_param_grad(
+ model=model,
+ activities=activities,
+ input=img_batch,
+ output=label_batch
+)
+
Updates¤
+Before putting everything together, let's wrap our gradients into update functions. This will also allow us to use JAX's jit
primitive, which essentially compiles your code the first time it's executed so that it can be run more efficiently the next time (see https://jax.readthedocs.io/en/latest/jit-compilation.html for more details).
These functions take an (Optax) optimiser such as gradient descent in addition to the previous arguments (model, activities and data).
+@eqx.filter_jit
+def update_activities(model, activities, optim, opt_state, input, output):
+ activity_grads = compute_activity_grad(
+ model=model,
+ activities=activities,
+ input=input,
+ output=output
+ )
+ activity_updates, activity_opt_state = optim.update(
+ updates=activity_grads,
+ state=opt_state,
+ params=activities
+ )
+ activities = eqx.apply_updates(
+ model=activities,
+ updates=activity_updates
+ )
+ return activities, optim, opt_state
+
+
+# note that the only difference with the above function is...
+# ...the variable we are updating (parameters vs activities)
+@eqx.filter_jit
+def update_params(model, activities, optim, opt_state, input, output):
+ param_grads = compute_param_grad(
+ model=model,
+ activities=activities,
+ input=input,
+ output=output
+ )
+ param_updates, param_opt_state = optim.update(
+ updates=param_grads,
+ state=opt_state,
+ params=model
+ )
+ model = eqx.apply_updates(
+ model=model,
+ updates=param_updates
+ )
+ return model, optim, opt_state
+
Putting everything together: Training and testing¤
+Now that we have our activity and parameter updates, we just need to wrap them in a training and test loop.
+# note: the accuracy computation below could be sped up...
+# ...with jit in a separate function
+
+def evaluate(model, test_loader):
+ avg_test_acc = 0
+ for test_iter, (img_batch, label_batch) in enumerate(test_loader):
+ img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
+
+ preds = init_activities_with_ffwd(model, img_batch)[-1]
+ test_acc = jnp.mean(
+ jnp.argmax(label_batch, axis=1) == jnp.argmax(preds, axis=1)
+ ) * 100
+ avg_test_acc += test_acc
+
+ return avg_test_acc / len(test_loader)
+
+
+def train(
+ model,
+ activity_lr,
+ inference_steps,
+ param_lr,
+ batch_size,
+ test_every,
+ n_train_iters
+):
+ # define optimisers for activities and parameters
+ activity_optim = optax.sgd(activity_lr)
+ param_optim = optax.adam(param_lr)
+ param_opt_state = param_optim.init(eqx.filter(model, eqx.is_array))
+
+ train_loader, test_loader = get_mnist_loaders(batch_size)
+ for train_iter, (img_batch, label_batch) in enumerate(train_loader):
+ img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
+
+ # initialise activities
+ activities = init_activities_with_ffwd(model, img_batch)
+ activity_opt_state = activity_optim.init(activities)
+
+ train_loss = jnp.mean((label_batch - activities[-1])**2)
+
+ # inference
+ for t in range(inference_steps):
+ activities, activity_optim, activity_opt_state = update_activities(
+ model=model,
+ activities=activities,
+ optim=activity_optim,
+ opt_state=activity_opt_state,
+ input=img_batch,
+ output=label_batch
+ )
+
+ # learning
+ model, param_optim, param_opt_state = update_params(
+ model=model,
+ activities=activities, # note how we use the optimised activities
+ optim=param_optim,
+ opt_state=param_opt_state,
+ input=img_batch,
+ output=label_batch
+ )
+ if ((train_iter+1) % test_every) == 0:
+ avg_test_acc = evaluate(model, test_loader)
+ print(
+ f"Train iter {train_iter+1}, train loss={train_loss:4f}, "
+ f"avg test accuracy={avg_test_acc:4f}"
+ )
+ if (train_iter+1) >= n_train_iters:
+ break
+
Run¤
+Let's test our implementation.
+train(
+ model=model,
+ activity_lr=ACTIVITY_LR,
+ inference_steps=INFERENCE_STEPS,
+ param_lr=PARAM_LR,
+ batch_size=BATCH_SIZE,
+ test_every=TEST_EVERY,
+ n_train_iters=N_TRAIN_ITERS
+)
+
🥳 Great, we see that our model is training! You can probably improve the performance by tweaking some of the hyperparameters (e.g. try a higher number of inference steps).
+Even if you didn't follow all the implementation details, you should now have at least an idea of how PC works in practice. Indeed, this is basically the core code behind a new PC library our lab will soon release: JPC. Play around with the notebook examples there where you can learn how to train a variety of PC networks.
+