Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GUI: advanced option to enable training-time upsampling so that it is compatible with all training WAVs #354

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions nam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

from ._core import InitializableFromConfig

# for upsampling
import scipy

logger = logging.getLogger(__name__)

REQUIRED_RATE = 48_000 # FIXME not "required" anymore!
Expand Down Expand Up @@ -239,6 +242,7 @@ def __init__(
input_gain: float = 0.0,
sample_rate: Optional[int] = None,
rate: Optional[int] = None,
resample_rate: Optional[int] = None,
require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
Expand Down Expand Up @@ -269,6 +273,7 @@ def __init__(
into the input with which the guitar was originally recorded.)
:param sample_rate: Sample rate for the data
:param rate: Sample rate for the data (deprecated)
:param resample_rate: Resample rate for the data
:param require_input_pre_silence: If provided, require that this much time (in
seconds) preceding the start of the data set (`start`) have a silent input.
If it's not, then raise an exception because the output due to it will leak
Expand All @@ -277,6 +282,7 @@ def __init__(
self._validate_x_y(x, y)
self._validate_start_stop(x, y, start, stop)
self._sample_rate = self._validate_sample_rate(sample_rate, rate)
self._resample_rate = resample_rate
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
Expand All @@ -294,11 +300,30 @@ def __init__(
self._x_path = x_path
self._y_path = y_path
self._validate_inputs_after_processing(x, y, nx, ny)
self._x = x
self._y = y

if self._resample_rate == 0 or self._resample_rate is None or self._resample_rate == self._sample_rate:
self._x = x
self._y = y
self._resample_rate = self._sample_rate # in case it was 0
else:
# Upsample x and y - changed for resampling, e.g. from 48kHz to 96kHz during training
print("Resampling for training, original rate: ", self._sample_rate," new rate: ", self._resample_rate)
self._x = self._upsample(x, original_rate=self._sample_rate, new_rate=self._resample_rate)
self._y = self._upsample(y, original_rate=self._sample_rate, new_rate=self._resample_rate)
self._sample_rate = self._resample_rate

self._nx = nx
self._ny = ny if ny is not None else len(x) - nx + 1

def _upsample(self, signal: torch.Tensor, original_rate: int, new_rate: int) -> torch.Tensor:
"""
Upsample a signal using scipy's resample function.
"""
num_samples = int(len(signal) * new_rate / original_rate)
signal_np = signal.detach().cpu().numpy()
upsampled_signal_np = scipy.signal.resample(signal_np, num_samples)
return torch.Tensor(upsampled_signal_np)

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return:
Expand Down Expand Up @@ -353,6 +378,9 @@ def parse_config(cls, config):
config["x_path"], info=True, rate=config.get("rate")
)
rate = x_wavinfo.rate
resample_rate = None
if 'resample_rate' in config:
resample_rate = config['resample_rate']
try:
y = wav_to_tensor(
config["y_path"],
Expand Down Expand Up @@ -405,6 +433,7 @@ def sample_to_time(s, rate):
"x_path": config["x_path"],
"y_path": config["y_path"],
"sample_rate": rate,
"resample_rate": resample_rate,
"require_input_pre_silence": config.get(
"require_input_pre_silence", _DEFAULT_REQUIRE_INPUT_PRE_SILENCE
),
Expand Down
12 changes: 8 additions & 4 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ def _get_configs(
lr_decay: float,
batch_size: int,
fit_cab: bool,
resample_rate: int = None
):
def get_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
Expand All @@ -827,8 +828,8 @@ def get_kwargs(data_info: _DataInfo):
]
train_kwargs, validation_kwargs = get_kwargs(data_info)
data_config = {
"train": {"ny": ny, **train_kwargs},
"validation": {"ny": None, **validation_kwargs},
"train": {"ny": ny, "resample_rate": resample_rate, **train_kwargs},
"validation": {"ny": None, "resample_rate": resample_rate, **validation_kwargs},
"common": {
"x_path": input_path,
"y_path": output_path,
Expand Down Expand Up @@ -885,14 +886,15 @@ def get_kwargs(data_info: _DataInfo):
"drop_last": True,
"num_workers": 0,
},
"val_dataloader": {},
"val_dataloader": {
},
"trainer": {"max_epochs": epochs, **device_config},
}
return data_config, model_config, learning_config


def _get_dataloaders(
data_config: Dict, learning_config: Dict, model: Model
data_config: Dict, learning_config: Dict, model: Model, resample_rate: int = None
) -> Tuple[DataLoader, DataLoader]:
data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)]
data_config["common"]["nx"] = model.net.receptive_field
Expand Down Expand Up @@ -1015,6 +1017,7 @@ def train(
ignore_checks: bool = False,
local: bool = False,
fit_cab: bool = False,
resample_rate: int = None,
) -> Optional[Model]:
if seed is not None:
torch.manual_seed(seed)
Expand Down Expand Up @@ -1060,6 +1063,7 @@ def train(
lr_decay,
batch_size,
fit_cab,
resample_rate=resample_rate
)

