Skip to content

Commit

Permalink
Matrix and bootstrap solvers docstrings, return types, and modifying …
Browse files Browse the repository at this point in the history
…the confidence interval method to return transitions
  • Loading branch information
dmnapolitano committed Jan 30, 2024
1 parent d62680b commit 9c3379d
Showing 1 changed file with 59 additions and 16 deletions.
75 changes: 59 additions & 16 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,42 @@


class TransitionMatrixSolver(TransitionSolver):
"""
Matrix regression transition solver using CVXPY.
"""

def __init__(self, strict: bool = True, lam: float | None = None):
"""
`lam` != 0 will enable L2 regularization (Ridge).
Parameters
----------
`strict` : bool, default True
If `True`, solution will be constrainted so that all coefficients are >= 0,
<= 1, and the sum of each row equals 1.
`lam` : float, optional
`lam` != 0 will enable L2 regularization (Ridge).
"""
super().__init__()
self._strict = strict
self._lambda = lam

@staticmethod
def __get_constraints(coef: np.ndarray, strict: bool):
def __get_constraints(coef: np.ndarray, strict: bool) -> list:
if strict:
return [0 <= coef, coef <= 1, cp.sum(coef, axis=1) == 1]
return [cp.sum(coef, axis=1) <= 1.1, cp.sum(coef, axis=1) >= 0.9]

def __standard_objective(self, A: np.ndarray, B: np.ndarray, beta: np.ndarray):
def __standard_objective(self, A: np.ndarray, B: np.ndarray, beta: np.ndarray) -> cp.Minimize:
loss_function = cp.norm(A @ beta - B, "fro")
return cp.Minimize(loss_function)

def __ridge_objective(self, A: np.ndarray, B: np.ndarray, beta: np.ndarray):
def __ridge_objective(self, A: np.ndarray, B: np.ndarray, beta: np.ndarray) -> cp.Minimize:
# Based on https://www.cvxpy.org/examples/machine_learning/ridge_regression.html
lam = cp.Parameter(nonneg=True, value=self._lambda)
loss_function = cp.pnorm(A @ beta - B, p=2) ** 2
regularizer = cp.pnorm(beta, p=2) ** 2
return cp.Minimize(loss_function + lam * regularizer)

def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray):
def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray) -> np.ndarray:
transition_matrix = cp.Variable((A.shape[1], B.shape[1]), pos=True)
Aw = np.dot(weights, A)
Bw = np.dot(weights, B)
Expand All @@ -62,11 +72,7 @@ def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray):

return transition_matrix.value

def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None):
"""
X and Y are matrixes (numpy or pandas.DataFrame) of integers.
weights is a list, numpy array, or pandas.Series with the same length as both X and Y.
"""
def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None) -> np.ndarray:
self._check_data_type(X)
self._check_data_type(Y)
self._check_any_element_nan_or_inf(X)
Expand Down Expand Up @@ -104,7 +110,24 @@ def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None =


class BootstrapTransitionMatrixSolver(TransitionSolver):
"""
Bootstrap version of the matrix regression transition solver.
"""

def __init__(self, B: int = 1000, strict: bool = True, verbose: bool = True, lam: int | None = None):
"""
Parameters
----------
`B` : int, default 1000
Number of bootstrap samples to draw and matrix solver models to fit/predict.
`strict` : bool, default True
If `True`, solution will be constrainted so that all coefficients are >= 0,
<= 1, and the sum of each row equals 1.
`verbose` : bool, default True
If `False`, this will reduce the amount of logging produced for each of the `B` bootstrap samples.
`lam` : float, optional
`lam` != 0 will enable L2 regularization (Ridge).
"""
super().__init__()
self._strict = strict
self._B = B
Expand All @@ -113,8 +136,9 @@ def __init__(self, B: int = 1000, strict: bool = True, verbose: bool = True, lam

# class members that are instantiated during model-fit
self._predicted_percentages = None
self._X_expected_totals = None

def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None):
def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None = None) -> np.ndarray:
self._predicted_percentages = []

# assuming pandas.DataFrame
Expand All @@ -123,7 +147,7 @@ def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None =
if not isinstance(Y, np.ndarray):
Y = Y.to_numpy()

X_expected_totals = X.sum(axis=0) / X.sum(axis=0).sum()
self._X_expected_totals = X.sum(axis=0) / X.sum(axis=0).sum()

tm = TransitionMatrixSolver(strict=self._strict, lam=self._lambda)
self._predicted_percentages.append(tm.fit_predict(X, Y, weights=weights))
Expand All @@ -138,19 +162,38 @@ def fit_predict(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None =
self._predicted_percentages.append(tm.fit_predict(X_resampled, Y_resampled, weights=None))

percentages = np.mean(self._predicted_percentages, axis=0)
self._transitions = np.diag(X_expected_totals) @ percentages
self._transitions = np.diag(self._X_expected_totals) @ percentages
return percentages

def get_confidence_interval(self, alpha: float):
# TODO: option to return transitions as well as (or instead of) percentages
def get_confidence_interval(self, alpha: float, transitions: bool = False) -> (np.ndarray, np.ndarray):
"""
Parameters
----------
`alpha` : float
Value between [0, 1). If greater than 1, will be divided by 100.
`transitions` : bool, default False
If True, the returned matrix will represent transitions, not percentages.
Returns
-------
A tuple of two np.ndarray matrices of float. Element 0 has the lower bound and 1 has the upper bound.
"""
if alpha > 1:
alpha = alpha / 100
if alpha < 0 or alpha >= 1:
raise ValueError(f"Invalid confidence interval {alpha}.")

p_lower = ((1.0 - alpha) / 2.0) * 100
p_upper = ((1.0 + alpha) / 2.0) * 100
return (

percentages = (
np.percentile(self._predicted_percentages, p_lower, axis=0),
np.percentile(self._predicted_percentages, p_upper, axis=0),
)

if transitions:
return (
np.diag(self._X_expected_totals) @ percentages[0],
np.diag(self._X_expected_totals) @ percentages[1],
)
return percentages

0 comments on commit 9c3379d

Please sign in to comment.