From 3557e88d2ac0c32670b42ba0b5ab6fd7faea8ae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian-Robert=20St=C3=B6ter?= Date: Thu, 11 Mar 2021 15:22:10 +0100 Subject: [PATCH] update to torch 1.8.0 (#79) * remove bg_iterator from data test * test 1.8.0 * remove in-place for 1.8 jit * remove soundfile backend deprecation * update setup.py requirements * add 1.7.0 tests * add tqdm * update outdir behaviour * report device * add another cli test * relax requirement * update version * remove 1.7.0 support * also from setup.py --- .github/workflows/test_unittests.yml | 4 +++- openunmix/cli.py | 5 ++--- openunmix/data.py | 5 +---- openunmix/model.py | 4 ++-- openunmix/transforms.py | 7 ++++--- openunmix/utils.py | 2 +- scripts/train.py | 1 - setup.py | 9 +++++---- tests/cli_test.sh | 1 + tests/test_datasets.py | 3 --- tests/test_io.py | 2 -- tests/test_regression.py | 2 +- 12 files changed, 20 insertions(+), 25 deletions(-) diff --git a/.github/workflows/test_unittests.yml b/.github/workflows/test_unittests.yml index cba7e985..996a2ed5 100644 --- a/.github/workflows/test_unittests.yml +++ b/.github/workflows/test_unittests.yml @@ -11,7 +11,7 @@ jobs: strategy: matrix: python-version: [3.6, 3.7, 3.8] - pytorch-version: ["1.7.0"] + pytorch-version: ["1.8.0"] # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 10 @@ -34,6 +34,8 @@ jobs: python -m pip install coverage codecov --upgrade-strategy only-if-needed --quiet if [ $TORCH_INSTALL == "1.7.0" ]; then INSTALL="torch==1.7.0+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html" + elif [ $TORCH_INSTALL == "1.8.0" ]; then + INSTALL="torch==1.8.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html" else INSTALL="--pre torch torchaudio -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" fi diff --git a/openunmix/cli.py b/openunmix/cli.py index 9b655340..08d10b5c 100644 --- a/openunmix/cli.py +++ b/openunmix/cli.py @@ -113,14 +113,13 @@ def separate(): "for deployment.", ) args = parser.parse_args() - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False if args.audio_backend != "stempeg": torchaudio.set_audio_backend(args.audio_backend) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") - + print("Using ", device) # parsing the output dict aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate) @@ -173,7 +172,7 @@ def separate(): else: outdir = Path(Path(input_file).stem + "_" + model_path.stem) else: - outdir = Path(args.outdir) + outdir = Path(args.outdir) / Path(input_file).stem outdir.mkdir(exist_ok=True, parents=True) # write out estimates diff --git a/openunmix/data.py b/openunmix/data.py index 89dde3d6..730b58b6 100644 --- a/openunmix/data.py +++ b/openunmix/data.py @@ -7,7 +7,6 @@ import torch.utils.data import torchaudio import tqdm -from torchaudio.datasets.utils import bg_iterator def load_info(path: str) -> dict: @@ -941,7 +940,6 @@ def __len__(self): args, _ = parser.parse_known_args() - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(args.audio_backend) train_dataset, valid_dataset, args = load_datasets(parser, args) @@ -970,6 +968,5 @@ def __len__(self): num_workers=4, ) - train_sampler = bg_iterator(train_sampler, 4) for x, y in tqdm.tqdm(train_sampler): - pass + print(x.shape) diff --git a/openunmix/model.py b/openunmix/model.py index b75d4e23..0aca6bff 100644 --- a/openunmix/model.py +++ b/openunmix/model.py @@ -124,8 +124,8 @@ def forward(self, x: Tensor) -> Tensor: # crop x = x[..., : self.nb_bins] # shift and scale input to mean=0 std=1 (across all bins) - x += self.input_mean - x *= self.input_scale + x = x + self.input_mean + x = x * self.input_scale # to (nb_frames*nb_samples, nb_channels*nb_bins) # and encode to (nb_frames*nb_samples, hidden_size) diff --git a/openunmix/transforms.py b/openunmix/transforms.py index 821c6354..369064a0 100644 --- a/openunmix/transforms.py +++ b/openunmix/transforms.py @@ -96,7 +96,7 @@ def forward(self, x: Tensor) -> Tensor: # pack batch x = x.view(-1, shape[-1]) - stft_f = torch.stft( + complex_stft = torch.stft( x, n_fft=self.n_fft, hop_length=self.n_hop, @@ -105,8 +105,9 @@ def forward(self, x: Tensor) -> Tensor: normalized=False, onesided=True, pad_mode="reflect", + return_complex=True, ) - + stft_f = torch.view_as_real(complex_stft) # unpack batch stft_f = stft_f.view(shape[:-1] + stft_f.shape[-3:]) return stft_f @@ -158,7 +159,7 @@ def forward(self, X: Tensor, length: Optional[int] = None) -> Tensor: X = X.reshape(-1, shape[-3], shape[-2], shape[-1]) y = torch.istft( - X, + torch.view_as_complex(X), n_fft=self.n_fft, hop_length=self.n_hop, window=self.window, diff --git a/openunmix/utils.py b/openunmix/utils.py index 08bb442f..bece46f8 100644 --- a/openunmix/utils.py +++ b/openunmix/utils.py @@ -295,7 +295,7 @@ def preprocess( audio = torch.repeat_interleave(audio, 2, dim=1) if rate != model_rate: - print("resampling") + warnings.warn("resample to model sample rate") # we have to resample to model samplerate if needed # this makes sure we resample input only once resampler = torchaudio.transforms.Resample( diff --git a/scripts/train.py b/scripts/train.py index 9fae3935..a6b4dc56 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -202,7 +202,6 @@ def main(): args, _ = parser.parse_known_args() - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(args.audio_backend) use_cuda = not args.no_cuda and torch.cuda.is_available() print("Using GPU:", use_cuda) diff --git a/setup.py b/setup.py index 8cd17bb5..ada7575e 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -umx_version = "1.1.0" +umx_version = "1.1.1" with open("README.md", encoding="utf-8") as fh: long_description = fh.read() @@ -18,8 +18,8 @@ python_requires=">=3.6", install_requires=[ "numpy", - "torchaudio>=0.7.0", - "torch>=1.7.0", + "torchaudio>=0.8.0", + "torch>=1.8.0", ], extras_require={ "asteroid": ["asteroid-filterbanks>=0.3.2"], @@ -27,8 +27,9 @@ "pytest", "musdb>=0.4.0", "museval>=0.4.0", - "onnx", "asteroid-filterbanks>=0.3.2", + "onnx", + "tqdm", ], "stempeg": ["stempeg"], "evaluation": ["musdb>=0.4.0", "museval>=0.4.0"], diff --git a/tests/cli_test.sh b/tests/cli_test.sh index 8d2b8c6d..780fcdf5 100644 --- a/tests/cli_test.sh +++ b/tests/cli_test.sh @@ -2,3 +2,4 @@ python -m pip install -e .['stempeg'] --quiet # run umx on url coverage run -a `which umx` https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg +coverage run -a `which umx` https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg --outdir out --niter 0 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 2dac4d11..f4fbce01 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -17,7 +17,6 @@ def test_musdb(): def test_trackfolder_fix(torch_backend): - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(torch_backend) train_dataset = data.FixedSourcesTrackFolderDataset( @@ -33,7 +32,6 @@ def test_trackfolder_fix(torch_backend): def test_trackfolder_var(torch_backend): - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(torch_backend) train_dataset = data.VariableSourcesTrackFolderDataset( @@ -48,7 +46,6 @@ def test_trackfolder_var(torch_backend): def test_sourcefolder(torch_backend): - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(torch_backend) train_dataset = data.SourceFolderDataset( diff --git a/tests/test_io.py b/tests/test_io.py index bfcbcd2b..ffaf80b3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -24,7 +24,6 @@ def dur(request): @pytest.fixture(params=[True, False]) def info(request, torch_backend): - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(torch_backend) if request.param: @@ -34,7 +33,6 @@ def info(request, torch_backend): def test_loadwav(dur, info, torch_backend): - torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False torchaudio.set_audio_backend(torch_backend) audio, _ = data.load_audio(audio_path, dur=dur, info=info) rate = 8000.0 diff --git a/tests/test_regression.py b/tests/test_regression.py index c0c80528..78d25cb2 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -91,4 +91,4 @@ def test_spectrogram(mus, method): ref = torch.load(spec_path) dut = encoder(audio).permute(3, 0, 1, 2) - assert torch.allclose(ref, dut, atol=1e-4, rtol=1e-5) + assert torch.allclose(ref, dut, atol=1e-4, rtol=1e-3)