print("Starting training. It's time to kick ass and chew bubblegum!")
Expand Down
26 changes: 25 additions & 1 deletion nam/train/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ def _ensure_graceful_shutdowns():
_BUTTON_HEIGHT = 2
_TEXT_WIDTH = 70

class _SampleRates(Enum):
S48K = "48000"
S96K = "96000"
S192K = "192000"

_DEFAULT_DELAY = None
_DEFAULT_IGNORE_CHECKS = False
_DEFAULT_SAMPLE_RATE = _SampleRates.S48K

_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
Expand All @@ -76,6 +82,7 @@ class _AdvancedOptions(object):
num_epochs: int
delay: Optional[int]
ignore_checks: bool
sample_rate: _SampleRates


class _PathType(Enum):
Expand All @@ -84,6 +91,9 @@ class _PathType(Enum):
MULTIFILE = "multifile"





class _PathButton(object):
"""
Button and the path
Expand Down Expand Up @@ -223,6 +233,7 @@ def __init__(self):
_DEFAULT_NUM_EPOCHS,
_DEFAULT_DELAY,
_DEFAULT_IGNORE_CHECKS,
_DEFAULT_SAMPLE_RATE,
)
# Window to edit them:
self._frame_advanced_options = tk.Frame(self._root)
Expand Down Expand Up @@ -337,6 +348,7 @@ def _train(self):
architecture = self.advanced_options.architecture
delay = self.advanced_options.delay
file_list = self._path_button_output.val
sample_rate = self.advanced_options.sample_rate

# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
Expand Down Expand Up @@ -370,6 +382,7 @@ def _train(self):
].variable.get(),
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
resample_rate=int(sample_rate.value)
)
if trained_model is None:
print("Model training failed! Skip exporting...")
Expand Down Expand Up @@ -526,7 +539,6 @@ def get(self):
except tk.TclError:
return None


class _AdvancedOptionsGUI(object):
"""
A window to hold advanced options (Architecture and number of epochs)
Expand Down Expand Up @@ -569,6 +581,17 @@ def __init__(self, parent: _GUI):
type=_int_or_null,
)

# Resample: radio buttons
self._frame_resample = tk.Frame(self._root)
self._frame_resample.pack()
self._sample_rate = _LabeledOptionMenu(
self._frame_resample,
"Resample",
_SampleRates,
default=self._parent.advanced_options.sample_rate,
)


# "Ok": apply and destory
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
Expand All @@ -590,6 +613,7 @@ def _apply_and_destroy(self):
Set values to parent and destroy this object
"""
self._parent.advanced_options.architecture = self._architecture.get()
self._parent.advanced_options.sample_rate = self._sample_rate.get()
epochs = self._epochs.get()
if epochs is not None:
self._parent.advanced_options.num_epochs = epochs
Expand Down