-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtransforms.py
173 lines (141 loc) · 6.01 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Code taken/adapted from https://towardsdatascience.com/audio-deep-learning-made-simple-sound-classification-step-by-step-cebc936bbe5
import pandas as pd
import random
import torch
import torchaudio
from torchaudio import transforms
from IPython.display import Audio
class AudioUtil():
# ----------------------------
# Load an audio file. Return the signal as a tensor and the sample rate
# ----------------------------
def open(audio_file):
sig, sr = torchaudio.load(audio_file)
return (sig, sr)
# ----------------------------
# Convert the given audio to the desired number of channels
# ----------------------------
def rechannel(aud, new_channel):
sig, sr = aud
if (sig.shape[0] == new_channel):
# Nothing to do
return aud
if (new_channel == 1):
# Convert from stereo to mono by selecting only the first channel
resig = sig[:1, :]
else:
# Convert from mono to stereo by duplicating the first channel
resig = torch.cat([sig, sig])
return ((resig, sr))
# ----------------------------
# Since Resample applies to a single channel, we resample one channel at a time
# ----------------------------
def resample(aud, newsr):
sig, sr = aud
if (sr == newsr):
# Nothing to do
return aud
num_channels = sig.shape[0]
# Resample first channel
resig = torchaudio.transforms.Resample(sr, newsr)(sig[:1,:])
if (num_channels > 1):
# Resample the second channel and merge both channels
retwo = torchaudio.transforms.Resample(sr, newsr)(sig[1:,:])
resig = torch.cat([resig, retwo])
return ((resig, newsr))
# ----------------------------
# Pad (or truncate) the signal to a fixed length 'max_ms' in milliseconds
# ----------------------------
def pad_trunc(aud, max_ms):
sig, sr = aud
num_rows, sig_len = sig.shape
max_len = sr//1000 * max_ms
if (sig_len > max_len):
# Truncate the signal to the given length
sig = sig[:,:max_len]
elif (sig_len < max_len):
# Length of padding to add at the beginning and end of the signal
pad_begin_len = random.randint(0, max_len - sig_len)
pad_end_len = max_len - sig_len - pad_begin_len
# Pad with 0s
pad_begin = torch.zeros((num_rows, pad_begin_len))
pad_end = torch.zeros((num_rows, pad_end_len))
sig = torch.cat((pad_begin, sig, pad_end), 1)
return (sig, sr)
# ----------------------------
# Shifts the signal to the left or right by some percent. Values at the end
# are 'wrapped around' to the start of the transformed signal.
# ----------------------------
def time_shift(aud, shift_limit):
sig,sr = aud
_, sig_len = sig.shape
shift_amt = int(random.random() * shift_limit * sig_len)
return (sig.roll(shift_amt), sr)
# ----------------------------
# Generate a Spectrogram
# ----------------------------
def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
sig,sr = aud
top_db = 80
# spec has shape [channel, n_mels, time], where channel is mono, stereo etc
spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)
# Convert to decibels
spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
return (spec)
# ----------------------------
# Augment the Spectrogram by masking out some sections of it in both the frequency
# dimension (ie. horizontal bars) and the time dimension (vertical bars) to prevent
# overfitting and to help the model generalise better. The masked sections are
# replaced with the mean value.
# ----------------------------
def spectro_augment(spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1):
_, n_mels, n_steps = spec.shape
mask_value = spec.mean()
aug_spec = spec
freq_mask_param = max_mask_pct * n_mels
for _ in range(n_freq_masks):
aug_spec = transforms.FrequencyMasking(freq_mask_param)(aug_spec, mask_value)
time_mask_param = max_mask_pct * n_steps
for _ in range(n_time_masks):
aug_spec = transforms.TimeMasking(time_mask_param)(aug_spec, mask_value)
return aug_spec
from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio
# ----------------------------
# Sound Dataset
# ----------------------------
class SoundDS(Dataset):
def __init__(self, df):
self.df = df
# self.data_path = str(data_path)
self.duration = 4000
self.sr = 44100
self.channel = 2
self.shift_pct = 0.4
# ----------------------------
# Number of items in dataset
# ----------------------------
def __len__(self):
return len(self.df)
# ----------------------------
# Get i'th item in dataset
# ----------------------------
def __getitem__(self, idx):
# Absolute file path of the audio file - concatenate the audio directory with
# the relative path
audio_file = self.df.loc[idx, 'file path']
# Get the Class ID
class_id = self.df.loc[idx, 'class']
aud = AudioUtil.open(audio_file)
# Some sounds have a higher sample rate, or fewer channels compared to the
# majority. So make all sounds have the same number of channels and same
# sample rate. Unless the sample rate is the same, the pad_trunc will still
# result in arrays of different lengths, even though the sound duration is
# the same.
reaud = AudioUtil.resample(aud, self.sr)
rechan = AudioUtil.rechannel(reaud, self.channel)
dur_aud = AudioUtil.pad_trunc(rechan, self.duration)
shift_aud = AudioUtil.time_shift(dur_aud, self.shift_pct)
sgram = AudioUtil.spectro_gram(shift_aud, n_mels=64, n_fft=1024, hop_len=None)
aug_sgram = AudioUtil.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)
return aug_sgram, class_id