Skip to content

Commit

Permalink
[BREAKING] Remove parametric modeling code (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson authored Jan 14, 2024
1 parent ae86979 commit 090fd22
Show file tree
Hide file tree
Showing 12 changed files with 10 additions and 1,222 deletions.
73 changes: 0 additions & 73 deletions bin/export/main.py

This file was deleted.

42 changes: 0 additions & 42 deletions bin/train/inputs/models/catlstm.json

This file was deleted.

15 changes: 4 additions & 11 deletions bin/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def _ensure_graceful_shutdowns():
import torch
from torch.utils.data import DataLoader

from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset
from nam.data import ConcatDataset, Split, init_dataset
from nam.models import Model
from nam.models._base import BaseNet # HACK access
from nam.util import filter_warnings, timestamp

torch.manual_seed(0)
Expand Down Expand Up @@ -86,8 +85,7 @@ def extend_savefig(i, savefig):
tx = len(ds.x) / 48_000
print(f"Run (t={tx:.2f})")
t0 = time()
args = (ds.vals, ds.x) if isinstance(ds, ParametricDataset) else (ds.x,)
output = model(*args).flatten().cpu().numpy()
output = model(ds.x).flatten().cpu().numpy()
t1 = time()
try:
rt = f"{tx / (t1 - t0):.2f}"
Expand All @@ -96,12 +94,8 @@ def extend_savefig(i, savefig):
print(f"Took {t1 - t0:.2f} ({rt}x)")

plt.figure(figsize=(16, 5))
# plt.plot(ds.x[window_start:window_end], label="Input")
plt.plot(output[window_start:window_end], label="Prediction")
plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
# plt.plot(
# ds.y[window_start:window_end] - output[window_start:window_end], label="Error"
# )
nrmse = _rms(torch.Tensor(output) - ds.y) / _rms(ds.y)
esr = nrmse**2
plt.title(f"ESR={esr:.3f}")
Expand Down Expand Up @@ -227,9 +221,8 @@ def main_inner(
show=False,
)
plot(model, dataset_validation, show=not no_show)
# Convenient export for snapshot models:
if isinstance(model.net, BaseNet):
model.net.export(outdir)
# Export!
model.net.export(outdir)


if __name__ == "__main__":
Expand Down
110 changes: 6 additions & 104 deletions nam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,10 @@ def np_to_wav(

class AbstractDataset(_Dataset, abc.ABC):
@abc.abstractmethod
def __getitem__(
self, idx: int
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
def __getitem__(self, idx: int):
"""
Get input and output audio segment for training / evaluation.
:return:
Case 1: Input (N1,), Output (N2,)
Case 2: Parameters (D,), Input (N1,), Output (N2,)
"""
pass

Expand Down Expand Up @@ -226,8 +220,6 @@ class StopError(StartStopError):
class Dataset(AbstractDataset, InitializableFromConfig):
"""
Take a pair of matched audio files and serve input + output pairs.
No conditioning parameters associated w/ the data.
"""

def __init__(
Expand Down Expand Up @@ -666,75 +658,6 @@ def _validate_preceding_silence(
)


class ParametricDataset(Dataset):
"""
Additionally tracks some conditioning parameters
"""

def __init__(self, params: Dict[str, Union[bool, float, int]], *args, **kwargs):
super().__init__(*args, **kwargs)
self._keys = sorted(tuple(k for k in params.keys()))
self._vals = torch.Tensor([float(params[k]) for k in self._keys])

@classmethod
def init_from_config(cls, config):
if "slices" not in config:
return super().init_from_config(config)
else:
return cls.init_from_config_with_slices(config)

@classmethod
def init_from_config_with_slices(cls, config):
config, x, y, slices = cls.parse_config_with_slices(config)
datasets = []
for s in tqdm(slices, desc="Slices..."):
c = deepcopy(config)
start, stop, params = [s[k] for k in ("start", "stop", "params")]
c.update(x=x[start:stop], y=y[start:stop], params=params)
if "delay" in s:
c["delay"] = s["delay"]
datasets.append(ParametricDataset(**c))
return ConcatDataset(datasets)

@classmethod
def parse_config(cls, config):
assert "slices" not in config
params = config["params"]
return {
"params": params,
"id": config.get("id"),
"common_params": config.get("common_params"),
"param_map": config.get("param_map"),
**super().parse_config(config),
}

@classmethod
def parse_config_with_slices(cls, config):
slices = config["slices"]
config = super().parse_config(config)
x, y = [config.pop(k) for k in "xy"]
return config, x, y, slices

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:return:
Parameter values (D,)
Input (NX+NY-1,)
Output (NY,)
"""
# FIXME don't override signature
x, y = super().__getitem__(idx)
return self.vals, x, y

@property
def keys(self) -> Tuple[str]:
return self._keys

@property
def vals(self):
return self._vals


class ConcatDataset(AbstractDataset, InitializableFromConfig):
def __init__(self, datasets: Sequence[Dataset], flatten=True):
if flatten:
Expand Down Expand Up @@ -815,35 +738,23 @@ def _validate_datasets(cls, datasets: Sequence[Dataset]):
raise ValueError(
f"Mismatch between ny of datasets {ref_ny.index} ({ref_ny.val}) and {i} ({d.ny})"
)
if isinstance(d, ParametricDataset):
val = d.keys
if ref_keys is None:
ref_keys = Reference(i, val)
if val != ref_keys.val:
raise ValueError(
f"Mismatch between keys of datasets {ref_keys.index} "
f"({ref_keys.val}) and {i} ({val})"
)


_dataset_init_registry = {
"dataset": Dataset.init_from_config,
"parametric": ParametricDataset.init_from_config, # To be removed in v0.8
}
_dataset_init_registry = {"dataset": Dataset.init_from_config}


def register_dataset_initializer(
name: str, constructor: Callable[[Any], AbstractDataset], overwrite=False
):
"""
If you have otehr data set types, you can register their initializer by name using
If you have other data set types, you can register their initializer by name using
this.
For example, the basic NAM is registered by default under the name "default", but if
it weren't, you could register it like this:
>>> from nam import data
>>> data.register_dataset_initializer("parametric", data.Dataset.init_from_config)
>>> data.register_dataset_initializer("parametric", MyParametricDataset.init_from_config)
:param name: The name that'll be used in the config to ask for the data set type
:param constructor: The constructor that'll be fed the config.
Expand All @@ -856,16 +767,7 @@ def register_dataset_initializer(


def init_dataset(config, split: Split) -> AbstractDataset:
if "parametric" in config:
logger.warning(
"Using the 'parametric' keyword is deprecated and will be removed in next "
"version. Instead, register the parametric dataset type using "
"`nam.data.register_dataset_initializer()` and then specify "
'`"type": "name"` in the config, using the name you registered.'
)
name = "parametric" if config["parametric"] else "dataset"
else:
name = config.get("type", "dataset")
name = config.get("type", "dataset")
base_config = config[split.value]
common = config.get("common", {})
if isinstance(base_config, dict):
Expand Down
5 changes: 0 additions & 5 deletions nam/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from .conv_net import ConvNet
from .linear import Linear
from .losses import apply_pre_emphasis_filter, esr, multi_resolution_stft_loss, mse_fft
from .parametric.catnets import CatLSTM, CatWaveNet
from .parametric.hyper_net import HyperConvNet
from .recurrent import LSTM
from .wavenet import WaveNet

Expand Down Expand Up @@ -120,10 +118,7 @@ class _LossItem(NamedTuple):


_model_net_init_registry = {
"CatLSTM": CatLSTM.init_from_config,
"CatWaveNet": CatWaveNet.init_from_config,
"ConvNet": ConvNet.init_from_config,
"HyperConvNet": HyperConvNet.init_from_config,
"Linear": Linear.init_from_config,
"LSTM": LSTM.init_from_config,
"WaveNet": WaveNet.init_from_config,
Expand Down
3 changes: 0 additions & 3 deletions nam/models/parametric/__init__.py

This file was deleted.

Loading

0 comments on commit 090fd22

Please sign in to comment.