Skip to content

Commit

Permalink
Merge pull request MAIF#522 from guerinclement/feature/lint
Browse files Browse the repository at this point in the history
Feature/lint
  • Loading branch information
guillaume-vignal authored Feb 8, 2024
2 parents 1f85eee + 98d7285 commit 17b92bd
Show file tree
Hide file tree
Showing 22 changed files with 89 additions and 76 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ repos:
# - id: mypy
# args: [--ignore-missing-imports, --disallow-untyped-defs, --show-error-codes, --no-site-packages]
# files: src
# - repo: https://github.com/PyCQA/flake8
# rev: 6.0.0
# hooks:
# - id: flake8
# exclude: ^tests/
# args: ['--ignore=E501,D2,D3,D4,D104,D100,D106,D107,W503,D105,E203']
# additional_dependencies: [ flake8-docstrings, "flake8-bugbear==22.8.23" ]
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
exclude: ^tests/
args: ['--ignore=E501,D2,D3,D4,D104,D100,D106,D107,W503,D105,E203', '--per-file-ignores=__init__.py:F401']
additional_dependencies: [ flake8-docstrings, "flake8-bugbear==22.8.23" ]
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.4.2
hooks:
Expand Down
5 changes: 3 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

sys.path.insert(0, "..")

import shapash
import shapash # noqa: E402

# -- Project information -----------------------------------------------------

Expand Down Expand Up @@ -106,14 +106,15 @@
todo_include_todos = True

# -- Additional html pages -------------------------------------------------
import subprocess
import subprocess # noqa: E402

# Generates the report example in the documentation
subprocess.call(["python", "../tutorial/generate_report/shapash_report_example.py"])
html_extra_path = ["../tutorial/report/output/report.html"]


def setup_tutorials():
"""Set up the tutorials"""
import pathlib
import shutil

Expand Down
14 changes: 14 additions & 0 deletions shapash/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ def __init__(self, model: Any, preprocessing: Optional[Any] = None):

