diff --git a/tests/test_cfconv.py b/tests/test_cfconv.py index 20b3576a..d21c7edf 100644 --- a/tests/test_cfconv.py +++ b/tests/test_cfconv.py @@ -4,14 +4,16 @@ import pytest from pytest import mark -import torch as pt -from torchmdnet.models.torchmd_gn import CFConv as RefCFConv -from torchmdnet.models.utils import OptimizedDistance, GaussianSmearing, ShiftedSoftplus -from NNPOps.CFConv import CFConv -from NNPOps.CFConvNeighbors import CFConvNeighbors +try: + import NNPOps + nnpops_available = True +except ImportError: + nnpops_available = False + +@pytest.mark.skipif(not nnpops_available, reason="NNPOps not available") @mark.parametrize("device", ["cpu", "cuda"]) @mark.parametrize( ["num_atoms", "num_filters", "num_rbfs"], @@ -19,6 +21,15 @@ ) @mark.parametrize("cutoff_upper", [5.0, 10.0]) def test_cfconv(device, num_atoms, num_filters, num_rbfs, cutoff_upper): + import torch as pt + from torchmdnet.models.torchmd_gn import CFConv as RefCFConv + from torchmdnet.models.utils import ( + OptimizedDistance, + GaussianSmearing, + ShiftedSoftplus, + ) + from NNPOps.CFConv import CFConv + from NNPOps.CFConvNeighbors import CFConvNeighbors if not pt.cuda.is_available() and device == "cuda": pytest.skip("No GPU")