Skip to content

Commit

Permalink
Small refactor to comparison plots for easy inheritance in premium pa…
Browse files Browse the repository at this point in the history
…ckage

Small refactor to comparison plots for easy inheritance in premium package
  • Loading branch information
nnansters committed Nov 21, 2023
1 parent e8a3bbf commit 36bdc47
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
12 changes: 8 additions & 4 deletions nannyml/plots/blueprints/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,11 @@ def render_metric_display_name(metric_display_name: Union[str, Tuple]):
class ResultCompareMixin:
def compare(self, other: Result):
return ResultComparison(
self, other, title=self._get_title(other), plot_kwargs=_get_plot_kwargs(self, other) # type: ignore
self, other, title=self.get_title(other), plot_kwargs=_get_plot_kwargs(self, other) # type: ignore
)

def _get_title(self, other: Result):
@property
def titles(self) -> Dict[type, str]:
from nannyml.data_quality.missing.result import Result as MissingValueResult
from nannyml.data_quality.unseen.result import Result as UnseenValuesResult
from nannyml.drift.multivariate.data_reconstruction import Result as DataReconstructionDriftResult
Expand All @@ -649,7 +650,7 @@ def _get_title(self, other: Result):
from nannyml.stats.std import Result as StatsStdResult
from nannyml.stats.sum import Result as StatsSumResult

_result_title_names: Dict[type, Any] = {
_titles: Dict[type, Any] = {
UnivariateDriftResult: "Univariate drift",
DataReconstructionDriftResult: "Multivariate drift",
RealizedPerformanceResult: "Realized performance",
Expand All @@ -663,7 +664,10 @@ def _get_title(self, other: Result):
StatsSumResult: "Statistics, Sum",
}

return f"<b>{_result_title_names[type(self)]}</b> vs. <b>{_result_title_names[type(other)]}</b>"
return _titles

def get_title(self, other: Result):
return f"<b>{self.titles[type(self)]}</b> vs. <b>{self.titles[type(other)]}</b>"


def _get_plot_kwargs(result: Result, other: Result) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions tests/performance_estimation/CBPE/test_cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def test_cbpe_defaults_to_isotonic_calibrator_when_none_given(): # noqa: D103

def test_cbpe_uses_custom_calibrator_when_provided(): # noqa: D103
class TestCalibrator(Calibrator):
def fit(self, y_pred_proba: np.ndarray, y_true: np.ndarray):
def fit(self, y_pred_proba: np.ndarray, y_true: np.ndarray, *args, **kwargs):
pass

def calibrate(self, y_pred_proba: np.ndarray):
def calibrate(self, y_pred_proba: np.ndarray, *args, **kwargs):
pass

estimator = CBPE(
Expand Down

0 comments on commit 36bdc47

Please sign in to comment.