diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index b15ac85..7745f2f 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -4,12 +4,17 @@ from copy import deepcopy import torch from torch.autograd import grad -from torch_geometric.nn.models import ViSNet as PygVisNet from torch_geometric.utils import scatter from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy - + +HAS_VISNET_FLAG = False +try: + from torch_geometric.nn.models import ViSNet as PygVisNet + HAS_VISNET_FLAG = True +except ImportError: + pass class EquivariantVecToScaler(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method