diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index b3900d6..41be6da 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -22,13 +22,14 @@ defaults: jobs: test: - name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} + name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}, Env ${{ matrix.deps }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [macOS-latest, ubuntu-latest] python-version: ["3.10", "3.11"] + deps: ["devtools/conda-envs/test_env.yaml", "devtools/conda-envs/test_env-nightly.yaml"] steps: - name: Checkout Repository @@ -45,7 +46,7 @@ jobs: - name: Setup Micromamba uses: mamba-org/setup-micromamba@v1 with: - environment-file: devtools/conda-envs/test_env.yaml + environment-file: ${{ matrix.deps }} environment-name: test create-args: >- python==${{ matrix.python-version }} @@ -68,5 +69,5 @@ jobs: with: file: ./coverage.xml flags: unittests - name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} + name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}-env${{ matrix.deps }} fail_ci_if_error: false diff --git a/devtools/conda-envs/test_env-nightly.yaml b/devtools/conda-envs/test_env-nightly.yaml new file mode 100644 index 0000000..ba2d308 --- /dev/null +++ b/devtools/conda-envs/test_env-nightly.yaml @@ -0,0 +1,24 @@ +name: test +channels: + - conda-forge +dependencies: + - pip + - pytorch + - pytorch_cluster + - pytorch_scatter + - pytorch_sparse + - numpy + - h5py + - e3nn + - dgllife + - dgl + - rdkit + - ase + # testing dependencies + - pytest + - pytest-cov + - codecov + - pydantic >=1.10.8,<2.0.0a0 + + - pip: + - pyg-nightly \ No newline at end of file diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index b8221be..9704c0f 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -18,4 +18,4 @@ dependencies: - pytest - pytest-cov - codecov - - pydantic >=1.10.8,<2.0.0a0 + - pydantic >=1.10.8,<2.0.0a0 \ No newline at end of file diff --git a/mtenn/config.py b/mtenn/config.py index bf92c73..bc56e19 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, Field, root_validator import random from typing import Callable, ClassVar - import mtenn import numpy as np import torch @@ -39,7 +38,7 @@ class ModelType(StringEnum): schnet = "schnet" e3nn = "e3nn" INVALID = "INVALID" - + visnet = "visnet" class StrategyConfig(StringEnum): """ @@ -749,3 +748,126 @@ def _build(self, mtenn_params={}): pred_readout=pred_readout, comb_readout=comb_readout, ) + + +class ViSNetModelConfig(ModelConfigBase): + """ + Class for constructing a VisNet ML model. Default values here are the default values + given in PyG. + """ + + model_type: ModelType = Field(ModelType.visnet, const=True) + lmax: int = Field(1, description="The maximum degree of the spherical harmonics.") + vecnorm_type: str | None = Field( + None, description="The type of normalization to apply to the vectors." + ) + trainable_vecnorm: bool = Field( + False, description="Whether the normalization weights are trainable." + ) + num_heads: int = Field(8, description="The number of attention heads.") + num_layers: int = Field(6, description="The number of layers in the network.") + hidden_channels: int = Field( + 128, description="The number of hidden channels in the node embeddings." + ) + num_rbf: int = Field(32, description="The number of radial basis functions.") + trainable_rbf: bool = Field( + False, description="Whether the radial basis function parameters are trainable." + ) + max_z: int = Field(100, description="The maximum atomic numbers.") + cutoff: float = Field(5.0, description="The cutoff distance.") + max_num_neighbors: int = Field( + 32, + description="The maximum number of neighbors considered for each atom." + ) + vertex: bool = Field( + False, + description="Whether to use vertex geometric features." + ) + atomref: list[float] | None = Field( + None, + description=( + "Reference values for single-atom properties. Should have length max_z" + ) + ) + reduce_op: str = Field( + "sum", + description="The type of reduction operation to apply. ['sum', 'mean']" + ) + mean: float = Field(0.0, description="The mean of the output distribution.") + std: float = Field(1.0, description="The standard deviation of the output distribution.") + derivative: bool = Field( + False, + description="Whether to compute the derivative of the output with respect to the positions." + ) + + @root_validator(pre=False) + def validate(cls, values): + # Make sure the grouped stuff is properly assigned + ModelConfigBase._check_grouped(values) + + # Make sure atomref length is correct (this is required by PyG) + atomref = values["atomref"] + if (atomref is not None) and (len(atomref) != values["max_z"]): + raise ValueError(f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})") + + return values + + + + def _build(self, mtenn_params={}): + """ + Build an MTENN ViSNet Model from this config. + + Parameters + ---------- + mtenn_params: dict + Dict giving the MTENN Readout. This will be passed by the `build` method in + the abstract base class + + Returns + ------- + mtenn.model.Model + MTENN ViSNet Model/GroupedModel + """ + # Create an MTENN ViSNet model from PyG ViSNet model + + from mtenn.conversion_utils.visnet import HAS_VISNET + if HAS_VISNET: + from mtenn.conversion_utils import ViSNet + + model = ViSNet( + lmax=self.lmax, + vecnorm_type=self.vecnorm_type, + trainable_vecnorm=self.trainable_vecnorm, + num_heads=self.num_heads, + num_layers=self.num_layers, + hidden_channels=self.hidden_channels, + num_rbf=self.num_rbf, + trainable_rbf=self.trainable_rbf, + max_z=self.max_z, + cutoff=self.cutoff, + max_num_neighbors=self.max_num_neighbors, + vertex=self.vertex, + reduce_op=self.reduce_op, + mean=self.mean, + std=self.std, + derivative=self.derivative, + atomref=self.atomref, + ) + combination = mtenn_params.get("combination", None) + pred_readout = mtenn_params.get("pred_readout", None) + comb_readout = mtenn_params.get("comb_readout", None) + + return ViSNet.get_model( + model=model, + grouped=self.grouped, + fix_device=True, + strategy=self.strategy, + combination=combination, + pred_readout=pred_readout, + comb_readout=comb_readout, + ) + + else: + raise ImportError("ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42.") + diff --git a/mtenn/conversion_utils/__init__.py b/mtenn/conversion_utils/__init__.py index f9c3001..58d8eab 100644 --- a/mtenn/conversion_utils/__init__.py +++ b/mtenn/conversion_utils/__init__.py @@ -2,4 +2,9 @@ from .gat import GAT from .schnet import SchNet -__all__ = ["E3NN", "GAT", "SchNet"] +# refer to issue #42 +from .visnet import HAS_VISNET +if HAS_VISNET: + from .visnet import ViSNet + +__all__ = ["E3NN", "GAT", "SchNet", "ViSNet"] diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py new file mode 100644 index 0000000..ac31032 --- /dev/null +++ b/mtenn/conversion_utils/visnet.py @@ -0,0 +1,225 @@ +""" +Representation and strategy for ViSNet model. +""" +import warnings +from copy import deepcopy +import torch +from torch.autograd import grad +from torch_geometric.utils import scatter + +from mtenn.model import GroupedModel, Model +from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy + +# guard required: currently require PyG nightly 2.5.0 (SEE ISSUE #42) +HAS_VISNET = False + +try: + from torch_geometric.nn.models import ViSNet as PygVisNet + from torch_geometric.nn.models.visnet import ViS_MP_Vertex + HAS_VISNET = True +except ImportError: + warnings.warn("VisNet import error. Is your PyG >=2.5.0? Refer to issue #42", ImportWarning) + +class EquivariantVecToScalar(torch.nn.Module): + # Wrapper for PygVisNet.EquivariantScalar to implement forward() method + def __init__(self, mean, reduce_op): + super(EquivariantVecToScalar, self).__init__() + self.mean = mean + self.reduce_op = reduce_op + def forward(self, x): + # dummy variable. all atoms from the same molecule and the same batch + batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) + y = scatter(x, batch, dim=0, reduce=self.reduce_op) + return y + self.mean + + +if HAS_VISNET: + class ViSNet(torch.nn.Module): + def __init__(self, *args, model=None, **kwargs): + super().__init__() + ## If no model is passed, construct default ViSNet model, otherwise copy + ## all parameters and weights over + if model is None: + self.visnet = PygVisNet(*args, **kwargs) + else: + atomref = model.prior_model.atomref.weight.detach().clone() + model_params = { + 'lmax': model.representation_model.lmax, + 'vecnorm_type': model.representation_model.vecnorm_type, + 'trainable_vecnorm': model.representation_model.trainable_vecnorm, + 'num_heads': model.representation_model.num_heads, + 'num_layers': model.representation_model.num_layers, + 'hidden_channels': model.representation_model.hidden_channels, + 'num_rbf': model.representation_model.num_rbf, + 'trainable_rbf': model.representation_model.trainable_rbf, + 'max_z': model.representation_model.max_z, + 'cutoff': model.representation_model.cutoff, + 'reduce_op': model.representation_model.max_num_neighbors, + 'vertex': isinstance(model.representation_model.vis_mp_layers[0], ViS_MP_Vertex), + 'reduce_op': model.reduce_op, + 'mean': model.mean, + 'std': model.std, + 'derivative': model.derivative, # not used. originally calculates "force" from energy + 'atomref': atomref, + } + self.visnet = PygVisNet(**model_params) + self.visnet.load_state_dict(model.state_dict()) + + self.readout = EquivariantVecToScalar(self.visnet.mean, self.visnet.reduce_op) + + def forward(self, data): + """ + Computes the vector representation of a molecule from its atomic numbers and positional coordinates, to produce a vector representation. The output vector can be used for further molecular analysis or prediction tasks. + + Parameters + ---------- + data : dict[str, torch.Tensor] + A dictionary containing the atomic point clouds of a molecule. It should have the following keys: + - z (torch.Tensor): The atomic numbers of the atoms in the molecule. + - pos (torch.Tensor): The 3D coordinates (x, y, z) of each atom in the molecule. + + Returns + ------- + torch.Tensor + A tensor representing the vector output of the molecule + + Notes + ----- + Assumes all atoms passed in `data` belong to the same molecule and are processed in a single batch. + """ + pos = data["pos"] + z = data["z"] + + # all atom in one pass from the same molecule + batch = torch.zeros(z.shape[0], device=z.device) + x, v = self.visnet.representation_model(z, pos, batch) + x = self.visnet.output_model.pre_reduce(x, v) + x = x * self.visnet.std + if self.visnet.prior_model is not None: + x = self.visnet.prior_model(x, z) + + return x + + def _get_representation(self): + """ + Input model, remove last layer. + + Returns + ------- + ViSNet + Copied ViSNet model, removing the last MLP layer that takes vector representation to scalar output. + """ + + ## Copy model so initial model isn't affected + model_copy = deepcopy(self) + + return model_copy + + def _get_energy_func(self): + """ + Return last layer of the model (outputs scalar value) + + Returns + ------- + torch.nn.Module + Copy of `model`'s last layer, which is an instance of EquivariantVecToScalar() class + """ + return deepcopy(self.readout) + + def _get_delta_strategy(self): + """ + Build a DeltaStrategy object based on the passed model. + + Returns + ------- + DeltaStrategy + DeltaStrategy built from `model` + """ + + return DeltaStrategy(self._get_energy_func()) + + def _get_complex_only_strategy(self): + """ + Build a ComplexOnlyStrategy object based on the passed model. + + Returns + ------- + ComplexOnlyStrategy + ComplexOnlyStrategy built from `self` + """ + + return ComplexOnlyStrategy(self._get_energy_func()) + + @staticmethod + def get_model( + model=None, + grouped=False, + fix_device=False, + strategy: str = "delta", + combination=None, + pred_readout=None, + comb_readout=None, + ): + """ + Exposed function to build a Model object from a VisNet object. If none + is provided, a default model is initialized. + + Parameters + ---------- + model: VisNet, optional + VisNet model to use to build the Model object. If left as none, a default model will be initialized and used + grouped: bool, default=False + Whether this model should accept groups of inputs or one input at a time. + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. + strategy: str, default='delta' + Strategy to use to combine representation of the different parts. + Options are ['delta', 'concat', 'complex'] + combination: Combination, optional + Combination object to use to combine predictions in a group. A value must be passed if `grouped` is `True`. + pred_readout : Readout + Readout object for the energy predictions. If `grouped` is `False`, this option will still be used in the construction of the `Model` object. + comb_readout : Readout + Readout object for the combination output. + + Returns + ------- + Model + Model object containing the desired Representation and Strategy + """ + if model is None: + model = ViSNet() + + ## First get representation module + representation = model._get_representation() + + ## Construct strategy module based on model and + ## representation (if necessary) + strategy = strategy.lower() + if strategy == "delta": + strategy = model._get_delta_strategy() + elif strategy == "concat": + strategy = ConcatStrategy() + elif strategy == "complex": + strategy = model._get_complex_only_strategy() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + ## Check on `combination` + if grouped and (combination is None): + raise ValueError( + "Must pass a value for `combination` if `grouped` is `True`." + ) + + if grouped: + return GroupedModel( + representation, + strategy, + combination, + pred_readout, + comb_readout, + fix_device, + ) + else: + return Model(representation, strategy, pred_readout, fix_device) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index dd16da7..2ef1919 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -1,5 +1,6 @@ -from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig - +from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig, ViSNetModelConfig +from mtenn.conversion_utils.visnet import HAS_VISNET +import pytest def test_random_seed_gat(): rand_config = GATModelConfig() @@ -65,3 +66,58 @@ def test_random_seed_schnet(): for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) ] assert sum(set_equal) == len(set_equal) + +@pytest.mark.skipif(not HAS_VISNET, reason="requires VisNet from nightly PyG") +def test_random_seed_visnet(): + rand_config = ViSNetModelConfig() + set_config = ViSNetModelConfig(rand_seed=10) + + rand_model1 = rand_config.build() + rand_model2 = rand_config.build() + set_model1 = set_config.build() + set_model2 = set_config.build() + + rand_equal = [ + (p1 == p2).all() + for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + ] + assert sum(rand_equal) < len(rand_equal) + + set_equal = [ + (p1 == p2).all() + for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) + ] + assert sum(set_equal) == len(set_equal) + +@pytest.mark.skipif(not HAS_VISNET, reason="requires VisNet from nightly PyG") +def test_visnet_from_pyg(): + from torch_geometric.nn.models import ViSNet as PyVisNet + from mtenn.conversion_utils import ViSNet + model_params={ + 'lmax': 1, + 'vecnorm_type': None, + 'trainable_vecnorm': False, + 'num_heads': 8, + 'num_layers': 6, + 'hidden_channels': 128, + 'num_rbf': 32, + 'trainable_rbf': False, + 'max_z': 100, + 'cutoff': 5.0, + 'max_num_neighbors': 32, + 'vertex': False, + 'reduce_op': "sum", + 'mean': 0.0, + 'std': 1.0, + 'derivative': False, + 'atomref': None, + } + + pyg_model = PyVisNet(**model_params) + visnet_model = ViSNet(model=pyg_model) + + params_equal = [ + (p1 == p2).all() + for p1, p2 in zip(pyg_model.parameters(), visnet_model.parameters()) + ] + assert sum(params_equal) == len(params_equal)