Skip to content

Commit

Permalink
Refactor likelihood module (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
hmgaudecker authored Sep 11, 2024
1 parent 4e64952 commit df4f68a
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 237 deletions.
2 changes: 1 addition & 1 deletion docs/source/getting_started/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"import yaml\n",
"\n",
"from skillmodels.config import TEST_DIR\n",
"from skillmodels.likelihood_function import get_maximization_inputs"
"from skillmodels.maximization_inputs import get_maximization_inputs"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import yaml\n",
"\n",
"from skillmodels.config import TEST_DIR\n",
"from skillmodels.likelihood_function import get_maximization_inputs\n",
"from skillmodels.maximization_inputs import get_maximization_inputs\n",
"from skillmodels.simulate_data import simulate_dataset\n",
"from skillmodels.visualize_factor_distributions import (\n",
" bivariate_density_contours,\n",
Expand Down
2 changes: 1 addition & 1 deletion src/skillmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
contextlib.suppress(Exception)

from skillmodels.filtered_states import get_filtered_states
from skillmodels.likelihood_function import get_maximization_inputs
from skillmodels.maximization_inputs import get_maximization_inputs
from skillmodels.simulate_data import simulate_dataset

__all__ = ["get_maximization_inputs", "simulate_dataset", "get_filtered_states"]
2 changes: 1 addition & 1 deletion src/skillmodels/filtered_states.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
import numpy as np

from skillmodels.likelihood_function import get_maximization_inputs
from skillmodels.maximization_inputs import get_maximization_inputs
from skillmodels.params_index import get_params_index
from skillmodels.parse_params import create_parsing_info, parse_params
from skillmodels.process_debug_data import create_state_ranges
Expand Down
247 changes: 20 additions & 227 deletions src/skillmodels/likelihood_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,191 +2,50 @@

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

import skillmodels.likelihood_function_debug as lfd
from skillmodels.clipping import soft_clipping
from skillmodels.constraints import add_bounds, get_constraints
from skillmodels.kalman_filters import (
calculate_sigma_scaling_factor_and_weights,
kalman_predict,
kalman_update,
)
from skillmodels.params_index import get_params_index
from skillmodels.parse_params import create_parsing_info, parse_params
from skillmodels.process_data import process_data
from skillmodels.process_debug_data import process_debug_data
from skillmodels.process_model import process_model
from skillmodels.parse_params import parse_params

jax.config.update("jax_enable_x64", False) # noqa: FBT003


def get_maximization_inputs(model_dict, data):
"""Create inputs for optimagic's maximize function.
Args:
model_dict (dict): The model specification. See: :ref:`model_specs`
data (DataFrame): dataset in long format.
Returns a dictionary with keys:
loglike (function): A jax jitted function that takes an optimagic-style
params dataframe as only input and returns a dict with entries:
- "value": The scalar log likelihood
- "contributions": An array with the log likelihood per observation
debug_loglike (function): Similar to loglike, with the following differences:
- It is not jitted and thus faster on the first call and debuggable
- It will add intermediate results as additional entries in the returned
dictionary. Those can be used for debugging and plotting.
gradient (function): The gradient of the scalar log likelihood
function with respect to the parameters.
loglike_and_gradient (function): Combination of loglike and
loglike_gradient that is faster than calling the two functions separately.
constraints (list): List of optimagic constraints that are implied by the
model specification.
params_template (pd.DataFrame): Parameter DataFrame with correct index and
bounds but with empty value column.
"""
model = process_model(model_dict)
p_index = get_params_index(
model["update_info"],
model["labels"],
model["dimensions"],
model["transition_info"],
)

parsing_info = create_parsing_info(
p_index,
model["update_info"],
model["labels"],
model["anchoring"],
)
measurements, controls, observed_factors = process_data(
data,
model["labels"],
model["update_info"],
model["anchoring"],
)

sigma_scaling_factor, sigma_weights = calculate_sigma_scaling_factor_and_weights(
model["dimensions"]["n_latent_factors"],
model["estimation_options"]["sigma_points_scale"],
)

partialed_get_jnp_params_vec = functools.partial(
_get_jnp_params_vec,
target_index=p_index,
)

partialed_loglikes = {}
for n, fun in {
"ll": _log_likelihood_jax,
"llo": _log_likelihood_obs_jax,
"debug_ll": lfd._log_likelihood_jax,
}.items():
partialed_loglikes[n] = _partial_some_log_likelihood_jax(
fun=fun,
parsing_info=parsing_info,
measurements=measurements,
controls=controls,
observed_factors=observed_factors,
model=model,
sigma_weights=sigma_weights,
sigma_scaling_factor=sigma_scaling_factor,
)

_jitted_loglike = jax.jit(partialed_loglikes["ll"])
_jitted_loglikeobs = jax.jit(partialed_loglikes["llo"])
_gradient = jax.jit(jax.grad(partialed_loglikes["ll"]))

def loglike(params):
params_vec = partialed_get_jnp_params_vec(params)
return float(_jitted_loglike(params_vec))

def loglikeobs(params):
params_vec = partialed_get_jnp_params_vec(params)
return _to_numpy(_jitted_loglikeobs(params_vec))

def loglike_and_gradient(params):
params_vec = partialed_get_jnp_params_vec(params)
crit = float(_jitted_loglike(params_vec))
grad = _to_numpy(_gradient(params_vec))
return crit, grad

def debug_loglike(params):
params_vec = partialed_get_jnp_params_vec(params)
jax_output = partialed_loglikes["debug_ll"](params_vec)
tmp = _to_numpy(jax_output)
tmp["value"] = float(tmp["value"])
return process_debug_data(debug_data=tmp, model=model)

constr = get_constraints(
dimensions=model["dimensions"],
labels=model["labels"],
anchoring_info=model["anchoring"],
update_info=model["update_info"],
normalizations=model["normalizations"],
)

params_template = pd.DataFrame(columns=["value"], index=p_index)
params_template = add_bounds(
params_template,
model["estimation_options"]["bounds_distance"],
)

out = {
"loglike": loglike,
"loglikeobs": loglikeobs,
"debug_loglike": debug_loglike,
"loglike_and_gradient": loglike_and_gradient,
"constraints": constr,
"params_template": params_template,
}

return out


def _partial_some_log_likelihood_jax(
fun,
def log_likelihood(
params,
parsing_info,
measurements,
controls,
observed_factors,
model,
sigma_weights,
transition_func,
sigma_scaling_factor,
sigma_weights,
dimensions,
labels,
estimation_options,
is_measurement_iteration,
is_predict_iteration,
iteration_to_period,
observed_factors,
):
update_info = model["update_info"]
is_measurement_iteration = (update_info["purpose"] == "measurement").to_numpy()
_periods = pd.Series(update_info.index.get_level_values("period").to_numpy())
is_predict_iteration = ((_periods - _periods.shift(-1)) == -1).to_numpy()
last_period = model["labels"]["periods"][-1]
# iteration_to_period is used as an indexer to loop over arrays of different lengths
# in a jax.lax.scan. It needs to work for arrays of length n_periods and not raise
# IndexErrors on tracer arrays of length n_periods - 1 (i.e. n_transitions).
# To achieve that, we replace the last period by -1.
iteration_to_period = _periods.replace(last_period, -1).to_numpy()

return functools.partial(
fun,
return log_likelihood_obs(
params=params,
parsing_info=parsing_info,
measurements=measurements,
controls=controls,
transition_func=model["transition_info"]["func"],
transition_func=transition_func,
sigma_scaling_factor=sigma_scaling_factor,
sigma_weights=sigma_weights,
dimensions=model["dimensions"],
labels=model["labels"],
estimation_options=model["estimation_options"],
dimensions=dimensions,
labels=labels,
estimation_options=estimation_options,
is_measurement_iteration=is_measurement_iteration,
is_predict_iteration=is_predict_iteration,
iteration_to_period=iteration_to_period,
observed_factors=observed_factors,
)
).sum()


def _log_likelihood_obs_jax(
def log_likelihood_obs(
params,
parsing_info,
measurements,
Expand Down Expand Up @@ -287,40 +146,6 @@ def _log_likelihood_obs_jax(
).sum(axis=0)


def _log_likelihood_jax(
params,
parsing_info,
measurements,
controls,
transition_func,
sigma_scaling_factor,
sigma_weights,
dimensions,
labels,
estimation_options,
is_measurement_iteration,
is_predict_iteration,
iteration_to_period,
observed_factors,
):
return _log_likelihood_obs_jax(
params=params,
parsing_info=parsing_info,
measurements=measurements,
controls=controls,
transition_func=transition_func,
sigma_scaling_factor=sigma_scaling_factor,
sigma_weights=sigma_weights,
dimensions=dimensions,
labels=labels,
estimation_options=estimation_options,
is_measurement_iteration=is_measurement_iteration,
is_predict_iteration=is_predict_iteration,
iteration_to_period=iteration_to_period,
observed_factors=observed_factors,
).sum()


def _scan_body(
carry,
loop_args,
Expand Down Expand Up @@ -427,35 +252,3 @@ def _one_arg_predict(kwargs, transition_func):
**kwargs,
)
return new_states, new_upper_chols, kwargs["states"]


def _to_numpy(obj):
if isinstance(obj, dict):
res = {}
for key, value in obj.items():
if np.isscalar(value):
res[key] = value
else:
res[key] = np.array(value)

elif np.isscalar(obj):
res = obj
else:
res = np.array(obj)

return res


def _get_jnp_params_vec(params, target_index):
if set(params.index) != set(target_index):
additional_entries = params.index.difference(target_index).tolist()
missing_entries = target_index.difference(params.index).tolist()
msg = "Invalid params DataFrame. "
if additional_entries:
msg += f"Your params have additional entries: {additional_entries}. "
if missing_entries:
msg += f"Your params have missing entries: {missing_entries}. "
raise ValueError(msg)

vec = jnp.array(params.reindex(target_index)["value"].to_numpy())
return vec
2 changes: 1 addition & 1 deletion src/skillmodels/likelihood_function_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from skillmodels.parse_params import parse_params


def _log_likelihood_jax(
def log_likelihood(
params,
parsing_info,
measurements,
Expand Down
Loading

0 comments on commit df4f68a

Please sign in to comment.