Skip to content

Commit

Permalink
update to torch 1.8.0 (#79)
Browse files Browse the repository at this point in the history
* remove bg_iterator from data test

* test 1.8.0

* remove in-place for 1.8 jit

* remove soundfile backend deprecation

* update setup.py requirements

* add 1.7.0 tests

* add tqdm

* update outdir behaviour

* report device

* add another cli test

* relax requirement

* update version

* remove 1.7.0 support

* also from setup.py
  • Loading branch information
faroit authored Mar 11, 2021
1 parent ab132ed commit 3557e88
Show file tree
Hide file tree
Showing 12 changed files with 20 additions and 25 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test_unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
pytorch-version: ["1.7.0"]
pytorch-version: ["1.8.0"]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 10
Expand All @@ -34,6 +34,8 @@ jobs:
python -m pip install coverage codecov --upgrade-strategy only-if-needed --quiet
if [ $TORCH_INSTALL == "1.7.0" ]; then
INSTALL="torch==1.7.0+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html"
elif [ $TORCH_INSTALL == "1.8.0" ]; then
INSTALL="torch==1.8.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html"
else
INSTALL="--pre torch torchaudio -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"
fi
Expand Down
5 changes: 2 additions & 3 deletions openunmix/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,13 @@ def separate():
"for deployment.",
)
args = parser.parse_args()
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False

if args.audio_backend != "stempeg":
torchaudio.set_audio_backend(args.audio_backend)

use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

print("Using ", device)
# parsing the output dict
aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate)

Expand Down Expand Up @@ -173,7 +172,7 @@ def separate():
else:
outdir = Path(Path(input_file).stem + "_" + model_path.stem)
else:
outdir = Path(args.outdir)
outdir = Path(args.outdir) / Path(input_file).stem
outdir.mkdir(exist_ok=True, parents=True)

# write out estimates
Expand Down
5 changes: 1 addition & 4 deletions openunmix/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.utils.data
import torchaudio
import tqdm
from torchaudio.datasets.utils import bg_iterator


def load_info(path: str) -> dict:
Expand Down Expand Up @@ -941,7 +940,6 @@ def __len__(self):

args, _ = parser.parse_known_args()

torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(args.audio_backend)

train_dataset, valid_dataset, args = load_datasets(parser, args)
Expand Down Expand Up @@ -970,6 +968,5 @@ def __len__(self):
num_workers=4,
)

train_sampler = bg_iterator(train_sampler, 4)
for x, y in tqdm.tqdm(train_sampler):
pass
print(x.shape)
4 changes: 2 additions & 2 deletions openunmix/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def forward(self, x: Tensor) -> Tensor:
# crop
x = x[..., : self.nb_bins]
# shift and scale input to mean=0 std=1 (across all bins)
x += self.input_mean
x *= self.input_scale
x = x + self.input_mean
x = x * self.input_scale

# to (nb_frames*nb_samples, nb_channels*nb_bins)
# and encode to (nb_frames*nb_samples, hidden_size)
Expand Down
7 changes: 4 additions & 3 deletions openunmix/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(self, x: Tensor) -> Tensor:
# pack batch
x = x.view(-1, shape[-1])

stft_f = torch.stft(
complex_stft = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.n_hop,
Expand All @@ -105,8 +105,9 @@ def forward(self, x: Tensor) -> Tensor:
normalized=False,
onesided=True,
pad_mode="reflect",
return_complex=True,
)

stft_f = torch.view_as_real(complex_stft)
# unpack batch
stft_f = stft_f.view(shape[:-1] + stft_f.shape[-3:])
return stft_f
Expand Down Expand Up @@ -158,7 +159,7 @@ def forward(self, X: Tensor, length: Optional[int] = None) -> Tensor:
X = X.reshape(-1, shape[-3], shape[-2], shape[-1])

y = torch.istft(
X,
torch.view_as_complex(X),
n_fft=self.n_fft,
hop_length=self.n_hop,
window=self.window,
Expand Down
2 changes: 1 addition & 1 deletion openunmix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def preprocess(
audio = torch.repeat_interleave(audio, 2, dim=1)

if rate != model_rate:
print("resampling")
warnings.warn("resample to model sample rate")
# we have to resample to model samplerate if needed
# this makes sure we resample input only once
resampler = torchaudio.transforms.Resample(
Expand Down
1 change: 0 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def main():

args, _ = parser.parse_known_args()

torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(args.audio_backend)
use_cuda = not args.no_cuda and torch.cuda.is_available()
print("Using GPU:", use_cuda)
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

umx_version = "1.1.0"
umx_version = "1.1.1"

with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()
Expand All @@ -18,17 +18,18 @@
python_requires=">=3.6",
install_requires=[
"numpy",
"torchaudio>=0.7.0",
"torch>=1.7.0",
"torchaudio>=0.8.0",
"torch>=1.8.0",
],
extras_require={
"asteroid": ["asteroid-filterbanks>=0.3.2"],
"tests": [
"pytest",
"musdb>=0.4.0",
"museval>=0.4.0",
"onnx",
"asteroid-filterbanks>=0.3.2",
"onnx",
"tqdm",
],
"stempeg": ["stempeg"],
"evaluation": ["musdb>=0.4.0", "museval>=0.4.0"],
Expand Down
1 change: 1 addition & 0 deletions tests/cli_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ python -m pip install -e .['stempeg'] --quiet

# run umx on url
coverage run -a `which umx` https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg
coverage run -a `which umx` https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg --outdir out --niter 0
3 changes: 0 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def test_musdb():


def test_trackfolder_fix(torch_backend):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(torch_backend)

train_dataset = data.FixedSourcesTrackFolderDataset(
Expand All @@ -33,7 +32,6 @@ def test_trackfolder_fix(torch_backend):


def test_trackfolder_var(torch_backend):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(torch_backend)

train_dataset = data.VariableSourcesTrackFolderDataset(
Expand All @@ -48,7 +46,6 @@ def test_trackfolder_var(torch_backend):


def test_sourcefolder(torch_backend):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(torch_backend)

train_dataset = data.SourceFolderDataset(
Expand Down
2 changes: 0 additions & 2 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def dur(request):

@pytest.fixture(params=[True, False])
def info(request, torch_backend):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(torch_backend)

if request.param:
Expand All @@ -34,7 +33,6 @@ def info(request, torch_backend):


def test_loadwav(dur, info, torch_backend):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend(torch_backend)
audio, _ = data.load_audio(audio_path, dur=dur, info=info)
rate = 8000.0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ def test_spectrogram(mus, method):
ref = torch.load(spec_path)
dut = encoder(audio).permute(3, 0, 1, 2)

assert torch.allclose(ref, dut, atol=1e-4, rtol=1e-5)
assert torch.allclose(ref, dut, atol=1e-4, rtol=1e-3)

0 comments on commit 3557e88

Please sign in to comment.