Skip to content

Commit 8afed30

Browse files
authored
extend batch support (#391) (#404)
* extend batch support closes #383 * function for batch test. * set seed.
1 parent a11a5a6 commit 8afed30

File tree

4 files changed

+154
-58
lines changed

4 files changed

+154
-58
lines changed

test/test_functional.py

+66-29
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,17 @@ def test_compute_deltas_randn(self):
7575
win_length = 2 * 7 + 1
7676
specgram = torch.randn(channel, n_mfcc, time)
7777
computed = F.compute_deltas(specgram, win_length=win_length)
78+
7879
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
80+
7981
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
8082

8183
def test_batch_pitch(self):
8284
waveform, sample_rate = torchaudio.load(self.test_filepath)
85+
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
8386

84-
# Single then transform then batch
85-
expected = F.detect_pitch_frequency(waveform, sample_rate)
86-
expected = expected.unsqueeze(0).repeat(3, 1, 1)
87-
88-
# Batch then transform
89-
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
90-
computed = F.detect_pitch_frequency(waveform, sample_rate)
91-
92-
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
93-
self.assertTrue(torch.allclose(computed, expected))
87+
def test_jit_pitch(self):
88+
waveform, sample_rate = torchaudio.load(self.test_filepath)
9489
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
9590

9691
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
@@ -106,22 +101,13 @@ def _test_istft_is_inverse_of_stft(self, kwargs):
106101
for data_size in self.data_sizes:
107102
for i in range(self.number_of_trials):
108103

109-
# Non-batch
110104
sound = common_utils.random_float_tensor(i, data_size)
111105

112106
stft = torch.stft(sound, **kwargs)
113107
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
114108

115109
self._compare_estimate(sound, estimate)
116110

117-
# Batch
118-
stft = torch.stft(sound, **kwargs)
119-
stft = stft.repeat(3, 1, 1, 1, 1)
120-
sound = sound.repeat(3, 1, 1)
121-
122-
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
123-
self._compare_estimate(sound, estimate)
124-
125111
def test_istft_is_inverse_of_stft1(self):
126112
# hann_window, centered, normalized, onesided
127113
kwargs1 = {
@@ -338,6 +324,16 @@ def test_linearity_of_istft4(self):
338324
data_size = (2, 7, 3, 2)
339325
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
340326

327+
def test_batch_istft(self):
328+
329+
stft = torch.tensor([
330+
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
331+
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
332+
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
333+
])
334+
335+
self._test_batch(F.istft, stft, n_fft=4, length=4)
336+
341337
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
342338
# Using a decorator here causes parametrize to fail on Python 2
343339
if not IMPORT_LIBROSA:
@@ -438,22 +434,63 @@ def test_pitch(self):
438434
self.assertFalse(s)
439435

440436
# Convert to stereo and batch for testing purposes
441-
freq = freq.repeat(3, 2, 1, 1)
442-
waveform = waveform.repeat(3, 2, 1, 1)
437+
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
443438

444-
freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
439+
def _test_batch_shape(self, functional, tensor, *args, **kwargs):
445440

446-
assert torch.allclose(freq, freq2, atol=1e-5)
441+
kwargs_compare = {}
442+
if 'atol' in kwargs:
443+
atol = kwargs['atol']
444+
del kwargs['atol']
445+
kwargs_compare['atol'] = atol
447446

448-
def _test_batch(self, functional):
449-
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
447+
if 'rtol' in kwargs:
448+
rtol = kwargs['rtol']
449+
del kwargs['rtol']
450+
kwargs_compare['rtol'] = rtol
450451

451452
# Single then transform then batch
452-
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)
453453

454-
# Batch then transform
455-
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
456-
computed = functional(waveform)
454+
torch.random.manual_seed(42)
455+
expected = functional(tensor.clone(), *args, **kwargs)
456+
expected = expected.unsqueeze(0).unsqueeze(0)
457+
458+
# 1-Batch then transform
459+
460+
tensors = tensor.unsqueeze(0).unsqueeze(0)
461+
462+
torch.random.manual_seed(42)
463+
computed = functional(tensors.clone(), *args, **kwargs)
464+
465+
self._compare_estimate(computed, expected, **kwargs_compare)
466+
467+
return tensors, expected
468+
469+
def _test_batch(self, functional, tensor, *args, **kwargs):
470+
471+
tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)
472+
473+
kwargs_compare = {}
474+
if 'atol' in kwargs:
475+
atol = kwargs['atol']
476+
del kwargs['atol']
477+
kwargs_compare['atol'] = atol
478+
479+
if 'rtol' in kwargs:
480+
rtol = kwargs['rtol']
481+
del kwargs['rtol']
482+
kwargs_compare['rtol'] = rtol
483+
484+
# 3-Batch then transform
485+
486+
ind = [3] + [1] * (int(tensors.dim()) - 1)
487+
tensors = tensor.repeat(*ind)
488+
489+
ind = [3] + [1] * (int(expected.dim()) - 1)
490+
expected = expected.repeat(*ind)
491+
492+
torch.random.manual_seed(42)
493+
computed = functional(tensors.clone(), *args, **kwargs)
457494

