Skip to content

Commit

Permalink
Reduce the peak memory usage of skillmodels (#76) and refactor the li…
Browse files Browse the repository at this point in the history
…kelihood functions modules (#77)

* Add a version of mem consumption test that fails if there are increases in the repo.
* Decorate kalman_update, _calculate_sigma_points,  with jax.checkpoint, can use `prevent_cse=False` due to being inside lax.scan, as per https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html\#jax.checkpoint.
* Use checkpoint() on kalman_predict as well, requires changing `transition_func` to be its first argument and using `jnp.array` in tests.
* Provision for testing on GPUs, timing information is not useful yet.
  • Loading branch information
hmgaudecker authored Sep 11, 2024
1 parent bbd4ce4 commit b7975a6
Show file tree
Hide file tree
Showing 19 changed files with 3,026 additions and 1,071 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ jobs:
- uses: actions/checkout@v4
- uses: prefix-dev/setup-pixi@v0.8.0
with:
pixi-version: v0.28.2
pixi-version: v0.29.0
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: test
environments: test-cpu
activate-environment: true
- name: Run pytest
shell: bash -l {0}
run: pixi run -e test tests-with-cov
run: pixi run -e test-cpu tests-with-cov
- name: Upload coverage report
if: runner.os == 'Linux' && matrix.python-version == '3.12'
uses: codecov/codecov-action@v4
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ entry for a suggested citation. The suggested citation will be updated once the
becomes part of a published paper.

```
@Unpublished{Gabler2018,
@Unpublished{Gabler2024,
Title = {A Python Library to Estimate Nonlinear Dynamic Latent Factor Models},
Author = {Janos Gabler},
Year = {2018},
Year = {2024},
Url = {https://github.com/OpenSourceEconomics/skillmodels}
}
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"import yaml\n",
"\n",
"from skillmodels.config import TEST_DIR\n",
"from skillmodels.likelihood_function import get_maximization_inputs"
"from skillmodels.maximization_inputs import get_maximization_inputs"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import yaml\n",
"\n",
"from skillmodels.config import TEST_DIR\n",
"from skillmodels.likelihood_function import get_maximization_inputs\n",
"from skillmodels.maximization_inputs import get_maximization_inputs\n",
"from skillmodels.simulate_data import simulate_dataset\n",
"from skillmodels.visualize_factor_distributions import (\n",
" bivariate_density_contours,\n",
Expand Down
3,477 changes: 2,710 additions & 767 deletions pixi.lock

Large diffs are not rendered by default.

30 changes: 27 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,33 @@ scipy = "<=1.13"
# Development Dependencies (pypi)
# --------------------------------------------------------------------------------------

[tool.pixi.pypi-dependencies]
jax = { version = ">=0.4.20", extras = ["cpu"] }
[tool.pixi.target.unix.dependencies]
jax = ">=0.4.20"
jaxlib = ">=0.4.20"

# Development Dependencies (pypi)
# --------------------------------------------------------------------------------------

[tool.pixi.pypi-dependencies]
pdbp = "*"
skillmodels = {path = ".", editable = true}

[tool.pixi.target.win-64.pypi-dependencies]
jax = { version = ">=0.4.20", extras = ["cpu"] }
jaxlib = ">=0.4.20"

# Features and Tasks
# --------------------------------------------------------------------------------------

[tool.pixi.feature.cuda]
platforms = ["linux-64"]
system-requirements = {cuda = "12"}

[tool.pixi.feature.cuda.target.linux-64.dependencies]
cuda-nvcc = ">=12"
jax = ">=0.4.20"
jaxlib = { version = ">=0.4.20", build = "cuda12*" }

[tool.pixi.feature.test.dependencies]
pytest = "*"
pytest-cov = "*"
Expand All @@ -128,6 +146,10 @@ pytest-memray = "*"
tests = "pytest tests"
tests-with-cov = "pytest tests --cov-report=xml --cov=./"
mem = "pytest -x -s --pdb --memray --fail-on-increase tests/test_likelihood_regression.py::test_likelihood_contributions_large_nobs"
mem-on-clean-repo = "git status --porcelain && git diff-index --quiet HEAD -- && git rev-parse HEAD && pytest -x -s --pdb --memray --fail-on-increase tests/test_likelihood_regression.py::test_likelihood_contributions_large_nobs"

[tool.pixi.feature.cuda.tasks]
mem-cuda = "pytest -x -s --pdb --memray --fail-on-increase tests/test_likelihood_regression.py::test_likelihood_contributions_large_nobs"

[tool.pixi.feature.mypy.dependencies]
mypy = "*"
Expand All @@ -141,8 +163,10 @@ mypy = "mypy src"
# --------------------------------------------------------------------------------------

[tool.pixi.environments]
cuda = {features = ["cuda"], solve-group = "cuda"}
mypy = {features = ["mypy"], solve-group = "default"}
test = {features = ["test"], solve-group = "default"}
test-cpu = {features = ["test"], solve-group = "default"}
test-gpu = {features = ["test", "cuda"], solve-group = "cuda"}


# ======================================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/skillmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
contextlib.suppress(Exception)

from skillmodels.filtered_states import get_filtered_states
from skillmodels.likelihood_function import get_maximization_inputs
from skillmodels.maximization_inputs import get_maximization_inputs
from skillmodels.simulate_data import simulate_dataset

__all__ = ["get_maximization_inputs", "simulate_dataset", "get_filtered_states"]
2 changes: 1 addition & 1 deletion src/skillmodels/filtered_states.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
import numpy as np

from skillmodels.likelihood_function import get_maximization_inputs
from skillmodels.maximization_inputs import get_maximization_inputs
from skillmodels.params_index import get_params_index
from skillmodels.parse_params import create_parsing_info, parse_params
from skillmodels.process_debug_data import create_state_ranges
Expand Down
23 changes: 11 additions & 12 deletions src/skillmodels/kalman_filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import jax
import jax.numpy as jnp

Expand All @@ -9,6 +11,7 @@
# ======================================================================================


@functools.partial(jax.checkpoint, prevent_cse=False)
def kalman_update(
states,
upper_chols,
Expand Down Expand Up @@ -152,12 +155,13 @@ def calculate_sigma_scaling_factor_and_weights(n_states, kappa=2):
return scaling_factor, weights


@functools.partial(jax.checkpoint, static_argnums=0, prevent_cse=False)
def kalman_predict(
transition_func,
states,
upper_chols,
sigma_scaling_factor,
sigma_weights,
transition_info,
trans_coeffs,
shock_sds,
anchoring_scaling_factors,
Expand All @@ -167,6 +171,7 @@ def kalman_predict(
"""Make a unscented Kalman predict.
Args:
transition_func (Callable): The transition function.
states (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states) with
pre-update states estimates.
upper_chols (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states,
Expand All @@ -177,9 +182,6 @@ def kalman_predict(
the sigma_point algorithm chosen.
sigma_weights (jax.numpy.array): 1d array of length n_sigma with non-negative
sigma weights.
transition_info (dict): Dict with the entries "func" (the actual transition
function) and "columns" (a dictionary mapping factors that are needed
as individual columns to positions in the factor array).
trans_coeffs (tuple): Tuple of 1d jax.numpy.arrays with transition parameters.
anchoring_scaling_factors (jax.numpy.array): Array of shape (2, n_fac) with
the scaling factors for anchoring. The first row corresponds to the input
Expand All @@ -203,7 +205,7 @@ def kalman_predict(
)
transformed = transform_sigma_points(
sigma_points,
transition_info,
transition_func,
trans_coeffs,
anchoring_scaling_factors,
anchoring_constants,
Expand All @@ -225,6 +227,7 @@ def kalman_predict(
return predicted_states, predicted_covs


@functools.partial(jax.checkpoint, prevent_cse=False)
def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factors):
"""Calculate the array of sigma_points for the unscented transform.
Expand Down Expand Up @@ -272,7 +275,7 @@ def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factor

def transform_sigma_points(
sigma_points,
transition_info,
transition_func,
trans_coeffs,
anchoring_scaling_factors,
anchoring_constants,
Expand All @@ -281,9 +284,7 @@ def transform_sigma_points(
Args:
sigma_points (jax.numpy.array) of shape n_obs, n_mixtures, n_sigma, n_fac.
transition_info (dict): Dict with the entries "func" (the actual transition
function) and "columns" (a dictionary mapping factors that are needed
as individual columns to positions in the factor array).
transition_func (Callable): The transition function.
trans_coeffs (tuple): Tuple of 1d jax.numpy.arrays with transition parameters.
anchoring_scaling_factors (jax.numpy.array): Array of shape (2, n_states) with
the scaling factors for anchoring. The first row corresponds to the input
Expand All @@ -303,9 +304,7 @@ def transform_sigma_points(

anchored = flat_sigma_points * anchoring_scaling_factors[0] + anchoring_constants[0]

transition_function = transition_info["func"]

transformed_anchored = transition_function(trans_coeffs, anchored)
transformed_anchored = transition_func(trans_coeffs, anchored)

n_observed = transformed_anchored.shape[-1]

Expand Down
Loading

0 comments on commit b7975a6

Please sign in to comment.