Skip to content

Commit

Permalink
fixing documentation and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Yasha Ektefaie committed Dec 24, 2024
1 parent a145195 commit 205a5bb
Showing 1 changed file with 60 additions and 5 deletions.
65 changes: 60 additions & 5 deletions spectrae/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def generate_spectra_split(self,
random_seed: int = 42,
test_size: float = 0.2,
degree_choosing: bool = False,
minimum: int = None):
minimum: int = None,
path_to_save: str = None):

print(f"Generating SPECTRA split for spectral parameter {spectral_parameter} and dataset {self.dataset.name}")
result = run_independent_set(spectral_parameter, self.SPG,
Expand All @@ -143,13 +144,67 @@ def generate_spectra_split(self,
print(f"Number of samples in independent set: {len(result)}")
train, test = self.spectra_train_test_split(result, test_size=test_size, random_state=random_seed)
stats = self.get_stats(train, test, spectral_parameter)
return train, test, stats
if path_to_save is None:
return train, test, stats
else:
i = 0
if not os.path.exists(f"{path_to_save}/SP_{spectral_parameter}_{i}"):
os.makedirs(f"{path_to_save}/SP_{spectral_parameter}_{i}")

pickle.dump(train, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/train.pkl", "wb"))
pickle.dump(test, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/test.pkl", "wb"))
pickle.dump(stats, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/stats.pkl", "wb"))

def get_stats(self, train, test, spectral_parameter, chunksize = 10000000, show_progress = False):
def get_stats(self, train: List,
test: List,
spectral_parameter: float,
chunksize: int = 10000000,
show_progress: bool = False,
sample_values: bool = False):

"""
Computes statistics for the given train and test splits.
Args:
train (List): A list of training sample IDs or sample indices. (see sample_values)
test (List): A list of test sample IDs or sample indices. (see sample_values)
spectral_parameter (float): The spectral parameter used for computation.
chunksize (int, optional): The size of chunks to process at a time. Default is 10,000,000. Decrease if you get a OOM error.
show_progress (bool, optional): Whether to show progress during computation. Default is False.
sample_values (bool, optional): True if you are passing sample IDs, False if you are passing sample indices. Default is False.
Returns:
Dict[str, Any]: A dictionary containing the computed statistics. The keys and values depend on whether the data is binary or not.
If not binary:
- 'SPECTRA_parameter' (float): The spectral parameter used.
- 'train_size' (int): The size of the training set.
- 'test_size' (int): The size of the testing set.
- 'cross_split_overlap' (float): The cross-split overlap value.
- 'std_css' (float): The standard deviation of the cross-split similarity.
- 'max_css' (float): The maximum cross-split similarity.
- 'min_css' (float): The minimum cross-split similarity.
If binary:
- 'SPECTRA_parameter' (float): The spectral parameter used.
- 'train_size' (int): The size of the training set.
- 'test_size' (int): The size of the testing set.
- 'cross_split_overlap' (float): The cross-split overlap value.
- 'num_similar' (int): The number of similar items.
- 'num_total' (int): The total number of items.
Raises:
ValueError: If the train or test lists are empty.
"""

train_size = len(train)
test_size = len(test)

if sample_values:
train = self.get_sample_indices(train)
test = self.get_sample_indices(test)

if not self.binary:
cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress)
cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(train, test, chunksize, show_progress)
stats = {'SPECTRA_parameter': spectral_parameter,
'train_size': train_size,
'test_size': test_size,
Expand All @@ -158,7 +213,7 @@ def get_stats(self, train, test, spectral_parameter, chunksize = 10000000, show_
'max_css': max_css,
'min_css': min_css}
else:
cross_split_overlap, num_similar, num_total = self.cross_split_overlap(self.get_sample_indices(train), self.get_sample_indices(test), chunksize, show_progress)
cross_split_overlap, num_similar, num_total = self.cross_split_overlap(train, test, chunksize, show_progress)
stats = {'SPECTRA_parameter': spectral_parameter,
'train_size': train_size,
'test_size': test_size,
Expand Down

0 comments on commit 205a5bb

Please sign in to comment.