458495

459496
def _num_stft_bins(signal_len, fft_len, hop_length, pad):

test/test_transforms.py

+37
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,19 @@ def test_compute_deltas_twochannel(self):
363363
computed = transform(specgram)
364364
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
365365

366+
def test_batch_MelScale(self):
367+
specgram = torch.randn(2, 31, 2786)
368+
369+
# Single then transform then batch
370+
expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
371+
372+
# Batch then transform
373+
computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
374+
375+
# shape = (3, 2, 201, 1394)
376+
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
377+
self.assertTrue(torch.allclose(computed, expected))
378+
366379
def test_batch_compute_deltas(self):
367380
specgram = torch.randn(2, 31, 2786)
368381

@@ -422,6 +435,30 @@ def test_batch_spectrogram(self):
422435
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
423436
self.assertTrue(torch.allclose(computed, expected))
424437

438+
def test_batch_melspectrogram(self):
439+
waveform, sample_rate = torchaudio.load(self.test_filepath)
440+
441+
# Single then transform then batch
442+
expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
443+
444+
# Batch then transform
445+
computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
446+
447+
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
448+
self.assertTrue(torch.allclose(computed, expected))
449+
450+
def test_batch_mfcc(self):
451+
waveform, sample_rate = torchaudio.load(self.test_filepath)
452+
453+
# Single then transform then batch
454+
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
455+
456+
# Batch then transform
457+
computed = transforms.MFCC()(waveform.repeat(3, 1, 1))
458+
459+
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
460+
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
461+
425462
def test_scriptmodule_TimeStretch(self):
426463
n_freq = 400
427464
hop_length = 512

