Skip to content

Commit

Permalink
first version of the sum and difference STFT Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
csteinmetz1 committed Nov 16, 2020
1 parent 9159c0a commit db19085
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
43 changes: 37 additions & 6 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import numpy as np

from .perceptual import SumAndDifference

def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
Expand Down Expand Up @@ -219,13 +221,42 @@ def forward(self, input, target):
return sc_loss, mag_loss

class SumAndDifferenceSTFTLoss(torch.nn.Module):
"""Sum and difference stereo balance loss module.
"""Sum and difference sttereo STFT loss module.
See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
"""
def __init__(self):
super(SumAndDiffLoss, self).__init__()
raise NotImplementedError()
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
"""Initialize sum and difference stereo STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(SumAndDifferenceSTFTLoss, self).__init__()
self.sd = SumAndDifference()
self.mrstft = MultiResolutionSTFTLoss(fft_sizes,
hop_sizes,
win_lengths,
window)

def forward(self, input, target):
"""Calculate forward propagation.
Args:
input (Tensor): Predicted signal (B, C, T).
target (Tensor): Groundtruth signal (B, C, T).
Returns:
Tensor: Multi resolution sum signal loss value.
Tensor: Multi resolution difference signal loss value.
"""
input_sum, input_diff = self.sd(input)
target_sum, target_diff = self.sd(target)

sum_loss = torch.stack(self.mrstft(input_sum, target_sum)).sum()
diff_loss = torch.stack(self.mrstft(input_diff, target_diff)).sum()

def forward(input, target):
return None
return sum_loss, diff_loss
31 changes: 14 additions & 17 deletions auraloss/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,29 @@ def __init__(self):
"""Initialize sum and difference extraction module."""
super(SumAndDifference, self).__init__()

def forward(self, input, target):
def forward(self, x):
"""Calculate forward propagation.
Args:
input (Tensor): Predicted signal (B, #channels, #samples).
target (Tensor): Groundtruth signal (B, #channels, #samples).
x (Tensor): Predicted signal (B, #channels, #samples).
Returns:
Tensor: Input sum signal.
Tensor: Input difference signal.
Tensor: Target sum signal.
Tensor: Target difference signal.
Tensor: Sum signal.
Tensor: Difference signal.
"""
if not (input.size(1) == target.size(1) == 2):
raise ValueError("Input and target must be stereo.") # inputs must be stereo
input_sum = self.sum(input)
input_diff = self.diff(input)
target_sum = self.sum(target)
target_diff = self.diff(target)

return input_sum, input_diff, target_sum, target_diff
if not (x.size(1) == 2): # inputs must be stereo
raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")

sum_sig = self.sum(x).unsqueeze(1)
diff_sig = self.diff(x).unsqueeze(1)

return sum_sig, diff_sig

@staticmethod
def sum(self, x):
def sum(x):
return x[:,0,:] + x[:,1,:]

@staticmethod
def diff(self, x):
def diff(x):
return x[:,0,:] - x[:,1,:]


Expand Down
9 changes: 9 additions & 0 deletions tests/simple_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
import auraloss

input = torch.rand(8,2,44100)
target = torch.rand(8,2,44100)

loss = auraloss.freq.SumAndDifferenceSTFTLoss()

print(loss(input, target))

0 comments on commit db19085

Please sign in to comment.