Skip to content

Commit

Permalink
adjust test to change in dataformat
Browse files Browse the repository at this point in the history
  • Loading branch information
lisa-sousa committed Sep 4, 2024
1 parent 30b0539 commit f1c4b49
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_calculate_global_feature_importance():
y = pd.Series([0, 0, 0, 0, 0, 0])
cluster_labels = np.array([0, 0, 0, 1, 1, 1])

X_ranked, p_value_of_features = calculate_global_feature_importance(
X, y, cluster_labels, model_type
)
X_ranked, p_value_of_features = calculate_global_feature_importance(X, y, cluster_labels, model_type)

X_ranked.drop("cluster", axis=1, inplace=True)
assert list(X_ranked.columns) == [
Expand Down Expand Up @@ -90,19 +88,13 @@ def test_calculate_local_feature_importance():
y = pd.Series([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
cluster_labels = np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])

X_ranked, p_value_of_features = calculate_global_feature_importance(
X, y, cluster_labels, model_type
)
X_ranked, p_value_of_features = calculate_global_feature_importance(X, y, cluster_labels, model_type)
for column in X.columns:
if p_value_of_features[column] > thr_pvalue:
if p_value_of_features.loc["p_value", column] > thr_pvalue:
X.drop(column, axis=1, inplace=True)

p_value_of_features_per_cluster = calculate_local_feature_importance(
X_ranked, bootstraps
)
p_value_of_features_per_cluster = calculate_local_feature_importance(X_ranked, bootstraps)
importance = 1 - p_value_of_features_per_cluster
result = importance.transpose().median()

assert (
sum(result > 0.9) == 2
), "error: wrong number of features with highest feature importance"
assert sum(result > 0.9) == 2, "error: wrong number of features with highest feature importance"

0 comments on commit f1c4b49

Please sign in to comment.