Skip to content

Commit

Permalink
Better in-out support for pandas.
Browse files Browse the repository at this point in the history
  • Loading branch information
chkoar committed Feb 3, 2020
1 parent 42cd496 commit 55d8f27
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 100 deletions.
49 changes: 7 additions & 42 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils._validation import _OutputReconstructor


class SamplerMixin(BaseEstimator, metaclass=ABCMeta):
Expand Down Expand Up @@ -80,21 +81,10 @@ def fit_resample(self, X, y):

output = self._fit_resample(X, y)

if self._X_columns is not None or self._y_name is not None:
import pandas as pd

if self._X_columns is not None:
X_ = pd.DataFrame(output[0], columns=self._X_columns)
X_ = X_.astype(self._X_dtypes)
else:
X_ = output[0]

y_ = (label_binarize(output[1], np.unique(y))
if binarize_y else output[1])

if self._y_name is not None:
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)

X_, y_ = self._reconstructor.reconstruct(output[0], y_)
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])

# define an alias for back-compatibility
Expand Down Expand Up @@ -137,22 +127,7 @@ def __init__(self, sampling_strategy="auto"):
self.sampling_strategy = sampling_strategy

def _check_X_y(self, X, y, accept_sparse=None):
if hasattr(X, "loc"):
# store information to build dataframe
self._X_columns = X.columns
self._X_dtypes = X.dtypes
else:
self._X_columns = None
self._X_dtypes = None

if hasattr(y, "loc"):
# store information to build a series
self._y_name = y.name
self._y_dtype = y.dtype
else:
self._y_name = None
self._y_dtype = None

self._reconstructor = _OutputReconstructor(X, y)
if accept_sparse is None:
accept_sparse = ["csr", "csc"]
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
Expand Down Expand Up @@ -265,8 +240,8 @@ def fit_resample(self, X, y):
y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
# store the columns name to reconstruct a dataframe
self._columns = X.columns if hasattr(X, "loc") else None
self._reconstructor = _OutputReconstructor(X, y)

if self.validate:
check_classification_targets(y)
X, y, binarize_y = self._check_X_y(
Expand All @@ -280,22 +255,12 @@ def fit_resample(self, X, y):
output = self._fit_resample(X, y)

if self.validate:
if self._X_columns is not None or self._y_name is not None:
import pandas as pd

if self._X_columns is not None:
X_ = pd.DataFrame(output[0], columns=self._X_columns)
X_ = X_.astype(self._X_dtypes)
else:
X_ = output[0]

y_ = (label_binarize(output[1], np.unique(y))
if binarize_y else output[1])

if self._y_name is not None:
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)

X_, y_ = self._reconstructor.reconstruct(output[0], y_)
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])

return output

def _fit_resample(self, X, y):
Expand Down
18 changes: 2 additions & 16 deletions imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..utils import check_target_type
from ..utils import Substitution
from ..utils._docstring import _random_state_docstring
from ..utils._validation import _OutputReconstructor


@Substitution(
Expand Down Expand Up @@ -75,22 +76,7 @@ def __init__(self, sampling_strategy="auto", random_state=None):
self.random_state = random_state

def _check_X_y(self, X, y):
if hasattr(X, "loc"):
# store information to build dataframe
self._X_columns = X.columns
self._X_dtypes = X.dtypes
else:
self._X_columns = None
self._X_dtypes = None

if hasattr(y, "loc"):
# store information to build a series
self._y_name = y.name
self._y_dtype = y.dtype
else:
self._y_name = None
self._y_dtype = None

self._reconstructor = _OutputReconstructor(X, y)
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
force_all_finite=False)
Expand Down
18 changes: 2 additions & 16 deletions imblearn/over_sampling/_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils import Substitution
from ..utils._docstring import _n_jobs_docstring
from ..utils._docstring import _random_state_docstring
from ..utils._validation import _OutputReconstructor


class BaseSMOTE(BaseOverSampler):
Expand Down Expand Up @@ -891,22 +892,7 @@ def _check_X_y(self, X, y):
"""Overwrite the checking to let pass some string for categorical
features.
"""
if hasattr(X, "loc"):
# store information to build dataframe
self._X_columns = X.columns
self._X_dtypes = X.dtypes
else:
self._X_columns = None
self._X_dtypes = None

