Skip to content

Commit

Permalink
Expand on advance usage docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 27, 2024
1 parent 1c2d507 commit db11f5d
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions docs/advanced_usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Advanced usage

Advanced users can access all the underlying functions of `jpc.make_pc_step` as
well as additional features. A custom PC training step looks like the following:
```py
import jpc

# 1. initialise activities with a feedforward pass
activities = jpc.init_activities_with_ffwd(model=model, input=x)

# 2. run inference to equilibrium
equilibrated_activities = jpc.solve_inference(
params=(model, None),
activities=activities,
output=y,
input=x
)

# 3. update parameters at the activities' solution with PC
param_update_result = jpc.update_params(
params=(model, None),
activities=equilibrated_activities,
optim=param_optim,
opt_state=param_opt_state,
output=y,
input=x
)

# updated model and optimiser
model = param_update_result["model"]
param_optim = param_update_result["optim"]
param_opt_state = param_update_result["opt_state"]
```
which can be embedded in a jitted function with any other additional
computations. One can also use any Optax optimiser to equilibrate the inference
dynamics by replacing the function in step 2, as shown below.
```py
activity_optim = optax.sgd(1e-3)

# 1. initialise activities
...

# 2. infer with gradient descent
activity_opt_state = activity_optim.init(activities)

for t in range(T):
activity_update_result = jpc.update_activities(
params=(model, None),
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
output=y,
input=x
)
# updated activities and optimiser
activities = activity_update_result["activities"]
activity_optim = activity_update_result["optim"]
activity_opt_state = activity_update_result["opt_state"]

# 3. update parameters at the activities' solution with PC
...
```
JPC also comes with some analytical tools that can be used to study and
potentially diagnose issues with PCNs (see [docs
](https://thebuckleylab.github.io/jpc/api/Analytical%20tools/)
and [example notebook
](https://thebuckleylab.github.io/jpc/examples/linear_net_theoretical_energy/)).

0 comments on commit db11f5d

Please sign in to comment.