From 51fe3117093c836bba9665f491de14de052925fd Mon Sep 17 00:00:00 2001 From: zoj <44142765+zoj613@users.noreply.github.com> Date: Thu, 4 Feb 2021 21:12:20 +0200 Subject: [PATCH] ENH: Add the SMOTERSB oversampling technique --- README.rst | 5 +- doc/api.rst | 1 + doc/bibtex/refs.bib | 21 +- doc/over_sampling.rst | 3 +- doc/whats_new/v0.7.rst | 3 + imblearn/over_sampling/__init__.py | 2 + imblearn/over_sampling/_smote.py | 246 ++++++++++++++++++ imblearn/over_sampling/tests/test_smotersb.py | 97 +++++++ references.bib | 21 +- 9 files changed, 395 insertions(+), 4 deletions(-) create mode 100644 imblearn/over_sampling/tests/test_smotersb.py diff --git a/README.rst b/README.rst index d86120e2e..62dbcb59d 100644 --- a/README.rst +++ b/README.rst @@ -156,6 +156,7 @@ Below is a list of the methods currently implemented in this module. 6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_ 7. KMeans-SMOTE [17]_ 8. ROSE - Random OverSampling Examples [19]_ + 9. SMOTERSB - SMOTE + Rough Set Theory lower bounds [20]_ * Over-sampling followed by under-sampling 1. SMOTE + Tomek links [12]_ @@ -213,4 +214,6 @@ References: .. [18] : Seiffert, C., Khoshgoftaar, T. M., Van Hulse, J., & Napolitano, A. "RUSBoost: A hybrid approach to alleviating class imbalance." IEEE Transactions on Systems, Man, and Cybernetics-Part A: Systems and Humans 40.1 (2010): 185-197. -.. [19] : Menardi, G., Torelli, N.: "Training and assessing classification rules with unbalanced data", Data Mining and Knowledge Discovery, 28, (2014): 92–122 \ No newline at end of file +.. [19] : Menardi, G., Torelli, N.: "Training and assessing classification rules with unbalanced data", Data Mining and Knowledge Discovery, 28, (2014): 92–122 + +.. [20] : Ramentol, E., Caballero, Y., Bello, R. et al. SMOTE-RSB*: a hybrid preprocessing approach based on oversampling and undersampling for high imbalanced data-sets using SMOTE and rough sets theory. Knowl Inf Syst 33, 245–265 (2012). https://doi.org/10.1007/s10115-011-0465-6. diff --git a/doc/api.rst b/doc/api.rst index 07ac6413c..af7ff79bd 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -77,6 +77,7 @@ Prototype selection over_sampling.SMOTENC over_sampling.SVMSMOTE over_sampling.ROSE + over_sampling.SMOTERSB .. _combine_ref: diff --git a/doc/bibtex/refs.bib b/doc/bibtex/refs.bib index 87eb3e30c..f4d645204 100644 --- a/doc/bibtex/refs.bib +++ b/doc/bibtex/refs.bib @@ -207,4 +207,23 @@ @article{torelli2014rose issn = {1573-756X}, url = {https://doi.org/10.1007/s10618-012-0295-5}, doi = {10.1007/s10618-012-0295-5} -} \ No newline at end of file +} + +@article{ramentol2012smotersb, + author = {Ramentol, Enislay and Caballero, Yail\'{e} and Bello, Rafael and + Herrera, Francisco}, + title = {SMOTE-RSB*: A Hybrid Preprocessing Approach Based on Oversampling + and Undersampling for High Imbalanced Data-Sets Using SMOTE and + Rough Sets Theory}, + year = {2012}, + publisher = {Springer-Verlag}, + address = {Berlin, Heidelberg}, + volume = {33}, + number = {2}, + issn = {0219-1377}, + url = {https://doi.org/10.1007/s10115-011-0465-6}, + doi = {10.1007/s10115-011-0465-6}, + journal = {Knowl. Inf. Syst.}, + month = {Nov.}, + pages = {245–265}, +} diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index a154a62dc..9325aa201 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -151,7 +151,8 @@ nearest neighbors class. Those variants are presented in the figure below. The :class:`BorderlineSMOTE` :cite:`han2005borderline`, -:class:`SVMSMOTE` :cite:`nguyen2009borderline`, and +:class:`SVMSMOTE` :cite:`nguyen2009borderline`, +:class:`SMOTERSB` :cite:`ramentol2012smotersb` and :class:`KMeansSMOTE` :cite:`last2017oversampling` offer some variant of the SMOTE algorithm:: diff --git a/doc/whats_new/v0.7.rst b/doc/whats_new/v0.7.rst index 6cb266854..9ed2faebd 100644 --- a/doc/whats_new/v0.7.rst +++ b/doc/whats_new/v0.7.rst @@ -76,6 +76,9 @@ Enhancements dictionary instead of a string. :pr:`770` by :user:`Guillaume Lemaitre `. +- Added the `SMOTERSB` class, implementing SMOTE with Rough Set Theory + :pr:`778` by :user:`Zolisa Bleki ` + Deprecation ........... diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index 3be402135..8f94a527e 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,6 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._smote import SMOTERSB from ._rose import ROSE __all__ = [ @@ -20,5 +21,6 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", + "SMOTERSB", "ROSE" ] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index cfd1c4217..726990b7c 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -4,10 +4,12 @@ # Fernando Nogueira # Christos Aridas # Dzianis Dudnik +# Zolisa Bleki # License: MIT import math from collections import Counter +from collections.abc import Callable import numpy as np from scipy import sparse @@ -1334,3 +1336,247 @@ def _fit_resample(self, X, y): y_resampled = np.hstack((y_resampled, y_new)) return X_resampled, y_resampled + + +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + n_jobs=_n_jobs_docstring, + random_state=_random_state_docstring, +) +class SMOTERSB(BaseSMOTE): + """SMOTE oversampling technique with Rough Set Theory. + + This is an implementation of the algorithm described in [1]_, which + generates samples using vanilla SMOTE and then only retain synthetic + samples that have low similarity with the majority class(es) as determined + by some inseperability relation. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + {sampling_strategy} + + {random_state} + + k_neighbors : int or object, default=2 + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + {n_jobs} + + similarity_threshold : float, default=0.4 + The base threshold set to determine if synthetic samples generated by + SMOTE belong in the lower approximation set. The value must be between + zero and 0.9. If not given, the value defaults to ``0.4``. + + threshold_increment : float, default=0.05 + The value to increment the similiarity_threshold by if no synthetic + samples have a similarity value below the current value of the + threshold. If not given, the value defaults to ``0.05``. + + equivalence_set : list, default=None + A list of feature indices to consider when constructing the + similarity matrix between SMOTE synthetic samples and the majority + class samples. If not given, all features in the fitting data set are + used. + + similarity_func : callback, default=None + A callback used to construct the similarity matrix between two 2d + arrays. The callback must have the signature ``func(X1, X2)``. + If not given, The similarity function defined in [1]_ is used. + + See Also + -------- + ROSE : Random Over-Sampling Examples. + + SMOTE : Over-sample using SMOTE. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + KMeansSMOTE : Over-sample applying a clustering before to oversample using + SMOTE. + + References + ---------- + .. [1] Ramentol, E., Caballero, Y., Bello, R. et al. SMOTE-RSB *: a + hybrid preprocessing approach based on oversampling and undersampling + for high imbalanced data-sets using SMOTE and rough sets theory. Knowl + Inf Syst 33, 245–265 (2012). https://doi.org/10.1007/s10115-011-0465-6. + + Examples + -------- + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import SMOTERSB + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> sm = SMOTERSB(random_state=42) + >>> X_res, y_res = sm.fit_resample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + """ + + @_deprecate_positional_args + def __init__( + self, + *, + sampling_strategy="auto", + random_state=None, + k_neighbors=2, + n_jobs=None, + similarity_threshold=0.4, + threshold_increment=0.05, + equivalence_set=None, + similarity_func=None, + ): + super().__init__( + sampling_strategy=sampling_strategy, + random_state=random_state, + k_neighbors=k_neighbors, + n_jobs=n_jobs, + ) + self.similarity_threshold = similarity_threshold + self.threshold_increment = threshold_increment + self.equivalence_set = equivalence_set + self.similarity_func = similarity_func + + def _validate_params(self, X): + if not (0.1 <= self.similarity_threshold <= 0.9): + raise ValueError( + "`similarity threshold` must lie in the interval [0, 0.9)" + ) + if self.equivalence_set is None: + self.feature_indices_ = range(X.shape[1]) + elif not isinstance(self.equivalence_set, list): + raise TypeError("`equivalence_set` must be a list of indices") + elif np.any(max(self.equivalence_set) - 1 >= X.shape[1] - 1): + raise ValueError("`An index of `equivalence_set` is out of bounds") + else: + self.feature_indices_ = self.equivalence_set + + if (self.similarity_func is not None and + not isinstance(self.similarity_func, Callable)): + raise TypeError("`similarity_func` must be a callable") + + # VERY slow! cython might be better suited for this function + def _make_similarity_matrix(self, X_s, X_m, maxmin_diff): + """ + Construct the similarity matrix between `X_s` and X_m` arrays. + + The similarity matrix is constructed as described in section 3 of + Ramentol et al. (2012). + + Parameters + ---------- + X_s : 2d array + An array of synthetic samples from a minority class generated by + SMOTE. + X_m : 2d array + An array of original samples from the non-minority class(es). + maxmin_diff : 1d array + An array containing the difference between the minimum and maximum + values of each feature in the fitting dataset. + """ + size = (X_s.shape[0], X_m.shape[0]) + n_features = X_s.shape[1] + matrix = np.empty(size) + for i in range(size[0]): + xs = X_s[i] + for j in range(size[1]): + total = 0 + xm = X_m[j] + diff = xs - xm + for k in range(n_features): + if float(xs[k]).is_integer() and float(xm[k]).is_integer(): + total += 1 if diff[k] == 0 else 0 + else: + total += 1 - abs(diff[k]) / maxmin_diff[k] + matrix[i, j] = total / n_features + return matrix + + def _get_nonminority_class_indices(self, y): + """Get the indices of non-minority class data points. + + Returns the indices in the original data that correspond to the + data points of the non-minority class(es). + + Parameters + ---------- + y : 1d array + Array of assigned the assigned class of each data point. + """ + tot_classes = set(np.unique(y)) + smote_classes = set(self.sampling_strategy_) + nonsmote_classes = tot_classes.symmetric_difference(smote_classes) + indices = np.argwhere(y == list(nonsmote_classes)).ravel() + return indices + + def _fit_resample(self, X, y): + self._validate_estimator() + self._validate_params(X) + + X_resampled = [X.copy()] + y_resampled = [y.copy()] + + maxmin_diff = X.max(axis=0) - X.min(axis=0) + + neg_indices = self._get_nonminority_class_indices(y) + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = _safe_indexing(X, target_class_indices) + + self.nn_k_.fit(X_class) + nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:] + X_new, y_new = self._make_samples( + X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0 + ) + + # construct the similairy matrix + if self.similarity_func is not None: + mat = self.similarity_func( + X_new[:, self.feature_indices_], + X[neg_indices][:, self.feature_indices_] + ) + else: + mat = self._make_similarity_matrix( + X_new[:, self.feature_indices_], + X[neg_indices][:, self.feature_indices_], + maxmin_diff=maxmin_diff + ) + # keep only those synthetic samples whose similiary value is + # below ``similarity_val``. + similarity_val = self.similarity_threshold + has_lower_approx = False + while similarity_val <= 0.9 and not has_lower_approx: + mask = np.less(mat, similarity_val).sum(axis=1) + if np.any(mask): + kept_indices = np.argwhere(mask)[:, 0] + X_resampled.append(X_new[kept_indices]) + y_resampled.append(y_new[kept_indices]) + has_lower_approx = True + similarity_val += self.threshold_increment + + if not has_lower_approx: + X_resampled.append(X_new) + y_resampled.append(y_new) + + if sparse.issparse(X): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) + + return X_resampled, y_resampled diff --git a/imblearn/over_sampling/tests/test_smotersb.py b/imblearn/over_sampling/tests/test_smotersb.py new file mode 100644 index 000000000..fe1c8c530 --- /dev/null +++ b/imblearn/over_sampling/tests/test_smotersb.py @@ -0,0 +1,97 @@ +"""Test the module SMOTE-RSB.""" +# Authors: Zolisa Bleki +# License: MIT +import numpy as np +import pytest + +from imblearn.over_sampling import SMOTERSB +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_array_equal + + +@pytest.fixture +def data(): + X = np.array([[0.11622591, -0.0317206], + [0.77481731, 0.60935141], + [1.25192108, -0.22367336], + [0.53366841, -0.30312976], + [1.52091956, -0.49283504], + [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], + [0.3084254, 0.33299982], + [0.70472253, -0.73309052], + [0.28893132, -0.38761769], + [1.15514042, 0.0129463], + [0.88407872, 0.35454207], + [1.31301027, -0.92648734], + [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], + [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], + [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], + [1.70580611, -0.11219234]]) + y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + return X, y + + +def test_smote_rsb(data): + X, y = data + RND_SEED = 0 + R_TOL = 1e-4 + rsb = SMOTERSB(random_state=RND_SEED) + X_res, y_res = rsb.fit_resample(X, y) + X_gt = np.array( + [ + [0.11622591, -0.0317206], + [0.77481731, 0.60935141], + [1.25192108, -0.22367336], + [0.53366841, -0.30312976], + [1.52091956, -0.49283504], + [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], + [0.3084254, 0.33299982], + [0.70472253, -0.73309052], + [0.28893132, -0.38761769], + [1.15514042, 0.0129463], + [0.88407872, 0.35454207], + [1.31301027, -0.92648734], + [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], + [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], + [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], + [1.70580611, -0.11219234], + [-0.09533627, -0.17126026], + [1.45849179, -0.17293647], + [0.8379596, -0.26946767], + [0.38584956, -0.20702218], + ] + ) + + y_gt = np.array( + [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0] + ) + assert_allclose(X_res, X_gt, rtol=R_TOL) + assert_array_equal(y_res, y_gt) + + with pytest.raises(ValueError): + SMOTERSB(similarity_threshold=1).fit_resample(X, y) + + with pytest.raises(TypeError): + SMOTERSB(equivalence_set=(0, 1)).fit_resample(X, y) + + with pytest.raises(ValueError): + SMOTERSB(equivalence_set=[0, 5]).fit_resample(X, y) + + with pytest.raises(TypeError): + fake_callable = "func()" + SMOTERSB(similarity_func=fake_callable).fit_resample(X, y) + + # different similarity matrix generates different values + X_, y_ = SMOTERSB(similarity_func=cosine_similarity).fit_resample(X, y) + assert y_.shape[0] == y_res.shape[0] + assert not np.allclose(X_, X_res, rtol=R_TOL) diff --git a/references.bib b/references.bib index c4432827a..d4318c19e 100644 --- a/references.bib +++ b/references.bib @@ -198,4 +198,23 @@ @article{torelli2014rose issn = {1573-756X}, url = {https://doi.org/10.1007/s10618-012-0295-5}, doi = {10.1007/s10618-012-0295-5} -} \ No newline at end of file +} + +@article{ramentol2012smotersb, + author = {Ramentol, Enislay and Caballero, Yail\'{e} and Bello, Rafael and + Herrera, Francisco}, + title = {SMOTE-RSB*: A Hybrid Preprocessing Approach Based on Oversampling + and Undersampling for High Imbalanced Data-Sets Using SMOTE and + Rough Sets Theory}, + year = {2012}, + publisher = {Springer-Verlag}, + address = {Berlin, Heidelberg}, + volume = {33}, + number = {2}, + issn = {0219-1377}, + url = {https://doi.org/10.1007/s10115-011-0465-6}, + doi = {10.1007/s10115-011-0465-6}, + journal = {Knowl. Inf. Syst.}, + month = {Nov.}, + pages = {245–265}, +}