@abstractmethod
def run_explainer(self, x: pd.DataFrame) -> dict:
"""
Computes local contributions.
Must be implemented by a child class
Parameters
----------
x : pd.DataFrame
The observations dataframe used by the model
Returns
-------
explain_data : dict
dict containing local contributions
"""
raise NotImplementedError(
f"`{self.__class__.__name__}` is a subclass of BaseBackend and "
f"must implement the `_run_explainer` method"
Expand Down
4 changes: 2 additions & 2 deletions shapash/backend/lime_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
except ImportError:
is_lime_available = False

from typing import Any, List, Optional, Union

import pandas as pd

from shapash.backend.base_backend import BaseBackend


class LimeBackend(BaseBackend):
"""The Lime Backend"""

column_aggregation = "sum"
name = "lime"
support_groups = False
Expand Down
2 changes: 2 additions & 0 deletions shapash/backend/shap_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class ShapBackend(BaseBackend):
"""The Shap Backend"""

# When grouping features contributions together, Shap uses the sum of the contributions
# of the features that belong to the group
column_aggregation = "sum"
Expand Down
3 changes: 2 additions & 1 deletion shapash/explainer/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from plotly.subplots import make_subplots
from sklearn.manifold import MDS

from shapash import SmartExplainer
from shapash.style.style_utils import colors_loading, define_style, select_palette


class Consistency:
"""Consistency class"""

def __init__(self):
self._palette_name = list(colors_loading().keys())[0]
self._style_dict = define_style(select_palette(colors_loading(), self._palette_name))
Expand Down
3 changes: 3 additions & 0 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def _compile_additional_data(self, additional_data):
return additional_data

def define_style(self, palette_name=None, colors_dict=None):
"""
Set the color set to use in plots.
"""
if palette_name is None and colors_dict is None:
raise ValueError("At least one of palette_name or colors_dict parameters must be defined")
new_palette_name = palette_name or self.palette_name
Expand Down
15 changes: 2 additions & 13 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
from plotly import graph_objs as go
from plotly.offline import plot
from plotly.subplots import make_subplots
from scipy.optimize import fsolve

from shapash.manipulation.select_lines import select_lines
from shapash.manipulation.summarize import compute_corr, compute_features_import, project_feature_values_1d
from shapash.manipulation.summarize import compute_corr, project_feature_values_1d
from shapash.style.style_utils import colors_loading, define_style, select_palette
from shapash.utils.utils import (
add_line_break,
Expand Down Expand Up @@ -2377,17 +2376,7 @@ def cluster_corr(corr, degree, inplace=False):
if facet_col:
features_to_hide += [facet_col]

# We use phik by default as it is a convenient method for numeric and categorical data
if how == "phik":
try:
from phik import phik_matrix

compute_method = "phik"
except (ImportError, ModuleNotFoundError):
warnings.warn('Cannot compute phik correlations. Install phik using "pip install phik".', UserWarning)
compute_method = "pearson"
else:
compute_method = how
compute_method = how

hovertemplate = "<b>%{text}<br />Correlation: %{z}</b><extra></extra>"

Expand Down
12 changes: 7 additions & 5 deletions shapash/explainer/smart_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from shapash.utils.check import (
check_consistency_model_features,
check_consistency_model_label,
check_contribution_object,
check_features_name,
check_label_dict,
check_mask_params,
Expand Down Expand Up @@ -115,13 +114,12 @@ def __init__(
preprocessing=None,
postprocessing=None,
features_groups=None,
mask_params={"features_to_hide": None, "threshold": None, "positive": None, "max_contrib": None},
mask_params=None,
):

params_dict = [features_dict, features_types, label_dict, columns_dict, postprocessing]

for params in params_dict:
if params is not None and isinstance(params, dict) == False:
if (params is not None) and (not isinstance(params, dict)):
raise ValueError(
"""
{} must be a dict.
Expand All @@ -140,7 +138,11 @@ def __init__(
self.label_dict = label_dict
self.check_label_dict()
self.columns_dict = columns_dict
self.mask_params = mask_params
self.mask_params = (
mask_params
if mask_params is not None
else {"features_to_hide": None, "threshold": None, "positive": None, "max_contrib": None}
)
self.check_mask_params()
self.postprocessing = postprocessing
self.features_groups = features_groups
Expand Down
17 changes: 13 additions & 4 deletions shapash/manipulation/summarize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Summarize Module
"""

import warnings

import numpy as np
Expand Down Expand Up @@ -31,9 +32,9 @@ def summarize_el(dataframe, mask, prefix):
Result of the summarize step
"""
matrix = dataframe.where(mask.to_numpy()).values.tolist()
summarized_matrix = [[x for x in l if str(x) != "nan"] for l in matrix]
summarized_matrix = [[x for x in ll if str(x) != "nan"] for ll in matrix]
# Padding to create pd.DataFrame
max_length = max(len(l) for l in summarized_matrix)
max_length = max(len(ll) for ll in summarized_matrix)
for elem in summarized_matrix:
elem.extend([np.nan] * (max_length - len(elem)))
# Create DataFrame
Expand Down Expand Up @@ -203,9 +204,17 @@ def compute_corr(df, compute_method):
# Remove user warnings (when not enough values to compute correlation).
warnings.filterwarnings("ignore")
if compute_method == "phik":
from phik import phik_matrix
try:
from phik import phik_matrix

return phik_matrix(df, verbose=False)
except ImportError:
warnings.warn(
'Cannot compute phik correlations. Falling back to pearson. Install phik using "pip install phik".',
UserWarning,
)
return df.corr()

return phik_matrix(df, verbose=False)
elif compute_method == "pearson":
return df.corr()
else:
Expand Down
2 changes: 1 addition & 1 deletion shapash/report/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def check_report_requirements():
pkg = req.split("=")[0]
try:
importlib.import_module(pkg.lower())
except (ModuleNotFoundError, ImportError):
except ImportError:
raise ModuleNotFoundError(
f"The following package is necessary to generate the Shapash Report : {pkg}. "
f"Try 'pip install shapash[report]' to install all required packages."
Expand Down
4 changes: 2 additions & 2 deletions shapash/report/project_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import plotly

from shapash import SmartExplainer
from shapash.report.common import VarType, compute_col_types, display_value, get_callable, series_dtype
from shapash.report.common import compute_col_types, display_value, get_callable, series_dtype
from shapash.report.data_analysis import perform_global_dataframe_analysis, perform_univariate_dataframe_analysis
from shapash.report.plots import generate_confusion_matrix_plot, generate_fig_univariate
from shapash.report.visualisation import (
Expand Down Expand Up @@ -218,7 +218,7 @@ def display_model_analysis(self):

print_md(f"**Library :** {self.explainer.model.__class__.__module__}")

for name, module in sorted(sys.modules.items()):
for _, module in sorted(sys.modules.items()):
if (
hasattr(module, "__version__")
and self.explainer.model.__class__.__module__.split(".")[0] == module.__name__
Expand Down
4 changes: 4 additions & 0 deletions shapash/report/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def print_html(text: str):


def print_css_style():
"""Print the CSS"""
print_html(
"""
<style type="text/css">
Expand Down Expand Up @@ -89,6 +90,7 @@ def print_css_style():


def print_javascript_misc():
"""Print the JS"""
print_html(
"""
<script>
Expand Down Expand Up @@ -130,6 +132,7 @@ def convert_fig_to_html(fig):


def html_str_df_and_image(df: pd.DataFrame, fig: plt.Figure) -> str:
"""Convert dataframe to HTML display"""
return f"""
<div class="row-fluid" style="margin-top:5px;">
<div class="col-sm-6">{df.to_html(classes="greyGridTable")}</div>
Expand All @@ -139,4 +142,5 @@ def html_str_df_and_image(df: pd.DataFrame, fig: plt.Figure) -> str:


def print_figure(fig):
"""Print a figure as HTML"""
print_html(convert_fig_to_html(fig))
24 changes: 6 additions & 18 deletions shapash/utils/check.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
"""
Check Module
"""
import copy

import numpy as np
import pandas as pd

from shapash.utils.category_encoder_backend import (
dummies_category_encoder,
no_dummies_category_encoder,
supported_category_encoder,
)
from shapash.utils.columntransformer_backend import (
columntransformer,
get_feature_names,
get_list_features_names,
no_dummies_sklearn,
supported_sklearn,
)
from shapash.utils.category_encoder_backend import supported_category_encoder
from shapash.utils.columntransformer_backend import columntransformer, get_list_features_names
from shapash.utils.model import extract_features_model
from shapash.utils.model_synoptic import dict_model_feature
from shapash.utils.transform import check_transformers, preprocessing_tolist
Expand Down Expand Up @@ -172,7 +159,7 @@ def check_contribution_object(case, classes, contributions):
List of labels if the model used is for classification problem, None otherwise.
contributions : pandas.DataFrame, np.ndarray or list
"""
if case == "regression" and isinstance(contributions, (np.ndarray, pd.DataFrame)) == False:
if (case == "regression") and (not isinstance(contributions, (np.ndarray, pd.DataFrame))):
raise ValueError(
"""
Type of contributions parameter specified is not compatible with
Expand Down Expand Up @@ -473,7 +460,8 @@ def check_features_name(columns_dict, features_dict, features):


def check_additional_data(x, additional_data):
"""Checks if additional_data is a pandas DataFrame and has the same index as x"""
if not isinstance(additional_data, pd.DataFrame):
raise ValueError(f"additional_data must be a pd.Dataframe.")
raise ValueError("additional_data must be a pd.Dataframe.")
if not additional_data.index.equals(x.index):
raise ValueError(f"x and additional_data should have the same index.")
raise ValueError("x and additional_data should have the same index.")
4 changes: 2 additions & 2 deletions shapash/utils/columntransformer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def calc_inv_contrib_ct(x_contrib, encoding, agg_columns):
elif str(type(ct_encoding)) == category_encoder_binary:
try:
col_origin = ct_encoding.base_n_encoder.mapping[i_enc].get("mapping").columns.tolist()
except:
except Exception:
col_origin = ct_encoding.mapping[i_enc].get("mapping").columns.tolist()
else:
col_origin = ct_encoding.mapping[i_enc].get("mapping").columns.tolist()
Expand Down Expand Up @@ -466,7 +466,7 @@ def get_col_mapping_ct(encoder, x_encoded):
elif estimator == "passthrough":
try:
features_out = encoder.feature_names_in_[features]
except:
except Exception:
features_out = encoder._feature_names_in[features] # for oldest sklearn version
for f_name in features_out:
dict_col_mapping[f_name] = [x_encoded.columns.to_list()[idx_encoded]]
Expand Down
2 changes: 1 addition & 1 deletion shapash/utils/explanation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_min_nb_features(selection, contributions, mode, distance):
output_value = np.sum(contributions[i, :])

score = 0
for j, idx in enumerate(ids):
for j, idx in enumerate(ids): # noqa: B007
# j : number of features needed
# idx : positions of the j top shap values
score += contributions[i, idx]
Expand Down
2 changes: 1 addition & 1 deletion shapash/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml

_is_yaml_available = True
except (ImportError, ModuleNotFoundError):
except ImportError:
_is_yaml_available = False


Expand Down
1 change: 1 addition & 0 deletions shapash/utils/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, *args, **keywords):
self.__run_backup = None

def start(self):
"""Starts the thread"""
self.__run_backup = self.run
self.run = self.__run
threading.Thread.start(self)
Expand Down
6 changes: 3 additions & 3 deletions shapash/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def check_transformers(list_encoding):
# check that encoding don't use ColumnTransformer and Category encoding at the same time
if use_ct and use_ce:
raise Exception(
f"Can't support ColumnTransformer and Category encoding at the same time. "
f"Use Category encoding in ColumnTransformer"
"Can't support ColumnTransformer and Category encoding at the same time. "
"Use Category encoding in ColumnTransformer"
)

# check that Category encoding is apply on different columns
Expand Down Expand Up @@ -263,7 +263,7 @@ def apply_postprocessing(x_init, postprocessing):

elif dict_postprocessing["type"] == "regex":
new_preds[feature_name] = new_preds[feature_name].apply(
lambda x: re.sub(dict_postprocessing["rule"]["in"], dict_postprocessing["rule"]["out"], x)
lambda x, d=dict_postprocessing: re.sub(d["rule"]["in"], d["rule"]["out"], x)
)

elif dict_postprocessing["type"] == "case":
Expand Down
Loading

0 comments on commit 17b92bd

Please sign in to comment.