Skip to content

Commit

Permalink
[BREAKING] Get rid of REQUIRED_RATE (#375)
Browse files Browse the repository at this point in the history
Get rid of `REQUIRED_RATE`
  • Loading branch information
sdatkinson authored Feb 7, 2024
1 parent 7808e90 commit 2959e53
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 55 deletions.
52 changes: 19 additions & 33 deletions nam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

logger = logging.getLogger(__name__)

REQUIRED_RATE = 48_000 # FIXME not "required" anymore!
_DEFAULT_RATE = REQUIRED_RATE # There we go :)
_REQUIRED_CHANNELS = 1 # Mono


Expand Down Expand Up @@ -242,8 +240,7 @@ def __init__(
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
input_gain: float = 0.0,
sample_rate: Optional[int] = None,
rate: Optional[int] = None,
sample_rate: Optional[float] = None,
require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
Expand Down Expand Up @@ -283,16 +280,13 @@ def __init__(
completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
:param sample_rate: Sample rate for the data
:param rate: Sample rate for the data (deprecated)
:param require_input_pre_silence: If provided, require that this much time (in
seconds) preceding the start of the data set (`start`) have a silent input.
If it's not, then raise an exception because the output due to it will leak
into the data set that we're trying to use. If `None`, don't assert.
"""
self._validate_x_y(x, y)
self._sample_rate = self._validate_sample_rate(
sample_rate, rate, default=_DEFAULT_RATE
)
self._sample_rate = sample_rate
start, stop = self._validate_start_stop(
x,
y,
Expand All @@ -302,15 +296,15 @@ def __init__(
stop_samples,
start_seconds,
stop_seconds,
self._sample_rate,
self.sample_rate,
)
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
)
if require_input_pre_silence is not None:
self._validate_preceding_silence(
x, start, int(require_input_pre_silence * self._sample_rate)
x, start, require_input_pre_silence, self.sample_rate
)
x, y = [z[start:stop] for z in (x, y)]
if delay is not None and delay != 0:
Expand Down Expand Up @@ -377,9 +371,12 @@ def y_offset(self) -> int:
@classmethod
def parse_config(cls, config):
config = deepcopy(config)
sample_rate = cls._validate_sample_rate(
config.pop("sample_rate", None), config.pop("rate", None)
)
if "rate" in config:
raise ValueError(
"use of `rate` was deprecated in version 0.8. Use `sample_rate` "
"instead."
)
sample_rate = config.pop("sample_rate", None)
x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
sample_rate = x_wavinfo.rate
try:
Expand Down Expand Up @@ -469,25 +466,6 @@ def _apply_delay_float(
y = _interpolate_delay(y, delay, method)
return x, y

@classmethod
def _validate_sample_rate(
cls, sample_rate: Optional[float], rate: Optional[int], default=None
) -> float:
if sample_rate is None and rate is None: # Default value
return default
if rate is not None:
if sample_rate is not None:
raise ValueError(
"Provided both sample_rate and rate. Provide only sample_rate!"
)
else:
logger.warning(
"Use of 'rate' is deprecated and will be removed. Use sample_rate instead"
)
return float(rate)
else:
return sample_rate

@classmethod
def _validate_start_stop(
cls,
Expand Down Expand Up @@ -632,19 +610,27 @@ def _validate_inputs_after_processing(self, x, y, nx, ny):

@classmethod
def _validate_preceding_silence(
cls, x: torch.Tensor, start: Optional[int], silent_samples: int
cls, x: torch.Tensor, start: Optional[int], silent_seconds: float, sample_rate: Optional[float]
):
"""
Make sure that the input is silent before the starting index.
If it's not, then the output from that non-silent input will leak into the data
set and couldn't be predicted!
This assumes that silence is indeed required. If it's not, then don't call this!
See: Issue #252
:param x: Input
:param start: Where the data starts
:param silent_samples: How many are expected to be silent
"""
if sample_rate is None:
raise ValueError(
f"Pre-silence was required for {silent_seconds} seconds, but no sample "
"rate was provided!"
)
silent_samples = int(silent_seconds * sample_rate)
if start is None:
return
raw_check_start = start - silent_samples
Expand Down
8 changes: 6 additions & 2 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from .._core import InitializableFromConfig
from ..data import REQUIRED_RATE, wav_to_tensor
from ..data import wav_to_tensor
from ._exportable import Exportable


Expand Down Expand Up @@ -133,7 +133,11 @@ def _export_input_output_args(self) -> Tuple[Any]:

def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
args = self._export_input_output_args()
rate = REQUIRED_RATE
rate = self.sample_rate
if rate is None:
raise RuntimeError(
"Cannot export model's input and output without a sample rate."
)
x = torch.cat(
[
torch.zeros((rate,)),
Expand Down
8 changes: 6 additions & 2 deletions nam/models/conv_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


from .. import __version__
from ..data import REQUIRED_RATE, wav_to_tensor
from ..data import wav_to_tensor
from ._activations import get_activation
from ._base import BaseNet
from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME
Expand Down Expand Up @@ -217,7 +217,11 @@ def _export_input_signal(self):
"""
:return: (L,)
"""
rate = REQUIRED_RATE
rate = self.sample_rate
if rate is None:
raise RuntimeError(
"Cannot export model's input and output without a sample rate."
)
return torch.cat(
[
torch.zeros((rate,)),
Expand Down
13 changes: 8 additions & 5 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader

from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.losses import esr
from ..util import filter_warnings
from ._version import Version

__all__ = ["train"]

# Training using the simplified trainers in NAM is done at 48k.
STANDARD_SAMPLE_RATE = 48_000.0


class Architecture(Enum):
STANDARD = "standard"
Expand Down Expand Up @@ -222,7 +225,7 @@ class _DataInfo(BaseModel):
"""

major_version: int
rate: Optional[int]
rate: Optional[float]
t_blips: int
first_blips_start: int
t_validate: int
Expand All @@ -234,7 +237,7 @@ class _DataInfo(BaseModel):

_V1_DATA_INFO = _DataInfo(
major_version=1,
rate=REQUIRED_RATE,
rate=STANDARD_SAMPLE_RATE,
t_blips=48_000,
first_blips_start=0,
t_validate=432_000,
Expand All @@ -254,7 +257,7 @@ class _DataInfo(BaseModel):
# (3:09-3:11) Blips at 3:09.5 and 3:10.5
_V2_DATA_INFO = _DataInfo(
major_version=2,
rate=REQUIRED_RATE,
rate=STANDARD_SAMPLE_RATE,
t_blips=96_000,
first_blips_start=0,
t_validate=432_000,
Expand All @@ -274,7 +277,7 @@ class _DataInfo(BaseModel):
# (3:01-3:10) Validation 2
_V3_DATA_INFO = _DataInfo(
major_version=3,
rate=REQUIRED_RATE,
rate=STANDARD_SAMPLE_RATE,
t_blips=96_000,
first_blips_start=480_000,
t_validate=432_000,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bin/test_train/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest
import torch

from nam.data import REQUIRED_RATE, np_to_wav
from nam.data import np_to_wav

_BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path(
"bin", "train", "main.py"
Expand Down
25 changes: 14 additions & 11 deletions tests/test_nam/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from nam import data

_sample_rates = (44_100, 48_000, 88_200, 96_000)
_SAMPLE_RATES = (44_100.0, 48_000.0, 88_200.0, 96_000.0)
_DEFAULT_SAMPLE_RATE = 48_000.0


class _XYMethod(Enum):
Expand Down Expand Up @@ -85,11 +86,11 @@ def test_apply_delay_int_positive(self):

def test_init(self):
x, y = self._create_xy()
data.Dataset(x, y, 3, None)
data.Dataset(x, y, 3, None, sample_rate=_DEFAULT_SAMPLE_RATE)

def test_init_sample_rate(self):
x, y = self._create_xy()
sample_rate = 48_000.0
sample_rate = _DEFAULT_SAMPLE_RATE
d = data.Dataset(x, y, 3, None, sample_rate=sample_rate)
assert hasattr(d, "sample_rate")
assert isinstance(d.sample_rate, float)
Expand All @@ -100,7 +101,7 @@ def test_init_zero_delay(self):
Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
"""
x, y = self._create_xy()
data.Dataset(x, y, 3, None, delay=0)
data.Dataset(x, y, 3, None, delay=0, sample_rate=_DEFAULT_SAMPLE_RATE)

def test_input_gain(self):
"""
Expand All @@ -112,14 +113,16 @@ def test_input_gain(self):
nx = 3
ny = None
args = (x, y, nx, ny)
d1 = data.Dataset(*args)
d2 = data.Dataset(*args, input_gain=input_gain)
d1 = data.Dataset(*args, sample_rate=_DEFAULT_SAMPLE_RATE)
d2 = data.Dataset(
*args, sample_rate=_DEFAULT_SAMPLE_RATE, input_gain=input_gain
)

sample_x1 = d1[0][0]
sample_x2 = d2[0][0]
assert torch.allclose(sample_x1 * x_scale, sample_x2)

@pytest.mark.parametrize("sample_rate", _sample_rates)
@pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
def test_sample_rates(self, sample_rate: int):
"""
Test that datasets with various sample rates can be made
Expand Down Expand Up @@ -155,7 +158,7 @@ def test_validate_start(self, n: int, start: int, valid: bool):
"""

def init():
data.Dataset(x, y, nx, ny, start=start)
data.Dataset(x, y, nx, ny, start=start, sample_rate=_DEFAULT_SAMPLE_RATE)

nx = 1
ny = None
Expand Down Expand Up @@ -239,7 +242,7 @@ def f():
)
def test_validate_stop(self, n: int, stop: int, valid: bool):
def init():
data.Dataset(x, y, nx, ny, stop=stop)
data.Dataset(x, y, nx, ny, stop=stop, sample_rate=_DEFAULT_SAMPLE_RATE)

nx = 1
ny = None
Expand All @@ -257,7 +260,7 @@ def init():
)
def test_validate_x_y(self, lenx: int, leny: int, valid: bool):
def init():
data.Dataset(x, y, nx, ny)
data.Dataset(x, y, nx, ny, sample_rate=_DEFAULT_SAMPLE_RATE)

x, y = self._create_xy()
assert len(x) >= lenx, "Invalid test!"
Expand Down Expand Up @@ -345,7 +348,7 @@ def test_np_to_wav_to_np(self, tmpdir):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)

@pytest.mark.parametrize("sample_rate", _sample_rates)
@pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
def test_np_to_wav_to_np_sample_rates(self, sample_rate: int):
with TemporaryDirectory() as tmpdir:
# Create random numpy array
Expand Down
3 changes: 2 additions & 1 deletion tests/test_nam/test_train/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def test_validation_preceded_by_silence(self):
Dataset._validate_preceding_silence(
x,
data_info.validation_start,
int(_DEFAULT_REQUIRE_INPUT_PRE_SILENCE * data_info.rate),
_DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
data_info.rate,
)

return C
Expand Down

0 comments on commit 2959e53

Please sign in to comment.