Skip to content

Commit

Permalink
Merge pull request #22 from HelmholtzAI-Consultants-Munich/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
lisa-sousa authored Sep 10, 2024
2 parents 469fd7e + 974bf53 commit 77ed171
Show file tree
Hide file tree
Showing 11 changed files with 522 additions and 652 deletions.
16 changes: 16 additions & 0 deletions fgclustering/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# file generated by setuptools_scm
# don't change, don't track in version control
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = '1.0.4.dev21+gbe0cbb8.d20240902'
__version_tuple__ = version_tuple = (1, 0, 4, 'dev21', 'gbe0cbb8.d20240902')
109 changes: 56 additions & 53 deletions fgclustering/forest_guided_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
# imports
############################################

import numpy as np
import pandas as pd
import kmedoids
import fgclustering.utils as utils
import fgclustering.optimizer as optimizer
import fgclustering.plotting as plotting
import fgclustering.statistics as stats
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from typing import Union

import warnings

Expand Down Expand Up @@ -40,7 +44,13 @@ class FgClustering:
or `sklearn.ensemble.RandomForestClassifier`.
"""

def __init__(self, model, data, target_column, random_state=42):
def __init__(
self,
model: Union[RandomForestClassifier, RandomForestRegressor],
data: pd.DataFrame,
target_column: Union[str, np.ndarray, pd.Series],
random_state: int = 42,
):
self.random_state = random_state

# check if random forest is regressor or classifier
Expand Down Expand Up @@ -72,16 +82,16 @@ def __init__(self, model, data, target_column, random_state=42):

def run(
self,
number_of_clusters=None,
max_K=8,
method_clustering="pam",
init_clustering="random",
max_iter_clustering=100,
discart_value_JI=0.6,
bootstraps_JI=100,
bootstraps_p_value=100,
n_jobs=1,
verbose=1,
number_of_clusters: int = None,
max_K: int = 8,
method_clustering: str = "pam",
init_clustering: str = "random",
max_iter_clustering: int = 100,
discart_value_JI: float = 0.6,
bootstraps_JI: int = 100,
bootstraps_p_value: int = 100,
n_jobs: int = 1,
verbose: int = 1,
):
"""Runs the forest-guided clustering model. The optimal number of clusters for a k-medoids clustering is computed,
based on the distance matrix computed from the Random Forest proximity matrix.
Expand Down Expand Up @@ -228,63 +238,56 @@ def calculate_statistics(self, data, target_column, bootstraps_p_value=100):
data_clustering_ranked=self.data_clustering_ranked, bootstraps_p_value=bootstraps_p_value
)

def plot_global_feature_importance(self, thr_pvalue=1, top_n=None, save=None):
"""
Plots global feature importance based on p-values. The p-values are computed using ANOVA (for continuous variables)
or Chi-Square (for categorical variables) tests. Feature importance is defined as 1 minus the p-value.
:param thr_pvalue: Threshold p-value for filtering features. Features with p-values below this threshold are considered
significant for plotting. Defaults to 1 (no filtering).
:type thr_pvalue: float, optional
:param top_n: Number of top features to display in the plot. If None, all significant features are displayed. Defaults
to None.
:type top_n: int, optional
:param save: Filename to save the plot. If None, the plot is not saved. Defaults to None.
:type save: str, optional
def plot_feature_importance(
self, thr_pvalue: float = 1, top_n: int = None, num_cols: int = 4, save: str = None
):
"""
# drop insignificant features
selected_features = self.p_value_of_features_ranked.loc["p_value"] < thr_pvalue
selected_features = self.p_value_of_features_ranked.columns[selected_features].tolist()

