From 7529e3f6db4d97e0b87f3c168c59cd9718bb78d1 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Wed, 18 Dec 2024 18:46:05 -0800 Subject: [PATCH] Better warning message --- nam/models/base.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/nam/models/base.py b/nam/models/base.py index 09e65b1..c27dadd 100644 --- a/nam/models/base.py +++ b/nam/models/base.py @@ -176,6 +176,10 @@ def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]: ) +def _get_torch_version() -> str: + return _torch.__version__ + + class BaseNet(_Base): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -217,17 +221,25 @@ def _forward_mps_safe(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor: return self._forward(x, **kwargs) except NotImplementedError as e: if "Output channels > 65536 not supported at the MPS device." in str(e): - print( - "===WARNING===\n" - "NAM encountered a bug in PyTorch's MPS backend and will " - "switch to a fallback.\n" - f"Your version of PyTorch is {_torch.__version__}.\n" - "Please report this in an Issue at:\n" - "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose" - "\n" - "so that NAM's dependencies can avoid buggy versions of " - "PyTorch and the associated performance hit." + msg = ( + "Warning: NAM encountered a bug in PyTorch's MPS backend and " + "will switch to a fallback." ) + known_bad_versions = {"2.5.0", "2.5.1"} + torch_version = _get_torch_version() + if torch_version not in known_bad_versions: + msg += ( + "\n" + f"Your version of PyTorch is {torch_version}, which " + "wasn't known to have this problem.\n" + "Please open an Issue at:\n" + "https://github.com/sdatkinson/neural-amp-modeler/issues/507" + "\n" + f"and report your PyTorch version ({torch_version}) " + "so that we can keep track of versions of PyTorch that " + "might be avoided." + ) + print(msg) self._mps_65536_fallback = True return self._forward_mps_safe(x, **kwargs) else: