Skip to content

Commit

Permalink
Merge pull request #929 from kitzeslab/feat_set_seed
Browse files Browse the repository at this point in the history
Feat set seed
  • Loading branch information
sammlapp authored Apr 15, 2024
2 parents f1dfaed + ab858ca commit 5f6ee99
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
19 changes: 19 additions & 0 deletions opensoundscape/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import soundfile
import librosa
from matplotlib.colors import LinearSegmentedColormap
import torch
import random


class GetDurationError(ValueError):
Expand Down Expand Up @@ -329,3 +331,20 @@ def generate_opacity_colormaps(
colormaps.append(cmap)

return colormaps


def set_seed(seed, verbose=True):
"""Set random state across different libraries for reproducibility
Args:
seed (int): Number to fix random number generators to a specific start.
verbose (bool, optional): Print set seed. Defaults to True.
"""
if verbose:
print(f"Random state set with seed {seed}")

torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
72 changes: 72 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import pandas as pd
import pytz
import datetime
import torch
import random
from opensoundscape.ml import cnn, cnn_architectures

from opensoundscape import utils

Expand Down Expand Up @@ -168,3 +171,72 @@ def test_make_clip_df_from_label_df(silence_10s_mp3_str, metadata_wav_str):
# should copy labels for each file to all clips of that file
# duplicate file should have labels from _first_ occurrence in label_df
assert np.array_equal(clip_df["a"].values, [0, 0, 0, 0, 2, 2])


# The @pytest.mark.parametrize decorator loops trough each value in list when running pytest.
# If you add --verbose, it also prints if it passed for each value in the list for each function
# that takes it as input.

# For all utils.set_seed() tests, assert that results are determistic for the the same seed AND
# for different seeds, in a tensor/array at least one element is different.


@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_torch_rand(input):
utils.set_seed(input)
tr1 = torch.rand(100)

utils.set_seed(input)
tr2 = torch.rand(100)

utils.set_seed(input + 1)
tr3 = torch.rand(100)

assert all(tr1 == tr2) & any(tr1 != tr3)


@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_numpy_random_rand(input):
utils.set_seed(input)
nr1 = np.random.rand(100)

utils.set_seed(input)
nr2 = np.random.rand(100)

utils.set_seed(input + 1)
nr3 = np.random.rand(100)

assert all(nr1 == nr2) & any(nr1 != nr3)


@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_radom_sample(input):
list1000 = list(range(1, 1000))

utils.set_seed(input)
rs1 = random.sample(list1000, 100)

utils.set_seed(input)
rs2 = random.sample(list1000, 100)

utils.set_seed(input + 1)
rs3 = random.sample(list1000, 100)

assert (rs1 == rs2) & (rs1 != rs3)


@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_cnn(input):
utils.set_seed(input)
model_resnet1 = cnn_architectures.resnet18(num_classes=10, weights=None)
lw1 = model_resnet1.layer1[0].conv1.weight

utils.set_seed(input)
model_resnet2 = cnn_architectures.resnet18(num_classes=10, weights=None)
lw2 = model_resnet2.layer1[0].conv1.weight

utils.set_seed(input + 1)
model_resnet3 = cnn_architectures.resnet18(num_classes=10, weights=None)
lw3 = model_resnet3.layer1[0].conv1.weight

assert torch.all(lw1 == lw2) & torch.any(lw1 != lw3)

0 comments on commit 5f6ee99

Please sign in to comment.