-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1c2d507
commit db11f5d
Showing
1 changed file
with
67 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Advanced usage | ||
|
||
Advanced users can access all the underlying functions of `jpc.make_pc_step` as | ||
well as additional features. A custom PC training step looks like the following: | ||
```py | ||
import jpc | ||
|
||
# 1. initialise activities with a feedforward pass | ||
activities = jpc.init_activities_with_ffwd(model=model, input=x) | ||
|
||
# 2. run inference to equilibrium | ||
equilibrated_activities = jpc.solve_inference( | ||
params=(model, None), | ||
activities=activities, | ||
output=y, | ||
input=x | ||
) | ||
|
||
# 3. update parameters at the activities' solution with PC | ||
param_update_result = jpc.update_params( | ||
params=(model, None), | ||
activities=equilibrated_activities, | ||
optim=param_optim, | ||
opt_state=param_opt_state, | ||
output=y, | ||
input=x | ||
) | ||
|
||
# updated model and optimiser | ||
model = param_update_result["model"] | ||
param_optim = param_update_result["optim"] | ||
param_opt_state = param_update_result["opt_state"] | ||
``` | ||
which can be embedded in a jitted function with any other additional | ||
computations. One can also use any Optax optimiser to equilibrate the inference | ||
dynamics by replacing the function in step 2, as shown below. | ||
```py | ||
activity_optim = optax.sgd(1e-3) | ||
|
||
# 1. initialise activities | ||
... | ||
|
||
# 2. infer with gradient descent | ||
activity_opt_state = activity_optim.init(activities) | ||
|
||
for t in range(T): | ||
activity_update_result = jpc.update_activities( | ||
params=(model, None), | ||
activities=activities, | ||
optim=activity_optim, | ||
opt_state=activity_opt_state, | ||
output=y, | ||
input=x | ||
) | ||
# updated activities and optimiser | ||
activities = activity_update_result["activities"] | ||
activity_optim = activity_update_result["optim"] | ||
activity_opt_state = activity_update_result["opt_state"] | ||
|
||
# 3. update parameters at the activities' solution with PC | ||
... | ||
``` | ||
JPC also comes with some analytical tools that can be used to study and | ||
potentially diagnose issues with PCNs (see [docs | ||
](https://thebuckleylab.github.io/jpc/api/Analytical%20tools/) | ||
and [example notebook | ||
](https://thebuckleylab.github.io/jpc/examples/linear_net_theoretical_energy/)). |