# select top n features for plotting
if top_n:
selected_features = selected_features[:top_n]
Plot feature importance based on p-values for global and local feature importance.
For the global feature importance, p-values are computed using ANOVA (for continuous variables)
or Chi-Square (for categorical variables) tests. The local feature importance, p-values are measured by the
variance and impurity of the feature within the cluster, i.e. a smaller p-value indicates lower variance/impurity.
Feature importance is defined as log transformation of the p-value with a small offset.
plotting._plot_global_feature_importance(
self.p_value_of_features_ranked[selected_features], top_n, save
)
$transformed_value=-log10(p-value + \epsilon) / -log10(\epsilon)$
def plot_local_feature_importance(self, thr_pvalue=1, top_n=None, num_cols=4, save=None):
"""
Plot local feature importance to display the importance of each feature within each cluster.
Importance is measured by the variance and impurity of the feature within the cluster;
a higher feature importance indicates lower variance/impurity.
where $\epsilon$ is a small positive constant (1e-50) that avoids issues with log10(0), but also significant distortion
because $\epsilon$ is very small.
Displays both global and local importance for top n selected features.
:param thr_pvalue: P-value threshold for filtering features. Only features with p-values below this threshold
are considered significant and plotted. Defaults to 1 (no filtering).
:param thr_pvalue: P-value threshold for display. Only features with p-values below this threshold
are considered significant. Defaults to 1 (no filtering).
:type thr_pvalue: float, optional
:param top_n: Number of top features to display in the plot. If None, all significant features are included.
:param top_n: Number of top features to display in the plot. If None, all features are included.
Defaults to None.
:type top_n: int, optional
:param num_cols: Number of plots per row in the output figure. Defaults to 4.
:type num_cols: int, optional
:param save: Filename to save the plot. If None, the plot will not be saved. Defaults to None.
:type save: str, optional
"""
# drop insignificant features
selected_features = self.p_value_of_features_ranked.loc["p_value"] < thr_pvalue
selected_features = self.p_value_of_features_ranked.columns[selected_features].tolist()

# select top n features for plotting
selected_features = self.p_value_of_features_ranked.columns.tolist()
if top_n:
selected_features = selected_features[:top_n]

