Skip to content

Commit

Permalink
[FEATURE] Add support for Proteus training files (#376)
Browse files Browse the repository at this point in the history
Add support for Proteus training files
  • Loading branch information
sdatkinson authored Feb 7, 2024
1 parent 2959e53 commit 16a2108
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 34 deletions.
121 changes: 101 additions & 20 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def assign_hash(path):
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
"ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
"36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0),
"80e224bd5622fd6153ff1fd9f34cb3bd": Version(4, 0, 0),
}.get(file_hash)
if version is None:
print(
Expand All @@ -80,7 +81,8 @@ def assign_hash(path):

def detect_weak(input_path) -> Optional[Version]:
def assign_hash(path):
Hashes = Tuple[Optional[str], Optional[str]]
Hash = Optional[str]
Hashes = Tuple[Hash, Hash]

def _hash(x: np.ndarray) -> str:
return hashlib.md5(x).hexdigest()
Expand Down Expand Up @@ -133,16 +135,28 @@ def assign_hashes_v3(path) -> Hashes:
end_hash = _hash(x[start_of_end_interval:])
return start_hash, end_hash

def assign_hash_v4(path) -> Hash:
# Use this to create recognized hashes for new files
x, info = wav_to_np(path, info=True)
rate = info.rate
if rate != _V4_DATA_INFO.rate:
return None
# I don't care about anything in the file except the starting blip and
start_hash = _hash(x[: int(1 * _V4_DATA_INFO.rate)])
return start_hash

start_hash_v1, end_hash_v1 = assign_hashes_v1(path)
start_hash_v2, end_hash_v2 = assign_hashes_v2(path)
start_hash_v3, end_hash_v3 = assign_hashes_v3(path)
hash_v4 = assign_hash_v4(path)
return (
start_hash_v1,
end_hash_v1,
start_hash_v2,
end_hash_v2,
start_hash_v3,
end_hash_v3,
hash_v4,
)

(
Expand All @@ -152,6 +166,7 @@ def assign_hashes_v3(path) -> Hashes:
end_hash_v2,
start_hash_v3,
end_hash_v3,
hash_v4,
) = assign_hash(input_path)
print(
"Weak hashes:\n"
Expand All @@ -161,9 +176,11 @@ def assign_hashes_v3(path) -> Hashes:
f" End (v2) : {end_hash_v2}\n"
f" Start (v3) : {start_hash_v3}\n"
f" End (v3) : {end_hash_v3}\n"
f" Proteus : {hash_v4}\n"
)

# Check for matches, starting with most recent
# Check for matches, starting with most recent. Proteus last since its match is
# the most permissive.
version = {
(
"dadb5d62f6c3973a59bf01439799809b",
Expand Down Expand Up @@ -192,6 +209,9 @@ def assign_hashes_v3(path) -> Hashes:
"8458126969a3f9d8e19a53554eb1fd52",
): Version(1, 1, 1),
}.get((start_hash_v1, end_hash_v1))
if version is not None:
return version
version = {"46151c8030798081acc00a725325a07d": Version(4, 0, 0)}.get(hash_v4)
return version

version = detect_strong(input_path)
Expand All @@ -211,17 +231,6 @@ def assign_hashes_v3(path) -> Hashes:
class _DataInfo(BaseModel):
"""
:param major_version: Data major version
:param rate: Sample rate, in Hz
:param t_blips: How long the blips are, in samples
:param first_blips_start: When the first blips section starts, in samples
:param t_validate: Validation signal length, in samples
:param train_start: Where training signal starts, in samples.
:param validation_start: Where validation signal starts, in samples. Less than zero
(from the end of the array).
:param noise_interval: Inside which we quantify the noise level
:param blip_locations: In samples, absolute location in the file. Negative values
mean from the end instead of from the start (typical "Python" negastive
indexing).
"""

major_version: int
Expand Down Expand Up @@ -286,6 +295,30 @@ class _DataInfo(BaseModel):
noise_interval=(492_000, 498_000),
blip_locations=((504_000, 552_000),),
)
# V4 (aka GuitarML Proteus)
# https://github.com/GuitarML/Releases/releases/download/v1.0.0/Proteus_Capture_Utility.zip
# * 44.1k
# * Odd length...
# * There's a blip on sample zero. This has to be ignored or else over-compensated
# latencies will come out wrong!
# (0:00-0:01) Blips at 0:00.0 and 0:00.5
# (0:01-0:09) Sine sweeps
# (0:09-0:17) White noise
# (0:17:0.20) Rising white noise (to 0:20.333 appx)
# (0:20-3:30.858) General training data (ends on sample 9,298,872)
# I'm arbitrarily assigning the last 10 seconds as validation data.
_V4_DATA_INFO = _DataInfo(
major_version=4,
rate=44_100.0,
t_blips=44_099, # Need to ignore the first blip!
first_blips_start=1, # Need to ignore the first blip!
t_validate=441_000,
# Blips are problematic for training because they don't have preceding silence
train_start=44_100,
validation_start=-441_000,
noise_interval=(6_000, 12_000),
blip_locations=((22_050,),),
)