torchaudio/functional.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def istft(
9696
9797
Args:
9898
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
99-
column is a window. it has a size of either (..., fft_size, n_frame, 2)
99+
column is a window. It has a size of either (..., fft_size, n_frame, 2)
100100
n_fft (int): Size of Fourier transform
101101
hop_length (Optional[int]): The distance between neighboring sliding window frames.
102102
(Default: ``win_length // 4``)
@@ -229,7 +229,7 @@ def spectrogram(
229229
The spectrogram can be either magnitude-only or complex.
230230
231231
Args:
232-
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
232+
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
233233
pad (int): Two sided padding of signal
234234
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
235235
n_fft (int): Size of FFT
@@ -241,8 +241,8 @@ def spectrogram(
241241
normalized (bool): Whether to normalize by magnitude after stft
242242
243243
Returns:
244-
torch.Tensor: Dimension (..., channel, freq, time), where channel
245-
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
244+
torch.Tensor: Dimension (..., freq, time), freq is
245+
``n_fft // 2 + 1`` and ``n_fft`` is the number of
246246
Fourier bins, and time is the number of window hops (n_frame).
247247
"""
248248

@@ -613,7 +613,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
613613
https://en.wikipedia.org/wiki/Digital_biquad_filter
614614
615615
Args:
616-
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
616+
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
617617
b0 (float): numerator coefficient of current input, x[n]
618618
b1 (float): numerator coefficient of input one time step ago x[n-1]
619619
b2 (float): numerator coefficient of input two time steps ago x[n-2]
@@ -622,7 +622,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
622622
a2 (float): denominator coefficient of current output y[n-2]
623623
624624
Returns:
625-
output_waveform (torch.Tensor): Dimension of `(channel, time)`
625+
output_waveform (torch.Tensor): Dimension of `(..., time)`
626626
"""
627627

628628
device = waveform.device
@@ -646,13 +646,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
646646
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
647647
648648
Args:
649-
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
649+
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
650650
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
651651
cutoff_freq (float): filter cutoff frequency
652652
Q (float): https://en.wikipedia.org/wiki/Q_factor
653653
654654
Returns:
655-
output_waveform (torch.Tensor): Dimension of `(channel, time)`
655+
output_waveform (torch.Tensor): Dimension of `(..., time)`
656656
"""
657657

658658
GAIN = 1.
@@ -675,13 +675,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
675675
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
676676
677677
Args:
678-
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
678+
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
679679
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
680680
cutoff_freq (float): filter cutoff frequency
681681
Q (float): https://en.wikipedia.org/wiki/Q_factor
682682
683683
Returns:
684-
output_waveform (torch.Tensor): Dimension of `(channel, time)`
684+
output_waveform (torch.Tensor): Dimension of `(..., time)`
685685
"""
686686

687687
GAIN = 1.
@@ -704,14 +704,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
704704
r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
705705
706706
Args:
707-
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
707+
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
708708
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
709709
center_freq (float): filter's central frequency
710710
gain (float): desired gain at the boost (or attenuation) in dB
711711
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
712712
713713
Returns:
714-
output_waveform (torch.Tensor): Dimension of `(channel, time)`
714+
output_waveform (torch.Tensor): Dimension of `(..., time)`
715715
"""
716716
w0 = 2 * math.pi * center_freq / sample_rate
717717
A = math.exp(gain / 40.0 * math.log(10))
@@ -800,7 +800,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
800800
# unpack batch
801801
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
802802

803-
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
803+
return specgram
804804

805805

806806
def compute_deltas(specgram, win_length=5, mode="replicate"):
@@ -860,7 +860,7 @@ def gain(waveform, gain_db=1.0):
860860
r"""Apply amplification or attenuation to the whole waveform.
861861
862862
Args:
863-
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
863+
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
864864
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
865865
866866
Returns:
@@ -913,7 +913,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
913913
The relationship of probabilities of results follows a bell-shaped,
914914
or Gaussian curve, typical of dither generated by analog sources.
915915
Args:
916-
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
916+
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
917917
probability_density_function (string): The density function of a
918918
continuous random variable (Default: `TPDF`)
919919
Options: Triangular Probability Density Function - `TPDF`
@@ -922,6 +922,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
922922
Returns:
923923
torch.Tensor: waveform dithered with TPDF
924924
"""
925+
926+
# pack batch
925927
shape = waveform.size()
926928
waveform = waveform.reshape(-1, shape[-1])
927929

@@ -961,6 +963,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
961963

962964
quantised_signal_scaled = torch.round(signal_scaled_dis)
963965
quantised_signal = quantised_signal_scaled / down_scaling
966+
967+
# unpack batch
964968
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
965969

966970

@@ -970,7 +974,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
970974
particular bit-depth by eliminating nonlinear truncation distortion
971975
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
972976
Args:
973-
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
977+
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
974978
density_function (string): The density function of a
975979
continuous random variable (Default: `TPDF`)
976980
Options: Triangular Probability Density Function - `TPDF`

0 commit comments

Comments
 (0)