Skip to content

Commit

Permalink
skip test if nnpops is not installed
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Dec 3, 2024
1 parent 0b94d88 commit fd329ee
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/test_cfconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,32 @@

import pytest
from pytest import mark
import torch as pt
from torchmdnet.models.torchmd_gn import CFConv as RefCFConv
from torchmdnet.models.utils import OptimizedDistance, GaussianSmearing, ShiftedSoftplus

from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors
try:
import NNPOps

nnpops_available = True
except ImportError:
nnpops_available = False


@pytest.mark.skipif(not nnpops_available, reason="NNPOps not available")
@mark.parametrize("device", ["cpu", "cuda"])
@mark.parametrize(
["num_atoms", "num_filters", "num_rbfs"],
[(3, 5, 7), (3, 7, 5), (5, 3, 7), (5, 7, 3), (7, 3, 5), (7, 5, 3)],
)
@mark.parametrize("cutoff_upper", [5.0, 10.0])
def test_cfconv(device, num_atoms, num_filters, num_rbfs, cutoff_upper):
import torch as pt
from torchmdnet.models.torchmd_gn import CFConv as RefCFConv
from torchmdnet.models.utils import (
OptimizedDistance,
GaussianSmearing,
ShiftedSoftplus,
)
from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors

if not pt.cuda.is_available() and device == "cuda":
pytest.skip("No GPU")
Expand Down

0 comments on commit fd329ee

Please sign in to comment.