-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat set seed #929
Feat set seed #929
Changes from 7 commits
fc43b5c
a7c8f07
d88524f
261851f
87cc881
75cf950
22f48d9
1fc6d3a
ab858ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,9 @@ | |
import pandas as pd | ||
import pytz | ||
import datetime | ||
import torch | ||
import random | ||
from opensoundscape.ml import cnn, cnn_architectures | ||
|
||
from opensoundscape import utils | ||
|
||
|
@@ -168,3 +171,66 @@ def test_make_clip_df_from_label_df(silence_10s_mp3_str, metadata_wav_str): | |
# should copy labels for each file to all clips of that file | ||
# duplicate file should have labels from _first_ occurrence in label_df | ||
assert np.array_equal(clip_df["a"].values, [0, 0, 0, 0, 2, 2]) | ||
|
||
|
||
|
||
# The @pytest.mark.parametrize decorator loops trough each value in list when running pytest. | ||
# If you add --verbose, it also prints if it passed for each value in the list for each function | ||
# that takes it as input. | ||
|
||
@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) | ||
def test_torch_rand(input): | ||
utils.set_seed(input) | ||
tr1 = torch.rand(100) | ||
|
||
utils.set_seed(input) | ||
tr2 = torch.rand(100) | ||
|
||
utils.set_seed(input + 1) | ||
tr3 = torch.rand(100) | ||
|
||
assert all(tr1 == tr2) & any(tr1 != tr3) | ||
|
||
@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) | ||
def test_numpy_random_rand(input): | ||
utils.set_seed(input) | ||
nr1 = np.random.rand(100) | ||
|
||
utils.set_seed(input) | ||
nr2 = np.random.rand(100) | ||
|
||
utils.set_seed(input + 1) | ||
nr3 = np.random.rand(100) | ||
|
||
assert all(nr1 == nr2) & any(nr1 != nr3) | ||
|
||
@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) | ||
def test_radom_sample(input): | ||
list1000 = list(range(1, 1000)) | ||
|
||
utils.set_seed(input) | ||
rs1 = random.sample(list1000, 100) | ||
|
||
utils.set_seed(input) | ||
rs2 = random.sample(list1000, 100) | ||
|
||
utils.set_seed(input + 1) | ||
rs3 = random.sample(list1000, 100) | ||
|
||
assert (rs1 == rs2) & (rs1 != rs3) | ||
|
||
@pytest.mark.parametrize("input", [1, 11, 13, 42, 59, 666, 1234]) | ||
def test_cnn(input): | ||
utils.set_seed(input) | ||
model_resnet1 = cnn_architectures.resnet18(num_classes=10, weights=None) | ||
lw1 = model_resnet1.layer1[0].conv1.weight | ||
|
||
utils.set_seed(input) | ||
model_resnet2 = cnn_architectures.resnet18(num_classes=10, weights=None) | ||
lw2 = model_resnet2.layer1[0].conv1.weight | ||
|
||
utils.set_seed(input + 1) | ||
model_resnet3 = cnn_architectures.resnet18(num_classes=10, weights=None) | ||
lw3 = model_resnet3.layer1[0].conv1.weight | ||
|
||
assert torch.all(lw1 == lw2) & torch.any(lw1 != lw3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't understand the logic here, can you add a comment? is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In summary, also assert that at least one element is different if different seeds. Added as a comment in tests/test_utils.py. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imports should always be at the top of a file. VS Code will help detect duplicate imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry, missed that. Just pushed the correction.