diff --git a/sslearn/model_selection/_split.py b/sslearn/model_selection/_split.py index 0e5cd26..731caa2 100644 --- a/sslearn/model_selection/_split.py +++ b/sslearn/model_selection/_split.py @@ -138,6 +138,8 @@ def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimu if force_minimum is not None: label = np.concatenate((selected, label)) + y_unlabel_original = y[unlabel] + # Create the label and unlabel sets X_label, y_label, X_unlabel, y_unlabel = X[label], y[label],\ X[unlabel], np.array([-1] * len(unlabel)) @@ -151,9 +153,9 @@ def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimu y = pd.Series(y) if indexes: - return X, y, X_unlabel, y_unlabel, label, unlabel + return X, y, X_unlabel, y_unlabel_original, label, unlabel - return X, y, X_unlabel, y_unlabel + return X, y, X_unlabel, y_unlabel_original """