Skip to content

Commit

Permalink
Artificial SSL Dataset now works with indexes #13
Browse files Browse the repository at this point in the history
  • Loading branch information
jlgarridol committed Feb 6, 2024
1 parent 2df3c8a commit 878ac4e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Add a parameter to `artificial_ssl_dataset` to force a minimum of instances. Issue #11
- Add a parameter to `artificial_ssl_dataset` to return indexes. Issue #13

### Changed
- The `artificial_ssl_dataset` changed the process to generate the dataset, based in indexes. Issue #13

### Fixed
- DeTriTraining now is vectorized and is faster than before.
Expand Down
45 changes: 44 additions & 1 deletion sslearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def split(self, X, y):
yield X_, y_, label, unlabel


def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimum=None, **kwards):
def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimum=None, indexes=False, **kwards):
"""Create an artificial Semi-supervised dataset from a supervised dataset.
Parameters
Expand All @@ -63,6 +63,8 @@ def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimu
Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls, by default None
force_minimum: int, optional
Force a minimum of instances of each class, by default None
indexes: bool, optional
If True, return the indexes of the labeled and unlabeled instances, by default False
shuffle: bool, default=True
Whether or not to shuffle the data before splitting. If shuffle=False then stratify must be None.
stratify: array-like, default=None
Expand All @@ -78,12 +80,49 @@ def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimu
The feature set for each y mark as unlabel
y_unlabel: ndarray
The true label for each y in the same order.
label: ndarray (optional)
The training set indexes for split mark as labeled.
unlabel: ndarray (optional)
The training set indexes for split mark as unlabeled.
"""
assert (label_rate > 0) and (label_rate < 1),\
"Label rate must be in (0, 1)."
assert "test_size" not in kwards and "train_size" not in kwards,\
"Test size and train size are illegal parameters in this method."

indices = np.arange(len(y))

if force_minimum is not None:
try:
selected = __random_select_n_instances(y, force_minimum, random_state)
except ValueError:
raise ValueError("The number of instances of each class is less than force_minimum.")

# Remove selected instances from indices
indices = np.delete(indices, selected, axis=0)

# Train test split with indexes
label, unlabel = ms.train_test_split(indices, train_size=label_rate,
random_state=random_state, **kwards)

if force_minimum is not None:
label = np.concatenate((selected, label))

# Create the label and unlabel sets
X_label, y_label, X_unlabel, y_unlabel = X[label], y[label],\
X[unlabel], np.array([-1] * len(unlabel))

# Create the artificial dataset
X = np.concatenate((X_label, X_unlabel), axis=0)
y = np.concatenate((y_label, y_unlabel), axis=0)

if indexes:
return X, y, X_unlabel, y_unlabel, label, unlabel

return X, y, X_unlabel, y_unlabel


"""
if force_minimum is not None:
try:
selected = __random_select_n_instances(y, force_minimum, random_state)
Expand All @@ -106,8 +145,12 @@ def artificial_ssl_dataset(X, y, label_rate=0.1, random_state=None, force_minimu
if force_minimum is not None:
X = np.concatenate((X, X_selected), axis=0)
y = np.concatenate((y, y_selected), axis=0)
if indexes:
return X, y, X_unlabel, true_label, X_label, X_unlabel
return X, y, X_unlabel, true_label
"""

def __random_select_n_instances(y, n, random_state):

Expand Down
13 changes: 13 additions & 0 deletions test/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ def test_artificial_ssl_dataset_with_force_minimum():

pytest.raises(ValueError, artificial_ssl_dataset, X, y, label_rate=0.02, force_minimum=2)

def test_artificial_ssl_dataset_with_indexes():
X, y = load_iris(return_X_y=True)
X, y, X_unlabel, true_label, label, unlabel = artificial_ssl_dataset(X, y, label_rate=0.1, indexes=True)

assert X_unlabel.shape[0] == unlabel.shape[0]

try:
X, y, X_unlabel, true_label, label, unlabel = artificial_ssl_dataset(X, y, label_rate=0.1, indexes=False)
except ValueError:
pass
except:
assert False, "Should raise ValueError if indexes=False and unpack the label and unlabel indexes."

def test_StratifiedKFoldSS():
X, y = load_iris(return_X_y=True)
splits = 5
Expand Down

0 comments on commit 878ac4e

Please sign in to comment.