_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
Expand Down Expand Up @@ -393,6 +426,7 @@ def report_any_delay_warnings(delays: Sequence[int]):
_calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO)
_calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO)
_calibrate_delay_v3 = partial(_calibrate_delay_v_all, _V3_DATA_INFO)
_calibrate_delay_v4 = partial(_calibrate_delay_v_all, _V4_DATA_INFO)


def _plot_delay_v_all(
Expand Down Expand Up @@ -445,6 +479,7 @@ def _plot_delay_v_all(
_plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO)
_plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO)
_plot_delay_v3 = partial(_plot_delay_v_all, _V3_DATA_INFO)
_plot_delay_v4 = partial(_plot_delay_v_all, _V4_DATA_INFO)


def _calibrate_delay(
Expand All @@ -454,12 +489,16 @@ def _calibrate_delay(
output_path: str,
silent: bool = False,
) -> int:
"""
:param is_proteus: Forget the version; do"""
if input_version.major == 1:
calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
elif input_version.major == 2:
calibrate, plot = _calibrate_delay_v2, _plot_delay_v2
elif input_version.major == 3:
calibrate, plot = _calibrate_delay_v3, _plot_delay_v3
elif input_version.major == 4:
calibrate, plot = _calibrate_delay_v4, _plot_delay_v4
else:
raise NotImplementedError(
f"Input calibration not implemented for input version {input_version}"
Expand Down Expand Up @@ -654,6 +693,29 @@ def _check_v3(input_path, output_path, silent: bool, *args, **kwargs) -> bool:
return True


def _check_v4(input_path, output_path, silent: bool, *args, **kwargs) -> bool:
# Things we can't check:
# Latency compensation agreement
# Data replicability
print("Using Proteus audio file. Standard data checks aren't possible!")
signal, info = wav_to_np(output_path, info=True)
passed = True
if info.rate != _V4_DATA_INFO.rate:
print(
f"Output signal has sample rate {info.rate}; expected {_V4_DATA_INFO.rate}!"
)
passed = False
# I don't care what's in the files except that they're long enough to hold the blip
# and the last 10 seconds I decided to use as validation
required_length = int((1.0 + 10.0) * _V4_DATA_INFO.rate)
if len(signal) < required_length:
print(
"File doesn't meet the minimum length requirements for latency compensation and validation signal!"
)
passed = False
return passed


def _check(
input_path: str, output_path: str, input_version: Version, delay: int, silent: bool
) -> bool:
Expand All @@ -668,6 +730,8 @@ def _check(
f = _check_v2
elif input_version.major == 3:
f = _check_v3
elif input_version.major == 4:
f = _check_v4
else:
print(f"Checks not implemented for input version {input_version}; skip")
return True
Expand Down Expand Up @@ -821,13 +885,34 @@ def get_kwargs(data_info: _DataInfo):
train_stop = validation_start
train_kwargs = {"start": 480_000, "stop": train_stop}
validation_kwargs = {"start": validation_start}
elif data_info.major_version == 4:
validation_start = data_info.validation_start
train_stop = validation_start
train_kwargs = {"stop": train_stop}
# Proteus doesn't have silence to get a clean split. Bite the bullet.
print(
"Using Proteus files:\n"
" * There isn't a silent point to split the validation set, so some of "
"your gear's response from the train set will leak into the start of "
"the validation set and impact validation accuracy (Bypassing data "
"quality check)\n"
" * Since the validation set is different, the ESRs reported for this "
"model aren't comparable to those from the other 'NAM' training files."
)
validation_kwargs = {
"start": validation_start,
"require_input_pre_silence": False,
}
else:
raise NotImplementedError(f"kwargs for input version {input_version}")
return train_kwargs, validation_kwargs

data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO, 3: _V3_DATA_INFO}[
input_version.major
]
data_info = {
1: _V1_DATA_INFO,
2: _V2_DATA_INFO,
3: _V3_DATA_INFO,
4: _V4_DATA_INFO,
}[input_version.major]
train_kwargs, validation_kwargs = get_kwargs(data_info)
data_config = {
"train": {"ny": ny, **train_kwargs},
Expand Down Expand Up @@ -994,10 +1079,6 @@ def _nasty_checks_modal():
modal.mainloop()


# Example usage:
# show_modal("Hello, World!")


def train(
input_path: str,
output_path: str,
Expand Down
2 changes: 2 additions & 0 deletions tests/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest

__all__ = [
"requires_proteus",
"requires_v1_0_0",
"requires_v1_1_1",
"requires_v2_0_0",
Expand All @@ -33,6 +34,7 @@ def _requires_v(name: str):
requires_v1_1_1 = _requires_v("v1_1_1.wav")
requires_v2_0_0 = _requires_v("v2_0_0.wav")
requires_v3_0_0 = _requires_v("v3_0_0.wav")
requires_proteus = _requires_v("Proteus_Capture.wav")


def resource_path(name: str) -> Path:
Expand Down
50 changes: 36 additions & 14 deletions tests/test_nam/test_train/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nam.train._version import Version

from ...resources import (
requires_proteus,
requires_v1_0_0,
requires_v1_1_1,
requires_v2_0_0,
Expand All @@ -34,6 +35,8 @@
def _resource_path(version: Version) -> Path:
if version == Version(1, 0, 0):
name = "v1.wav"
elif version == Version(4, 0, 0):
name = "Proteus_Capture.wav"
else:
name = f'v{str(version).replace(".", "_")}.wav'
return resource_path(name)
Expand Down Expand Up @@ -167,24 +170,37 @@ class TestCalibrateDelayV3(_TCalibrateDelay):
_data_info = core._V3_DATA_INFO


class TestCalibrateDelayV4(_TCalibrateDelay):
_calibrate_delay = core._calibrate_delay_v4
_data_info = core._V4_DATA_INFO


def _make_t_validation_dataset_class(
version: Version, decorator, data_info: core._DataInfo
):
class C(object):
@decorator
def test_validation_preceded_by_silence(self):
"""
Validate that the datasets that we've made are valid
"""
x = wav_to_tensor(_resource_path(version))
Dataset._validate_preceding_silence(
x,
data_info.validation_start,
_DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
data_info.rate,
)

return C
pass

# Proteus has a bad validation split; don't define the silence test for it.
if version == Version(4, 0, 0):
return C
else:

class C2(C):
@decorator
def test_validation_preceded_by_silence(self):
"""
Validate that the datasets that we've made are valid
"""
x = wav_to_tensor(_resource_path(version))
Dataset._validate_preceding_silence(
x,
data_info.validation_start,
_DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
data_info.rate,
)

return C2


TestValidationDatasetV1_0_0 = _make_t_validation_dataset_class(
Expand All @@ -207,6 +223,12 @@ def test_validation_preceded_by_silence(self):
)


# Aka Proteus
TestValidationDatasetV4_0_0 = _make_t_validation_dataset_class(
Version(4, 0, 0), requires_proteus, core._V4_DATA_INFO
)


def test_v3_check_doesnt_make_figure_if_silent(mocker):
"""
Issue 337
Expand Down

0 comments on commit 16a2108

Please sign in to comment.