Skip to content

Commit

Permalink
try importing visnet, implement has_visnet_flag
Browse files Browse the repository at this point in the history
  • Loading branch information
fyng committed Jan 26, 2024
1 parent be05591 commit 7be95b7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mtenn/conversion_utils/visnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from copy import deepcopy
import torch
from torch.autograd import grad
from torch_geometric.nn.models import ViSNet as PygVisNet
from torch_geometric.utils import scatter

from mtenn.model import GroupedModel, Model
from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy


HAS_VISNET_FLAG = False
try:
from torch_geometric.nn.models import ViSNet as PygVisNet
HAS_VISNET_FLAG = True
except ImportError:
pass

class EquivariantVecToScaler(torch.nn.Module):
# Wrapper for PygVisNet.EquivariantScalar to implement forward() method
Expand Down

0 comments on commit 7be95b7

Please sign in to comment.