Skip to content

Commit

Permalink
Update utils docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 27, 2024
1 parent db11f5d commit 1001494
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 4 deletions.
35 changes: 35 additions & 0 deletions docs/api/Utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Utils

::: jpc.make_mlp

---

::: jpc.get_act_fn

---

::: jpc.mse_loss

---

::: jpc.cross_entropy_loss

---

::: jpc.compute_accuracy

---

::: jpc.get_t_max

---

::: jpc.compute_infer_energies

---

::: jpc.compute_activity_norms

---

::: jpc.compute_param_norms
3 changes: 0 additions & 3 deletions docs/api/make_mlp.md

This file was deleted.

156 changes: 156 additions & 0 deletions experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import random
import numpy as np
from torch import manual_seed
import jax.random as jr
import jax.numpy as jnp
from jax.tree_util import tree_leaves
import equinox as eqx
from diffrax import Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8
from jpc import (
linear_activities_coeff_matrix,
compute_activity_grad,
compute_pc_param_grads
)


def setup_mlp_experiment(
Expand Down Expand Up @@ -45,6 +54,88 @@ def setup_mlp_experiment(
str(seed)
)

def setup_mlp_experiment_test(
results_dir,
dataset,
n_hidden,
act_fn,
weight_init_type,
activity_init,
max_t1,
lr,
weight_decay,
activity_optim_id,
seed
):
print(
f"""
Starting experiment with configuration:
Dataset: {dataset}
N hidden: {n_hidden}
Act fn: {act_fn}
Max t1: {max_t1}
Activity optim: {activity_optim_id}
Seed: {seed}
"""
)
return os.path.join(
results_dir,
dataset,
f"{n_hidden}_n_hidden",
act_fn,
f"max_t1_{max_t1}",
activity_optim_id,
str(seed)
)


def setup_cnn_experiment(
results_dir,
dataset,
use_skips,
act_fn,
init_type,
loss,
optim_id,
lr,
batch_size,
ode_solver,
max_t1,
seed
):
print(
f"""
Starting experiment with configuration:
Dataset: {dataset}
Use skips: {use_skips}
Act fn: {act_fn}
Init type: {init_type}
Loss: {loss}
Optim: {optim_id}
Learning rate: {lr}
Batch size: {batch_size}
ODE solver: {ode_solver}
Max t1: {max_t1}
Seed: {seed}
"""
)
return os.path.join(
results_dir,
dataset,
"skips" if use_skips else "no_skips",
act_fn,
f"{init_type}_init",
f"{init_type}_loss",
optim_id,
f"lr_{lr}",
f"batch_{batch_size}",
ode_solver,
f"max_t1_{max_t1}",
str(seed)
)


def set_seed(seed):
np.random.seed(seed)
Expand All @@ -69,3 +160,68 @@ def get_ode_solver(name):
return Dopri5()
elif name == "Dopri8":
return Dopri8()


def origin_init(weight, std_dev, key):
if len(weight.shape) == 2:
out, in_ = weight.shape
return std_dev * jr.normal(key, shape=(out, in_))
elif len(weight.shape) == 4:
out, in_, kh, kw = weight.shape
return std_dev * jr.normal(key, shape=(out, in_, kh, kw))


def init_weights(model, init_fn, std_dev, key):
is_linear_or_conv = lambda x: isinstance(x, (eqx.nn.Linear, eqx.nn.Conv2d))
get_weights = lambda m: [x.weight
for x in tree_leaves(m, is_leaf=is_linear_or_conv)
if is_linear_or_conv(x)]
weights = get_weights(model)
new_weights = [init_fn(weight, std_dev, subkey)
for weight, subkey in zip(weights, jr.split(key, len(weights)))]
new_model = eqx.tree_at(get_weights, model, new_weights)
return new_model


def get_network_weights(network):
weights = [network[l][0].weight for l in range(len(network))]
return weights


@eqx.filter_jit
def compute_network_metrics(network):
weights = get_network_weights(network)
coeff_matrix = linear_activities_coeff_matrix(weights)
rank = jnp.linalg.matrix_rank(coeff_matrix)
cond_num = jnp.linalg.cond(coeff_matrix)
return {
"coeff_matrix": coeff_matrix,
"rank": rank,
"cond_num": cond_num
}


def get_min_iter(lists):
min_iter = 100000
for i in lists:
if len(i) < min_iter:
min_iter = len(i)
return min_iter


def get_min_iter_metrics(metrics):
n_seeds = len(metrics)
min_iter = get_min_iter(lists=metrics)

min_iter_metrics = np.zeros((n_seeds, min_iter))
for seed in range(n_seeds):
min_iter_metrics[seed, :] = metrics[seed][:min_iter]

return min_iter_metrics


def compute_metric_stats(metric):
min_iter_metrics = get_min_iter_metrics(metrics=metric)
metric_means = min_iter_metrics.mean(axis=0)
metric_stds = min_iter_metrics.std(axis=0)
return metric_means, metric_stds
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ nav:
- 🌱 Basic API:
- 'api/Training.md'
- 'api/Testing.md'
- 'api/make_mlp.md'
- 'api/Utils.md'
- 🚀 Advanced API:
- 'api/Initialisation.md'
- 'api/Energy functions.md'
Expand Down

0 comments on commit 1001494

Please sign in to comment.