if hasattr(y, "loc"):
# store information to build a series
self._y_name = y.name
self._y_dtype = y.dtype
else:
self._y_name = None
self._y_dtype = None

self._reconstructor = _OutputReconstructor(X, y)
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None)
return X, y, binarize_y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ...utils import check_target_type
from ...utils import Substitution
from ...utils._docstring import _random_state_docstring
from ...utils._validation import _OutputReconstructor


@Substitution(
Expand Down Expand Up @@ -81,22 +82,7 @@ def __init__(
self.replacement = replacement

def _check_X_y(self, X, y):
if hasattr(X, "loc"):
# store information to build dataframe
self._X_columns = X.columns
self._X_dtypes = X.dtypes
else:
self._X_columns = None
self._X_dtypes = None

if hasattr(y, "loc"):
# store information to build a series
self._y_name = y.name
self._y_dtype = y.dtype
else:
self._y_name = None
self._y_dtype = None

self._reconstructor = _OutputReconstructor(X, y)
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
force_all_finite=False)
Expand Down
45 changes: 45 additions & 0 deletions imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,51 @@
TARGET_KIND = ("binary", "multiclass", "multilabel-indicator")


class _OutputReconstructor:
"""A class for converting input types to numpy and back."""

def __init__(self, X, y):
self.x_props = self._gets_props(X)
self.y_props = self._gets_props(y)

def reconstruct(self, X, y):
X = self._transfrom(X, self.x_props)
y = self._transfrom(y, self.y_props)
return X, y

def _gets_props(self, array):
props = {}
props["type"] = array.__class__.__name__
props["columns"] = getattr(array, "columns", None)
props["name"] = getattr(array, "name", None)
props["dtypes"] = getattr(array, "dtypes", None)
return props

def _transfrom(self, array, props):
type_ = props["type"].lower()
msg="Could not convert to {}".format(type_)
if type_ == "list":
ret = array.tolist()
elif type_ == "dataframe":
try:
import pandas as pd
ret = pd.DataFrame(array, columns=props["columns"])
ret = ret.astype(props["dtypes"])
except Exception:
warnings.warn(msg)
elif type_ == "series":
try:
import pandas as pd
ret = pd.Series(array,
dtype=props["dtypes"],
name=props["name"])
except Exception:
warnings.warn(msg)
else:
ret = array
return ret


def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
"""Check the objects is consistent to be a NN.
Expand Down
27 changes: 17 additions & 10 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ def check_samplers_pandas(name, Sampler):
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
y_pd = pd.Series(y, name="class")
X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
y_df = pd.DataFrame(y)
y_s = pd.Series(y, name="class")
sampler = Sampler()
if isinstance(Sampler(), NearMiss):
samplers = [Sampler(version=version) for version in (1, 2, 3)]
Expand All @@ -253,16 +254,22 @@ def check_samplers_pandas(name, Sampler):

for sampler in samplers:
set_random_state(sampler)
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y_pd)
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
X_res_df, y_res_df = sampler.fit_resample(X_df, y_df)
X_res, y_res = sampler.fit_resample(X, y)

# check that we return a pandas dataframe if a dataframe was given in
assert isinstance(X_res_pd, pd.DataFrame)
assert isinstance(y_res_pd, pd.Series)
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
assert y_pd.name == y_res_pd.name
assert_allclose(X_res_pd.to_numpy(), X_res)
assert_allclose(y_res_pd.to_numpy(), y_res)
# check that we return the same type for dataframes or seires types
assert isinstance(X_res_df, pd.DataFrame)
assert isinstance(y_res_df, pd.DataFrame)
assert isinstance(y_res_s, pd.Series)

assert X_df.columns.to_list() == X_res_df.columns.to_list()
assert y_df.columns.to_list() == y_res_df.columns.to_list()
assert y_s.name == y_res_s.name

assert_allclose(X_res_df.to_numpy(), X_res)
assert_allclose(y_res_df.to_numpy().ravel(), y_res)
assert_allclose(y_res_s.to_numpy(), y_res)


def check_samplers_multiclass_ova(name, Sampler):
Expand Down

0 comments on commit 55d8f27

Please sign in to comment.