From 029aab5182ba106d5250982b9075d85f3a42dc65 Mon Sep 17 00:00:00 2001 From: Francesco Innocenti Date: Wed, 27 Nov 2024 16:35:31 +0000 Subject: [PATCH] Reorganise utils. --- jpc/_utils.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/jpc/_utils.py b/jpc/_utils.py index 2064a2d..7962381 100644 --- a/jpc/_utils.py +++ b/jpc/_utils.py @@ -88,17 +88,10 @@ def cross_entropy_loss(logits: ArrayLike, labels: ArrayLike) -> Scalar: return - jnp.mean(jnp.sum(labels * log_probs, axis=-1)) -def compute_activity_norms(activities: PyTree[Array]) -> Array: - """Calculates l2 norm of activities at each layer.""" - return jnp.array([ - jnp.mean( - jnp.linalg.norm( - a, - axis=-1, - ord=2 - ) - ) for a in tree_leaves(activities) - ]) +def compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Scalar: + return jnp.mean( + jnp.argmax(truths, axis=1) == jnp.argmax(preds, axis=1) + ) * 100 def get_t_max(activities_iters: PyTree[Array]) -> Array: @@ -161,6 +154,19 @@ def loop_body(state): return energies_iters[::-1, :] +def compute_activity_norms(activities: PyTree[Array]) -> Array: + """Calculates l2 norm of activities at each layer.""" + return jnp.array([ + jnp.mean( + jnp.linalg.norm( + a, + axis=-1, + ord=2 + ) + ) for a in tree_leaves(activities) + ]) + + def compute_param_norms(params): """Calculates l2 norm of all model parameters.""" def process_model_params(model_params): @@ -178,9 +184,3 @@ def process_model_params(model_params): skip_model_params is not None else None) return model_norms, skip_model_norms - - -def compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Scalar: - return jnp.mean( - jnp.argmax(truths, axis=1) == jnp.argmax(preds, axis=1) - ) * 100