diff --git a/opensoundscape/utils.py b/opensoundscape/utils.py index a28b9ae1..30a78122 100644 --- a/opensoundscape/utils.py +++ b/opensoundscape/utils.py @@ -9,6 +9,8 @@ import soundfile import librosa from matplotlib.colors import LinearSegmentedColormap +import torch +import random class GetDurationError(ValueError): @@ -329,3 +331,20 @@ def generate_opacity_colormaps( colormaps.append(cmap) return colormaps + + +def set_seed(seed, verbose=True): + """Set random state across different libraries for reproducibility + + Args: + seed (int): Number to fix random number generators to a specific start. + verbose (bool, optional): Print set seed. Defaults to True. + """ + if verbose: + print(f"Random state set with seed {seed}") + + torch.backends.cudnn.deterministic = True + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/tests/test_utils.py b/tests/test_utils.py index bfea9cf3..5ed2eb0f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,9 @@ import pandas as pd import pytz import datetime +import torch +import random +from opensoundscape.ml import cnn, cnn_architectures from opensoundscape import utils @@ -168,3 +171,72 @@ def test_make_clip_df_from_label_df(silence_10s_mp3_str, metadata_wav_str): # should copy labels for each file to all clips of that file # duplicate file should have labels from _first_ occurrence in label_df assert np.array_equal(clip_df["a"].values, [0, 0, 0, 0, 2, 2]) + + +# The @pytest.mark.parametrize decorator loops trough each value in list when running pytest. +# If you add --verbose, it also prints if it passed for each value in the list for each function +# that takes it as input. + +# For all utils.set_seed() tests, assert that results are determistic for the the same seed AND +# for different seeds, in a tensor/array at least one element is different. + + +@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) +def test_torch_rand(input): + utils.set_seed(input) + tr1 = torch.rand(100) + + utils.set_seed(input) + tr2 = torch.rand(100) + + utils.set_seed(input + 1) + tr3 = torch.rand(100) + + assert all(tr1 == tr2) & any(tr1 != tr3) + + +@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) +def test_numpy_random_rand(input): + utils.set_seed(input) + nr1 = np.random.rand(100) + + utils.set_seed(input) + nr2 = np.random.rand(100) + + utils.set_seed(input + 1) + nr3 = np.random.rand(100) + + assert all(nr1 == nr2) & any(nr1 != nr3) + + +@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) +def test_radom_sample(input): + list1000 = list(range(1, 1000)) + + utils.set_seed(input) + rs1 = random.sample(list1000, 100) + + utils.set_seed(input) + rs2 = random.sample(list1000, 100) + + utils.set_seed(input + 1) + rs3 = random.sample(list1000, 100) + + assert (rs1 == rs2) & (rs1 != rs3) + + +@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) +def test_cnn(input): + utils.set_seed(input) + model_resnet1 = cnn_architectures.resnet18(num_classes=10, weights=None) + lw1 = model_resnet1.layer1[0].conv1.weight + + utils.set_seed(input) + model_resnet2 = cnn_architectures.resnet18(num_classes=10, weights=None) + lw2 = model_resnet2.layer1[0].conv1.weight + + utils.set_seed(input + 1) + model_resnet3 = cnn_architectures.resnet18(num_classes=10, weights=None) + lw3 = model_resnet3.layer1[0].conv1.weight + + assert torch.all(lw1 == lw2) & torch.any(lw1 != lw3)