plotting._plot_local_feature_importance(
self.p_value_of_features_per_cluster.loc[selected_features], thr_pvalue, top_n, num_cols, save
plotting._plot_feature_importance(
self.p_value_of_features_ranked[selected_features],
self.p_value_of_features_per_cluster.loc[selected_features],
thr_pvalue,
top_n,
num_cols,
save,
)

def plot_decision_paths(
self, distributions=True, heatmap=True, thr_pvalue=1, top_n=None, num_cols=6, save=None
self,
distributions: bool = True,
heatmap: bool = True,
thr_pvalue: float = 1,
top_n: int = None,
num_cols: int = 6,
save: str = None,
):
"""
Plot decision paths of the Random Forest model. This function generates visualizations
Expand Down Expand Up @@ -322,12 +325,12 @@ def plot_decision_paths(

selected_features = ["cluster", "target"] + selected_features

if heatmap:
plotting._plot_heatmap(
self.data_clustering_ranked[selected_features], thr_pvalue, top_n, self.model_type, save
)

if distributions:
plotting._plot_distributions(
self.data_clustering_ranked[selected_features], thr_pvalue, top_n, num_cols, save
)

if heatmap:
plotting._plot_heatmap(
self.data_clustering_ranked[selected_features], thr_pvalue, top_n, self.model_type, save
)
155 changes: 89 additions & 66 deletions fgclustering/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,87 +17,108 @@
############################################


def _plot_global_feature_importance(p_value_of_features_ranked, top_n, save):
"""Plot global feature importance based on p-values given as input.
:param p_value_of_features: dictionary where keys are names of features and values are p-values of these features
:type p_value_of_features: dict
:param save: Filename to save plot.
:type save: str
def log_transform(p_values: list, epsilon: float = 1e-50):
"""
# Transform the DataFrame and calculate feature importance
importance = pd.melt(p_value_of_features_ranked, var_name="Feature", value_name="p_value")
importance["Importance"] = 1 - importance["p_value"]
importance.sort_values(by="Importance", ascending=False, inplace=True)

# Determine figure size dynamically based on the number of features
n_features = len(importance)
figure_height = max(6.5, int(np.ceil(5 * n_features / 25)))

# Plotting
plt.figure(figsize=(6.5, figure_height))
sns.set_theme(style="whitegrid")
sns.barplot(data=importance, x="Importance", y="Feature", color="#3470a3")

plt.title(f"Global Feature Importance for {'top ' + str(top_n) if top_n else 'all'} features")
plt.tight_layout()

# Save plot if a filename is provided
if save:
plt.savefig(f"{save}_global_feature_importance.png", bbox_inches="tight", dpi=300)

plt.show()


def _plot_local_feature_importance(p_value_of_features_per_cluster, thr_pvalue, top_n, num_cols, save):
"""Plot local feature importance to show the importance of each feature for each cluster.
Apply a log transformation to p-values to enhance numerical stability and highlight differences.
Adds a small constant `epsilon` to avoid taking the log of zero and normalizes by dividing by
the log of `epsilon`.
:param p_values: List of p-values to be transformed.
:type p_values: list
:param epsilon: Small constant added to p-values to avoid log of zero. Defaults to 1e-50.
:type epsilon: float, optional
:return: Transformed p-values after log transformation.
:rtype: numpy.ndarray
"""
# add a small constant epsilon
p_values = np.clip(p_values, epsilon, 1)
return -np.log(p_values) / -np.log(epsilon)


def _plot_feature_importance(
p_value_of_features_ranked: pd.DataFrame,
p_value_of_features_per_cluster: pd.DataFrame,
thr_pvalue: float,
top_n: int,
num_cols: int,
save: str,
):
"""
Generate and display a plot showing the importance of features based on p-values.
The plot includes both global feature importance and local feature importance for each cluster.
Global importance is based on all clusters combined, while local importance is specific to each cluster.
:param p_value_of_features_per_cluster: p-value matrix of all features per cluster.
:param p_value_of_features_ranked: DataFrame containing p-values of features, ranked by p-value.
:type p_value_of_features_ranked: pandas.DataFrame
:param p_value_of_features_per_cluster: DataFrame containing p-values of features for each cluster.
:type p_value_of_features_per_cluster: pandas.DataFrame
:param thr_pvalue: P-value threshold used for feature filtering
:param thr_pvalue: P-value threshold for display. Only features with p-values below this threshold
are considered significant. Defaults to 1 (no filtering).
:type thr_pvalue: float, optional
:param num_cols: Number of plots in one row.
:type num_cols: int
:param save: Filename to save plot.
:type save: str
:param top_n: Number of top features to display in the plot. If None, all features are included.
Defaults to None.
:type top_n: int, optional
:param num_cols: Number of plots per row in the output figure. Defaults to 4.
:type num_cols: int, optional
:param save: Filename to save the plot. If None, the plot will not be saved. Defaults to None.
:type save: str, optional
"""

importance = 1 - p_value_of_features_per_cluster

# Reshape and sort the data
X_barplot = (
importance.melt(ignore_index=False, var_name="Cluster", value_name="Importance")
.rename_axis("Feature")
.reset_index(level=0, inplace=False)
.sort_values("Importance", ascending=False)
)
# Determine figure size dynamically based on the number of features
num_features = len(p_value_of_features_ranked.columns)
figsize_width = 6.5
figsize_height = max(figsize_width, int(np.ceil(5 * num_features / 25)))

# Determine figure size based on the number of features and clusters
n_features = len(X_barplot["Feature"].unique())
figure_height = max(6.5, int(np.ceil(5 * n_features / 25)))
num_cols = min(num_cols, len(importance.columns))
num_subplots = 1 + p_value_of_features_per_cluster.shape[1]
num_cols = min(num_cols, num_subplots)
num_rows = int(np.ceil(num_subplots / num_cols))

# Set up the Seaborn theme and create the FacetGrid
sns.set_theme(style="whitegrid")
g = sns.FacetGrid(X_barplot, col="Cluster", sharey=False, col_wrap=num_cols, height=figure_height)
g.map(sns.barplot, "Importance", "Feature", order=None, color="#3470a3")

# Label the axes and set the plot titles
g.set_titles(col_template="Cluster {col_name}")
plt.figure(figsize=(num_cols * figsize_width, num_rows * figsize_height))
plt.subplots_adjust(top=0.95, hspace=0.8, wspace=0.8)
plt.suptitle(
f"Local Feature Importance - Showing {'top ' + str(top_n) if top_n else 'all'} features with p-value < {thr_pvalue}",
f"Feature Importance - Showing {'top ' + str(top_n) if top_n else 'all'} features",
fontsize=14,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
sns.set_theme(style="whitegrid")

# Save the plot if a save path is provided
if save:
plt.savefig(f"{save}_local_feature_importance.png", bbox_inches="tight", dpi=300)
# Plot global feature importance
importance_global = pd.DataFrame(
{
"Feature": p_value_of_features_ranked.columns,
"Importance": log_transform(p_value_of_features_ranked.loc["p_value"].to_list()),
}
).sort_values(by="Importance", ascending=False)

ax = plt.subplot(num_rows, num_cols, 1)
sns.barplot(data=importance_global, x="Importance", y="Feature", color="#3470a3")
ax.axvline(x=log_transform(thr_pvalue), color="red", linestyle="--", label=f"thr p-value = {thr_pvalue}")
ax.set_title(f"Cluster all")
ax.legend(bbox_to_anchor=(1, 1), loc=2)

# Plot local feature importance
for n, cluster in enumerate(p_value_of_features_per_cluster.columns):
importance_local = pd.DataFrame(
{
"Feature": p_value_of_features_per_cluster.index,
"Importance": log_transform(p_value_of_features_per_cluster[cluster].to_list()),
}
).sort_values(by="Importance", ascending=False)
ax = plt.subplot(num_rows, num_cols, n + 2)
sns.barplot(data=importance_local, x="Importance", y="Feature", color="#3470a3")
ax.axvline(x=log_transform(thr_pvalue), color="red", linestyle="--")
ax.set_title(f"Cluster {cluster}")
# ax.legend(bbox_to_anchor=(1, 1), loc=2)

plt.tight_layout(rect=[0, 0, 1, 0.95])

if save is not None:
plt.savefig(f"{save}_feature_importance.png", bbox_inches="tight", dpi=300)
plt.show()


def _plot_heatmap(data_clustering_ranked, thr_pvalue, top_n, model_type, save):
def _plot_heatmap(
data_clustering_ranked: pd.DataFrame, thr_pvalue: float, top_n: int, model_type: str, save: str
):
"""Plot feature heatmap sorted by clusters, where features are filtered and ranked
with statistical tests (ANOVA for continuous featres, chi square for categorical features).
Expand Down Expand Up @@ -209,7 +230,9 @@ def _plot_heatmap(data_clustering_ranked, thr_pvalue, top_n, model_type, save):
plt.show()


def _plot_distributions(data_clustering_ranked, thr_pvalue, top_n, num_cols, save):
def _plot_distributions(
data_clustering_ranked: pd.DataFrame, thr_pvalue: float, top_n: int, num_cols: int, save: str
):
"""Plot feature boxplots (for continuous features) or barplots (for categorical features) divided by clusters,
where features are filtered and ranked by p-value of a statistical test (ANOVA for continuous features,
chi square for categorical features).
Expand Down
1 change: 0 additions & 1 deletion fgclustering/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def calculate_global_feature_importance(X, y, cluster_labels, model_type):
f"Feature {feature} has dytpye {data_feature.dtype} but has to be of type category or numeric!"
)

print(p_value_of_features)
# Convert p-value dictionary to a DataFrame and sort by p-value
p_value_of_features_ranked = (
pd.DataFrame.from_dict(p_value_of_features, orient="index", columns=["p_value"])
Expand Down
Loading

0 comments on commit 77ed171

Please sign in to comment.