Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat set seed #929

Merged
merged 9 commits into from
Apr 15, 2024
25 changes: 24 additions & 1 deletion opensoundscape/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import soundfile
import librosa
from matplotlib.colors import LinearSegmentedColormap

import torch
import random

class GetDurationError(ValueError):
"""raised if librosa.get_duration(path=f) causes an error"""
Expand Down Expand Up @@ -329,3 +330,25 @@ def generate_opacity_colormaps(
colormaps.append(cmap)

return colormaps


import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports should always be at the top of a file. VS Code will help detect duplicate imports

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, missed that. Just pushed the correction.

import torch
import random


def set_seed(seed, verbose=True):
"""Set random state across different libraries for reproducibility

Args:
seed (int): Number to fix random number generators to a specific start.
verbose (bool, optional): Print set seed. Defaults to True.
"""
if verbose:
print(f"Random state set with seed {seed}")

torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
66 changes: 66 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import pandas as pd
import pytz
import datetime
import torch
import random
from opensoundscape.ml import cnn, cnn_architectures

from opensoundscape import utils

Expand Down Expand Up @@ -168,3 +171,66 @@ def test_make_clip_df_from_label_df(silence_10s_mp3_str, metadata_wav_str):
# should copy labels for each file to all clips of that file
# duplicate file should have labels from _first_ occurrence in label_df
assert np.array_equal(clip_df["a"].values, [0, 0, 0, 0, 2, 2])



# The @pytest.mark.parametrize decorator loops trough each value in list when running pytest.
# If you add --verbose, it also prints if it passed for each value in the list for each function
# that takes it as input.

@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_torch_rand(input):
utils.set_seed(input)
tr1 = torch.rand(100)

utils.set_seed(input)
tr2 = torch.rand(100)

utils.set_seed(input + 1)
tr3 = torch.rand(100)

assert all(tr1 == tr2) & any(tr1 != tr3)

@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_numpy_random_rand(input):
utils.set_seed(input)
nr1 = np.random.rand(100)

utils.set_seed(input)
nr2 = np.random.rand(100)

utils.set_seed(input + 1)
nr3 = np.random.rand(100)

assert all(nr1 == nr2) & any(nr1 != nr3)

@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_radom_sample(input):
list1000 = list(range(1, 1000))

utils.set_seed(input)
rs1 = random.sample(list1000, 100)

utils.set_seed(input)
rs2 = random.sample(list1000, 100)

utils.set_seed(input + 1)
rs3 = random.sample(list1000, 100)

assert (rs1 == rs2) & (rs1 != rs3)

@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234])
def test_cnn(input):
utils.set_seed(input)
model_resnet1 = cnn_architectures.resnet18(num_classes=10, weights=None)
lw1 = model_resnet1.layer1[0].conv1.weight

utils.set_seed(input)
model_resnet2 = cnn_architectures.resnet18(num_classes=10, weights=None)
lw2 = model_resnet2.layer1[0].conv1.weight

utils.set_seed(input + 1)
model_resnet3 = cnn_architectures.resnet18(num_classes=10, weights=None)
lw3 = model_resnet3.layer1[0].conv1.weight

assert torch.all(lw1 == lw2) & torch.any(lw1 != lw3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't understand the logic here, can you add a comment? is the any asserting that none of the weights should be the same? That doesn't seem like the correct logic, rather we want to assert that the entire array isn't equal to the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In summary, also assert that at least one element is different if different seeds. Added as a comment in tests/test_utils.py.

Loading