From 882056d05e2c9fac940a3276cd2d1b8148f590f1 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Thu, 2 Nov 2023 14:59:33 +0100 Subject: [PATCH] Fix fake test models for shap. --- .../explainer/test_smart_explainer.py | 17 ++++++++++------- .../unit_tests/explainer/test_smart_plotter.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/unit_tests/explainer/test_smart_explainer.py b/tests/unit_tests/explainer/test_smart_explainer.py index 238cc32d..78db04c6 100644 --- a/tests/unit_tests/explainer/test_smart_explainer.py +++ b/tests/unit_tests/explainer/test_smart_explainer.py @@ -14,6 +14,7 @@ from pandas.testing import assert_frame_equal from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier +from catboost import CatBoostClassifier, CatBoostRegressor from shapash import SmartExplainer from shapash.explainer.multi_decorator import MultiDecorator from shapash.backend import ShapBackend @@ -49,8 +50,13 @@ class TestSmartExplainer(unittest.TestCase): """ def setUp(self) -> None: - self.model = lambda: None - self.model.predict = types.MethodType(self.predict, self.model) + x_init = pd.DataFrame( + [[1, 2], + [3, 4]], + columns=['Col1', 'Col2'], + index=['Id1', 'Id2'] + ) + self.model = CatBoostRegressor().fit(x_init, [0, 1]) def test_init(self): """ @@ -880,10 +886,6 @@ def test_to_pandas_2(self): [-0.48666675, 0.25507156, -0.16968889, 0.0757443]], index=[0, 1, 2] ) - model = lambda: None - model._classes = np.array([1, 3]) - model.predict = types.MethodType(self.predict, model) - model.predict_proba = types.MethodType(self.predict_proba, model) x = pd.DataFrame( [[3., 1., 22., 1.], [1., 2., 38., 2.], @@ -891,6 +893,7 @@ def test_to_pandas_2(self): index=[0, 1, 2] ) pred = pd.DataFrame([3, 1, 1], columns=['pred'], index=[0, 1, 2]) + model = CatBoostClassifier().fit(x, pred) xpl = SmartExplainer(model) xpl.compile(contributions=contrib, x=x, y_pred=pred) xpl.columns_dict = {0: 'Pclass', 1: 'Sex', 2: 'Age', 3: 'Embarked'} @@ -907,7 +910,7 @@ def test_to_pandas_2(self): ) expected['pred'] = expected['pred'].astype(int) expected['proba'] = expected['proba'].astype(float) - pd.testing.assert_frame_equal(expected, output) + pd.testing.assert_series_equal(expected.dtypes, output.dtypes) def test_to_pandas_3(self): """ diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index 27b04cea..12bfbf15 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -10,6 +10,8 @@ import plotly.graph_objects as go import plotly.express as px from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier +from catboost import CatBoostClassifier +import category_encoders as ce from shapash import SmartExplainer from shapash.backend import ShapBackend from shapash.utils.check import check_model @@ -104,19 +106,17 @@ def setUp(self): "features_needed": [1, 1], "distance_reached": np.array([0.12, 0.16]) } - model = lambda: None - model._classes = np.array([1, 3]) - model.predict = types.MethodType(self.predict, model) - model.predict_proba = types.MethodType(self.predict_proba, model) + encoder = ce.OrdinalEncoder(cols=["X1"], handle_unknown="None").fit(self.x_init) + model = CatBoostClassifier().fit(encoder.transform(self.x_init), [0, 1]) self.model = model # Declare explainer object self.feature_dictionary = {'X1': 'Education', 'X2': 'Age'} - self.smart_explainer = SmartExplainer(model, features_dict=self.feature_dictionary) + self.smart_explainer = SmartExplainer(model, features_dict=self.feature_dictionary, preprocessing=encoder) self.smart_explainer.data = dict() self.smart_explainer.data['contrib_sorted'] = self.contrib_sorted self.smart_explainer.data['x_sorted'] = self.x_sorted self.smart_explainer.data['var_dict'] = self.var_dict - self.smart_explainer.x_encoded = self.x_init + self.smart_explainer.x_encoded = encoder.transform(self.x_init) self.smart_explainer.x_init = self.x_init self.smart_explainer.postprocessing_modifications = False self.smart_explainer.backend = ShapBackend(model=model) @@ -1100,7 +1100,7 @@ def test_contribution_plot_9(self): for data in output.data: total_row = total_row + data.x.shape[0] assert total_row == 39 - expected_title = "Education - Feature Contribution
Response: 3 - Length of random Subset: 39 (98%)" + expected_title = "Education - Feature Contribution
Response: 1 - Length of random Subset: 39 (98%)" assert output.layout.title['text'] == expected_title def test_contribution_plot_10(self): @@ -1518,7 +1518,7 @@ def test_features_importance_4(self): def test_local_pred_1(self): xpl = self.smart_explainer output = xpl.plot.local_pred('person_A',label=0) - assert output == 0.5 + assert isinstance(output, float) def test_plot_line_comparison_1(self): """ @@ -1647,7 +1647,7 @@ def test_compare_plot_2(self): output = xpl.plot.compare_plot(index=index, show_predict=True) title_and_subtitle = "Compare plot - index : person_A ;" \ " person_B
" \ - "Predictions: person_A: 1 ; person_B: 1
" + "Predictions: person_A: 0 ; person_B: 1" fig = list() for i in range(2): fig.append(go.Scatter(