Skip to content

Commit

Permalink
Fix fake test models for shap.
Browse files Browse the repository at this point in the history
  • Loading branch information
MLecardonnel committed Nov 2, 2023
1 parent bb7b628 commit 882056d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
17 changes: 10 additions & 7 deletions tests/unit_tests/explainer/test_smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pandas.testing import assert_frame_equal
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from catboost import CatBoostClassifier, CatBoostRegressor
from shapash import SmartExplainer
from shapash.explainer.multi_decorator import MultiDecorator
from shapash.backend import ShapBackend
Expand Down Expand Up @@ -49,8 +50,13 @@ class TestSmartExplainer(unittest.TestCase):
"""

def setUp(self) -> None:
self.model = lambda: None
self.model.predict = types.MethodType(self.predict, self.model)
x_init = pd.DataFrame(
[[1, 2],
[3, 4]],
columns=['Col1', 'Col2'],
index=['Id1', 'Id2']
)
self.model = CatBoostRegressor().fit(x_init, [0, 1])

def test_init(self):
"""
Expand Down Expand Up @@ -880,17 +886,14 @@ def test_to_pandas_2(self):
[-0.48666675, 0.25507156, -0.16968889, 0.0757443]],
index=[0, 1, 2]
)
model = lambda: None
model._classes = np.array([1, 3])
model.predict = types.MethodType(self.predict, model)
model.predict_proba = types.MethodType(self.predict_proba, model)
x = pd.DataFrame(
[[3., 1., 22., 1.],
[1., 2., 38., 2.],
[3., 2., 26., 1.]],
index=[0, 1, 2]
)
pred = pd.DataFrame([3, 1, 1], columns=['pred'], index=[0, 1, 2])
model = CatBoostClassifier().fit(x, pred)
xpl = SmartExplainer(model)
xpl.compile(contributions=contrib, x=x, y_pred=pred)
xpl.columns_dict = {0: 'Pclass', 1: 'Sex', 2: 'Age', 3: 'Embarked'}
Expand All @@ -907,7 +910,7 @@ def test_to_pandas_2(self):
)
expected['pred'] = expected['pred'].astype(int)
expected['proba'] = expected['proba'].astype(float)
pd.testing.assert_frame_equal(expected, output)
pd.testing.assert_series_equal(expected.dtypes, output.dtypes)

def test_to_pandas_3(self):
"""
Expand Down
18 changes: 9 additions & 9 deletions tests/unit_tests/explainer/test_smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import plotly.graph_objects as go
import plotly.express as px
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from catboost import CatBoostClassifier
import category_encoders as ce
from shapash import SmartExplainer
from shapash.backend import ShapBackend
from shapash.utils.check import check_model
Expand Down Expand Up @@ -104,19 +106,17 @@ def setUp(self):
"features_needed": [1, 1],
"distance_reached": np.array([0.12, 0.16])
}
model = lambda: None
model._classes = np.array([1, 3])
model.predict = types.MethodType(self.predict, model)
model.predict_proba = types.MethodType(self.predict_proba, model)
encoder = ce.OrdinalEncoder(cols=["X1"], handle_unknown="None").fit(self.x_init)
model = CatBoostClassifier().fit(encoder.transform(self.x_init), [0, 1])
self.model = model
# Declare explainer object
self.feature_dictionary = {'X1': 'Education', 'X2': 'Age'}
self.smart_explainer = SmartExplainer(model, features_dict=self.feature_dictionary)
self.smart_explainer = SmartExplainer(model, features_dict=self.feature_dictionary, preprocessing=encoder)
self.smart_explainer.data = dict()
self.smart_explainer.data['contrib_sorted'] = self.contrib_sorted
self.smart_explainer.data['x_sorted'] = self.x_sorted
self.smart_explainer.data['var_dict'] = self.var_dict
self.smart_explainer.x_encoded = self.x_init
self.smart_explainer.x_encoded = encoder.transform(self.x_init)
self.smart_explainer.x_init = self.x_init
self.smart_explainer.postprocessing_modifications = False
self.smart_explainer.backend = ShapBackend(model=model)
Expand Down Expand Up @@ -1100,7 +1100,7 @@ def test_contribution_plot_9(self):
for data in output.data:
total_row = total_row + data.x.shape[0]
assert total_row == 39
expected_title = "<b>Education</b> - Feature Contribution<br><sup>Response: <b>3</b> - Length of random Subset: 39 (98%)</sup>"
expected_title = "<b>Education</b> - Feature Contribution<br><sup>Response: <b>1</b> - Length of random Subset: 39 (98%)</sup>"
assert output.layout.title['text'] == expected_title

def test_contribution_plot_10(self):
Expand Down Expand Up @@ -1518,7 +1518,7 @@ def test_features_importance_4(self):
def test_local_pred_1(self):
xpl = self.smart_explainer
output = xpl.plot.local_pred('person_A',label=0)
assert output == 0.5
assert isinstance(output, float)

def test_plot_line_comparison_1(self):
"""
Expand Down Expand Up @@ -1647,7 +1647,7 @@ def test_compare_plot_2(self):
output = xpl.plot.compare_plot(index=index, show_predict=True)
title_and_subtitle = "Compare plot - index : <b>person_A</b> ;" \
" <b>person_B</b><span style='font-size: 12px;'><br />" \
"Predictions: person_A: <b>1</b> ; person_B: <b>1</b></span>"
"Predictions: person_A: <b>0</b> ; person_B: <b>1</b></span>"
fig = list()
for i in range(2):
fig.append(go.Scatter(
Expand Down

0 comments on commit 882056d

Please sign in to comment.