Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Workaround for PyTorch MPS bug with sequences longer than 65,536 samples #506

Merged
merged 9 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:


class BaseNet(_Base):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._mps_65536_fallback = False

def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
Expand All @@ -179,16 +183,53 @@ def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
x = torch.cat(
(torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
)
y = self._forward(x, **kwargs)
if x.shape[1] < self.receptive_field:
raise ValueError(
f"Input has {x.shape[1]} samples, which is too few for this model with "
f"receptive field {self.receptive_field}!"
)
y = self._forward_mps_safe(x, **kwargs)
if scalar:
y = y[0]
return y

def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
return self(x)

def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Wrap `._forward()` to protect against MPS-unsupported inptu lengths
beyond 65,536 samples.

Check this again when PyTorch 2.5.2 is released--hopefully it's fixed
then.
"""
if not self._mps_65536_fallback:
try:
return self._forward(x, **kwargs)
except NotImplementedError as e:
if "Output channels > 65536 not supported at the MPS device." in str(e):
self._mps_65536_fallback = True
return self._forward_mps_safe(x, **kwargs)
else:
raise e
else:
# Stitch together the output one piece at a time to avoid the MPS error
stride = 65_536 - (self.receptive_field - 1)
# We need to make sure that the last segment is big enough that we have the required history for the receptive field.
out_list = []
for i in range(0, x.shape[1], stride):
j = min(i+65_536, x.shape[1])
xi = x[:, i:j]
out_list.append(self._forward(xi, **kwargs))
# Bit hacky, but correct.
if j == x.shape[1]:
break
return torch.cat(out_list, dim=1)


@abc.abstractmethod
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
The true forward method.

Expand Down
46 changes: 9 additions & 37 deletions nam/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
Linear model
"""

import json
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
Expand All @@ -30,38 +27,6 @@ def pad_start_default(self) -> bool:
def receptive_field(self) -> int:
return self._net.weight.shape[2]

def export(self, outdir: Path):
training = self.training
self.eval()
with open(Path(outdir, "config.json"), "w") as fp:
json.dump(
{
"version": __version__,
"architecture": self.__class__.__name__,
"config": {
"receptive_field": self.receptive_field,
"bias": self._bias,
},
},
fp,
indent=4,
)

params = [self._net.weight.flatten()]
if self._bias:
params.append(self._net.bias.flatten())
params = torch.cat(params).detach().cpu().numpy()
# Hope I don't regret using np.save...
np.save(Path(outdir, "weights.npy"), params)

# And an input/output to verify correct computation:
x, y = self._export_input_output()
np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())

# And resume training state
self.train(training)

def export_cpp_header(self):
raise NotImplementedError()

Expand All @@ -73,7 +38,14 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor:
return self._net(x[:, None])[:, 0]

def _export_config(self):
raise NotImplementedError()
return {
"receptive_field": self.receptive_field,
"bias": self._bias,
}

def _export_weights(self) -> np.ndarray:
raise NotImplementedError()
params_list = [self._net.weight.flatten()]
if self._bias:
params_list.append(self._net.bias.flatten())
params = torch.cat(params_list).detach().cpu().numpy()
return params
30 changes: 30 additions & 0 deletions tests/test_nam/test_models/_convolutional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# File: _conv_mixin.py
# Created Date: Saturday November 23rd 2024
# Author: Steven Atkinson (steven@atkinson.mn)

"""
Mix-in tests for models with a convolution layer
"""

import pytest as _pytest
import torch as _torch

from .base import Base as _Base


class Convolutional(_Base):
@_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test")
def test_process_input_longer_than_65536(self):
"""
Processing inputs longer than 65,536 samples using the MPS backend can
cause problems.

See: https://github.com/sdatkinson/neural-amp-modeler/issues/505

Assert that precautions are taken.
"""

x = _torch.zeros((65_536 + 1,)).to("mps")

model = self._construct().to("mps")
model(x)
15 changes: 6 additions & 9 deletions tests/test_nam/test_models/test_conv_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
# Created Date: Friday May 6th 2022
# Author: Steven Atkinson (steven@atkinson.mn)

from pathlib import Path
from tempfile import TemporaryDirectory

import pytest
import pytest as _pytest

from nam.models import conv_net

from .base import Base
from ._convolutional import Convolutional as _Convolutional


class TestConvNet(Base):
class TestConvNet(_Convolutional):
@classmethod
def setup_class(cls):
channels = 3
Expand All @@ -23,18 +20,18 @@ def setup_class(cls):
{"batchnorm": False, "activation": "Tanh"},
)

@pytest.mark.parametrize(
@_pytest.mark.parametrize(
("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
)
def test_init(self, batchnorm, activation):
super().test_init(kwargs={"batchnorm": batchnorm, "activation": activation})

@pytest.mark.parametrize(
@_pytest.mark.parametrize(
("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
)
def test_export(self, batchnorm, activation):
super().test_export(kwargs={"batchnorm": batchnorm, "activation": activation})


if __name__ == "__main__":
pytest.main()
_pytest.main()
18 changes: 18 additions & 0 deletions tests/test_nam/test_models/test_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# File: test_linear.py
# Created Date: Saturday November 23rd 2024
# Author: Steven Atkinson (steven@atkinson.mn)

import pytest as _pytest

from nam.models import linear as _linear

from ._convolutional import Convolutional as _Convolutional


class TestLinear(_Convolutional):
@classmethod
def setup_class(cls):
C = _linear.Linear
args = ()
kwargs = {"receptive_field": 2, "sample_rate": 44100}
super().setup_class(C, args, kwargs)
29 changes: 25 additions & 4 deletions tests/test_nam/test_models/test_wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,28 @@
from nam.models.wavenet import WaveNet
from nam.train.core import Architecture, get_wavenet_config

from ._convolutional import Convolutional as _Convolutional


class TestWaveNet(_Convolutional):
@classmethod
def setup_class(cls):
C = WaveNet
args = ()
kwargs = {
"layers_configs": [
{
"input_size": 1,
"condition_size": 1,
"head_size": 1,
"channels": 1,
"kernel_size": 1,
"dilations": [1]
}
]
}
super().setup_class(C, args, kwargs)

# from .base import Base


class TestWaveNet(object):
def test_import_weights(self):
config = get_wavenet_config(Architecture.FEATHER)
model_1 = WaveNet.init_from_config(config)
Expand All @@ -29,3 +46,7 @@ def test_import_weights(self):

assert not torch.allclose(y2_before, y1)
assert torch.allclose(y2_after, y1)


if __name__ == "__main__":
pytest.main()
Loading