From de57edab89742605418443a2541dae554e18c723 Mon Sep 17 00:00:00 2001 From: fyng Date: Tue, 23 Jan 2024 15:15:29 -0500 Subject: [PATCH 01/42] add visnet template --- mtenn/config.py | 120 ++++++++++++++++++- mtenn/conversion_utils/visnet.py | 195 +++++++++++++++++++++++++++++++ 2 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 mtenn/conversion_utils/visnet.py diff --git a/mtenn/config.py b/mtenn/config.py index bf92c73..4a87d84 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -39,7 +39,7 @@ class ModelType(StringEnum): schnet = "schnet" e3nn = "e3nn" INVALID = "INVALID" - + visnet = "visnet" class StrategyConfig(StringEnum): """ @@ -749,3 +749,121 @@ def _build(self, mtenn_params={}): pred_readout=pred_readout, comb_readout=comb_readout, ) + + + +#FIXME: Update to ViSNet API +class ViSNetModelConfig(ModelConfigBase): + """ + Class for constructing a SchNet ML model. Default values here are the default values + given in PyG. + """ + + model_type: ModelType = Field(ModelType.visnet, const=True) + lmax: int = Field(1, description="The maximum degree of the spherical harmonics.") + vecnorm_type: str | None = Field( + None, description="The type of normalization to apply to the vectors." + ) + trainable_vecnorm: bool = Field( + False, description="Whether the normalization weights are trainable." + ) + num_heads: int = Field(8, description="The number of attention heads.") + num_layers: int = Field(6, description="The number of layers in the network.") + hidden_channels: int = Field( + 128, description="The number of hidden channels in the node embeddings." + ) + num_rbf: int = Field(32, description="The number of radial basis functions.") + trainable_rbf: bool = Field( + False, description="Whether the radial basis function parameters are trainable." + ) + max_z: int = Field(100, description="The maximum atomic numbers.") + cutoff: float = Field(5.0, description="The cutoff distance.") + max_num_neighbors: int = Field( + 32, + description="The maximum number of neighbors considered for each atom." + ) + vertex: bool = Field( + False, + description="Whether to use vertex geometric features." + ) + atomref: torch.Tensor | None = Field( + None, + description=( + "Reference values for single-atom properties. Should have length of 100 to " + "match with PyG." + ) + ) + reduce_op: str = Field( + "sum", + description="The type of reduction operation to apply. ['sum', 'mean']" + ) + mean: float | None = Field(0.0, description="The mean of the output distribution.") + std: float | None = Field(1.0, description="The standard deviation of the output distribution.") + derivative: bool = Field( + False, + description="Whether to compute the derivative of the output with respect to the positions." + ) + + @root_validator(pre=False) + def validate(cls, values): + # Make sure the grouped stuff is properly assigned + ModelConfigBase._check_grouped(values) + + # Make sure atomref length is correct (this is required by PyG) + atomref = values["atomref"] + if (atomref is not None) and (len(atomref) != 100): + raise ValueError(f"atomref must be length 100 (got {len(atomref)})") + + return values + + def _build(self, mtenn_params={}): + """ + Build an MTENN ViSNet Model from this config. + + Parameters + ---------- + mtenn_params: dict + Dict giving the MTENN Readout. This will be passed by the `build` method in + the abstract base class + + Returns + ------- + mtenn.model.Model + MTENN ViSNet Model/GroupedModel + """ + from mtenn.conversion_utils import ViSNet + + # Create an MTENN ViSNet model from PyG ViSNet model + model = ViSNet( + lmax=self.lmax, + vecnorm_type=self.vecnorm_type, + trainable_vector_norm=self.trainable_vecnorm, + num_heads=self.num_heads, + num_layers=self.num_layers, + hidden_channels=self.hidden_channels, + num_rbf=self.num_rbf, + trainable_rbf=self.trainable_rbf, + max_z=self.max_z, + cutoff=self.cutoff, + max_num_neighbors=self.max_num_neighbors, + vertex=self.vertex, + reduce_op=self.reduce_op, + mean=self.mean, + std=self.std, + derivative=self.derivative, + atomref=self.atomref, + ) + + combination = mtenn_params.get("combination", None) + pred_readout = mtenn_params.get("pred_readout", None) + comb_readout = mtenn_params.get("comb_readout", None) + + return ViSNet.get_model( + model=model, + grouped=self.grouped, + fix_device=True, + strategy=self.strategy, + combination=combination, + pred_readout=pred_readout, + comb_readout=comb_readout, + ) \ No newline at end of file diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py new file mode 100644 index 0000000..c7d3459 --- /dev/null +++ b/mtenn/conversion_utils/visnet.py @@ -0,0 +1,195 @@ +""" +Representation and strategy for ViSNet model. +""" +from copy import deepcopy +import torch +from torch_geometric.nn.models import ViSNet as PygVisNet + +from mtenn.model import GroupedModel, Model +from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy + + +class ViSNet(PygVisNet): + def __init__(self, *args, model=None, **kwargs): + ## If no model is passed, construct default SchNet model, otherwise copy + ## all parameters and weights over + if model is None: + super(ViSNet, self).__init__(*args, **kwargs) + else: + try: + atomref = model.atomref.weight.detach().clone() + except AttributeError: + atomref = None + model_params = ( + model.lmax, + model.vecnorm_type, + model.trainable_vecnorm, + model.num_heads, + model.num_layers, + model.hidden_channels, + model.num_rbf, + model.trainable_rbf, + model.max_z, + model.cutoff, + model.max_num_neighbors, + model.vertex, + model.reduce_op, + model.mean, + model.std, + model.derivative, + atomref, + ) + super(ViSNet, self).__init__(*model_params) + self.load_state_dict(model.state_dict()) + + def forward(self, data): + return super(ViSNet, self).forward(data["z"], data["pos"]) + + def _get_representation(self): + """ + Input model, remove last layer. + + Parameters + ---------- + model: SchNet + SchNet model + + Returns + ------- + SchNet + Copied SchNet model with the last layer replaced by an Identity module + """ + + ## Copy model so initial model isn't affected + model_copy = deepcopy(self) + + # FIXME: change to visnet X + ## Replace final linear layer with an identity module + model_copy.lin2 = torch.nn.Identity() + + return model_copy + + def _get_energy_func(self): + """ + Return last layer of the model. + + Parameters + ---------- + model: SchNet + SchNet model + + Returns + ------- + torch.nn.modules.linear.Linear + Copy of `model`'s last layer + """ + # FIXME: change to visnet X + return deepcopy(self.lin2) + + def _get_delta_strategy(self): + """ + Build a DeltaStrategy object based on the passed model. + + Parameters + ---------- + model: SchNet + SchNet model + + Returns + ------- + DeltaStrategy + DeltaStrategy built from `model` + """ + + return DeltaStrategy(self._get_energy_func()) + + def _get_complex_only_strategy(self): + """ + Build a ComplexOnlyStrategy object based on the passed model. + + Returns + ------- + ComplexOnlyStrategy + ComplexOnlyStrategy built from `self` + """ + + return ComplexOnlyStrategy(self._get_energy_func()) + + @staticmethod + def get_model( + model=None, + grouped=False, + fix_device=False, + strategy: str = "delta", + combination=None, + pred_readout=None, + comb_readout=None, + ): + """ + Exposed function to build a Model object from a SchNet object. If none + is provided, a default model is initialized. + + Parameters + ---------- + model: SchNet, optional + SchNet model to use to build the Model object. If left as none, a + default model will be initialized and used + grouped: bool, default=False + Whether this model should accept groups of inputs or one input at a + time. + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. + strategy: str, default='delta' + Strategy to use to combine representation of the different parts. + Options are ['delta', 'concat', 'complex'] + combination: Combination, optional + Combination object to use to combine predictions in a group. A value + must be passed if `grouped` is `True`. + pred_readout : Readout + Readout object for the energy predictions. If `grouped` is `False`, + this option will still be used in the construction of the `Model` + object. + comb_readout : Readout + Readout object for the combination output. + + Returns + ------- + Model + Model object containing the desired Representation and Strategy + """ + if model is None: + model = ViSNet() + + ## First get representation module + representation = model._get_representation() + + ## Construct strategy module based on model and + ## representation (if necessary) + strategy = strategy.lower() + if strategy == "delta": + strategy = model._get_delta_strategy() + elif strategy == "concat": + strategy = ConcatStrategy() + elif strategy == "complex": + strategy = model._get_complex_only_strategy() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + ## Check on `combination` + if grouped and (combination is None): + raise ValueError( + "Must pass a value for `combination` if `grouped` is `True`." + ) + + if grouped: + return GroupedModel( + representation, + strategy, + combination, + pred_readout, + comb_readout, + fix_device, + ) + else: + return Model(representation, strategy, pred_readout, fix_device) From eadb44c88271b256081d06e13b85e1372e7cb679 Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 24 Jan 2024 16:32:16 -0500 Subject: [PATCH 02/42] ViSNet outputs vector embedding --- mtenn/conversion_utils/visnet.py | 70 +++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 11 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index c7d3459..634c221 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -3,18 +3,42 @@ """ from copy import deepcopy import torch +from torch.autograd import grad from torch_geometric.nn.models import ViSNet as PygVisNet +from torch_geometric.utils import scatter from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy +class UpdateAtoms(torch.nn.Module): + def __init__(self, prior_model, std): + super(UpdateAtoms, self).__init__() + self.prior_model = prior_model + self.std = std -class ViSNet(PygVisNet): + def forward(self, x, z): + x = x * self.std + + if self.prior_model is not None: + x = self.prior_model(x, z) + + return x + + +class EquivariantRepToScaler(torch.nn.module): + # Wrapper for PygVisNet.EquivariantScalar to implement forward() method + def __init__(self, equv_layer): + super(EquivariantRepToScaler, self).__init__() + self.equv_layer = equv_layer + def forward(self, x, v): + return self.equv_layer.pre_reduce(x, v) + +class ViSNet(torch.nn.Module): def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct default SchNet model, otherwise copy ## all parameters and weights over if model is None: - super(ViSNet, self).__init__(*args, **kwargs) + self.visnet = PygVisNet(*args, **kwargs) else: try: atomref = model.atomref.weight.detach().clone() @@ -36,14 +60,43 @@ def __init__(self, *args, model=None, **kwargs): model.reduce_op, model.mean, model.std, - model.derivative, + model.derivative, # not used. originally calculates "force" from energy atomref, ) - super(ViSNet, self).__init__(*model_params) + self.visnet = PygVisNet(*model_params) self.load_state_dict(model.state_dict()) + # self.readout = UpdateAtoms(self.visnet.prior_model, self.visnet.std) + self.readout = EquivariantRepToScaler(self.visnet.output_model) + + def forward(self, data): - return super(ViSNet, self).forward(data["z"], data["pos"]) + # return super(ViSNet, self).forward(data["z"], data["pos"]) + """ + Computes the energies or properties (forces) for a batch of + molecules. + + Args: + z (torch.Tensor): The atomic numbers. + pos (torch.Tensor): The coordinates of the atoms. + batch (torch.Tensor): A batch vector, + which assigns each node to a specific example. + + Returns: + x (torch.Tensor): Scalar output based on node features and vector features. + dx (torch.Tensor, optional): The negative derivative of x. + """ + pos = data["pos"] + z = data["z"] + + # all atom in one pass from the same molecule + # TODO: set separate batch for ligand and protein + batch = torch.zeros(z.shape[0], device=z.device) + x, v = self.visnet.representation_model(z, pos, batch) + x = self.readout(x, v) + + # x = self.visnet.output_model.pre_reduce(x, v) + # return self.readout(x, z) def _get_representation(self): """ @@ -63,10 +116,6 @@ def _get_representation(self): ## Copy model so initial model isn't affected model_copy = deepcopy(self) - # FIXME: change to visnet X - ## Replace final linear layer with an identity module - model_copy.lin2 = torch.nn.Identity() - return model_copy def _get_energy_func(self): @@ -83,8 +132,7 @@ def _get_energy_func(self): torch.nn.modules.linear.Linear Copy of `model`'s last layer """ - # FIXME: change to visnet X - return deepcopy(self.lin2) + return deepcopy(self.readout) def _get_delta_strategy(self): """ From f61d2d1cfcb20961ec015b869ddf7caa3695e8ec Mon Sep 17 00:00:00 2001 From: fyng Date: Thu, 25 Jan 2024 08:45:56 -0500 Subject: [PATCH 03/42] add visnet to ModelType --- mtenn/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 4a87d84..360f237 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -39,7 +39,7 @@ class ModelType(StringEnum): schnet = "schnet" e3nn = "e3nn" INVALID = "INVALID" - visnet = "visnet" + ViSNet = "visnet" class StrategyConfig(StringEnum): """ @@ -751,8 +751,6 @@ def _build(self, mtenn_params={}): ) - -#FIXME: Update to ViSNet API class ViSNetModelConfig(ModelConfigBase): """ Class for constructing a SchNet ML model. Default values here are the default values From bdcadfbd4b0ddc765d2f22d2d9039c6519e9a6ba Mon Sep 17 00:00:00 2001 From: fyng Date: Thu, 25 Jan 2024 11:56:18 -0500 Subject: [PATCH 04/42] standardize naming for visnet --- mtenn/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 360f237..ab1803c 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -39,7 +39,7 @@ class ModelType(StringEnum): schnet = "schnet" e3nn = "e3nn" INVALID = "INVALID" - ViSNet = "visnet" + visnet = "visnet" class StrategyConfig(StringEnum): """ @@ -784,7 +784,7 @@ class ViSNetModelConfig(ModelConfigBase): False, description="Whether to use vertex geometric features." ) - atomref: torch.Tensor | None = Field( + atomref: list[float] | None = Field( None, description=( "Reference values for single-atom properties. Should have length of 100 to " From 6a0d11155ac6e2d72f76ed5c586257e61f00e5dd Mon Sep 17 00:00:00 2001 From: fyng Date: Thu, 25 Jan 2024 11:56:26 -0500 Subject: [PATCH 05/42] add visnet to init --- mtenn/conversion_utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mtenn/conversion_utils/__init__.py b/mtenn/conversion_utils/__init__.py index f9c3001..8a84f04 100644 --- a/mtenn/conversion_utils/__init__.py +++ b/mtenn/conversion_utils/__init__.py @@ -1,5 +1,6 @@ from .e3nn import E3NN from .gat import GAT from .schnet import SchNet +from .visnet import ViSNet -__all__ = ["E3NN", "GAT", "SchNet"] +__all__ = ["E3NN", "GAT", "SchNet", "ViSNet"] From 9c9d96aa4b3a12f86345cc3f01e5db45aa0ff724 Mon Sep 17 00:00:00 2001 From: fyng Date: Thu, 25 Jan 2024 17:26:44 -0500 Subject: [PATCH 06/42] fix typo --- mtenn/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index ab1803c..98b177f 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -835,7 +835,7 @@ def _build(self, mtenn_params={}): model = ViSNet( lmax=self.lmax, vecnorm_type=self.vecnorm_type, - trainable_vector_norm=self.trainable_vecnorm, + trainable_vecnorm=self.trainable_vecnorm, num_heads=self.num_heads, num_layers=self.num_layers, hidden_channels=self.hidden_channels, From 14864db59d0047e0da458e778346a370e188a041 Mon Sep 17 00:00:00 2001 From: fyng Date: Thu, 25 Jan 2024 17:28:47 -0500 Subject: [PATCH 07/42] fix last layer of model --- mtenn/conversion_utils/visnet.py | 47 ++++++++++++-------------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 634c221..b15ac85 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -9,32 +9,25 @@ from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy - -class UpdateAtoms(torch.nn.Module): - def __init__(self, prior_model, std): - super(UpdateAtoms, self).__init__() - self.prior_model = prior_model - self.std = std - - def forward(self, x, z): - x = x * self.std - - if self.prior_model is not None: - x = self.prior_model(x, z) - - return x -class EquivariantRepToScaler(torch.nn.module): +class EquivariantVecToScaler(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method - def __init__(self, equv_layer): - super(EquivariantRepToScaler, self).__init__() - self.equv_layer = equv_layer - def forward(self, x, v): - return self.equv_layer.pre_reduce(x, v) + def __init__(self, mean, reduce_op): + super(EquivariantVecToScaler, self).__init__() + self.mean = mean + self.reduce_op = reduce_op + def forward(self, x): + # dummy variable. all atoms from the same molecule and the same batch + batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) + + y = scatter(x, batch, dim=0, reduce=self.reduce_op) + return y + self.mean + class ViSNet(torch.nn.Module): def __init__(self, *args, model=None, **kwargs): + super().__init__() ## If no model is passed, construct default SchNet model, otherwise copy ## all parameters and weights over if model is None: @@ -66,12 +59,9 @@ def __init__(self, *args, model=None, **kwargs): self.visnet = PygVisNet(*model_params) self.load_state_dict(model.state_dict()) - # self.readout = UpdateAtoms(self.visnet.prior_model, self.visnet.std) - self.readout = EquivariantRepToScaler(self.visnet.output_model) - + self.readout = EquivariantVecToScaler(self.visnet.mean, self.visnet.reduce_op) def forward(self, data): - # return super(ViSNet, self).forward(data["z"], data["pos"]) """ Computes the energies or properties (forces) for a batch of molecules. @@ -84,7 +74,6 @@ def forward(self, data): Returns: x (torch.Tensor): Scalar output based on node features and vector features. - dx (torch.Tensor, optional): The negative derivative of x. """ pos = data["pos"] z = data["z"] @@ -93,10 +82,10 @@ def forward(self, data): # TODO: set separate batch for ligand and protein batch = torch.zeros(z.shape[0], device=z.device) x, v = self.visnet.representation_model(z, pos, batch) - x = self.readout(x, v) + x = self.visnet.output_model.pre_reduce(x, v) + x = x * self.visnet.std - # x = self.visnet.output_model.pre_reduce(x, v) - # return self.readout(x, z) + return x def _get_representation(self): """ @@ -120,7 +109,7 @@ def _get_representation(self): def _get_energy_func(self): """ - Return last layer of the model. + Return last layer of the model (outputs scalar value) Parameters ---------- From 81f8818ad64c661a805e64156e1d67c6a363a2c4 Mon Sep 17 00:00:00 2001 From: fyng Date: Fri, 26 Jan 2024 16:27:31 -0500 Subject: [PATCH 08/42] install packages from pytorch-nightly --- devtools/conda-envs/test_env.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index b8221be..e171b76 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -1,5 +1,6 @@ name: test channels: + - pytorch-nightly - conda-forge dependencies: - pytorch From be05591007f049a716ab22756c1ad09d701125f2 Mon Sep 17 00:00:00 2001 From: fyng Date: Fri, 26 Jan 2024 16:53:47 -0500 Subject: [PATCH 09/42] pip install pytorch geometric nightly instead --- devtools/conda-envs/test_env.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index e171b76..35ee854 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -1,10 +1,10 @@ name: test channels: - - pytorch-nightly - conda-forge dependencies: + - pip - pytorch - - pytorch_geometric + # - pytorch_geometric - pytorch_cluster - pytorch_scatter - pytorch_sparse @@ -20,3 +20,6 @@ dependencies: - pytest-cov - codecov - pydantic >=1.10.8,<2.0.0a0 + + - pip: + - pyg-nightly \ No newline at end of file From 7be95b7535da5b449a9d125611106a5f754aca75 Mon Sep 17 00:00:00 2001 From: fyng Date: Fri, 26 Jan 2024 17:14:00 -0500 Subject: [PATCH 10/42] try importing visnet, implement has_visnet_flag --- mtenn/conversion_utils/visnet.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index b15ac85..7745f2f 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -4,12 +4,17 @@ from copy import deepcopy import torch from torch.autograd import grad -from torch_geometric.nn.models import ViSNet as PygVisNet from torch_geometric.utils import scatter from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy - + +HAS_VISNET_FLAG = False +try: + from torch_geometric.nn.models import ViSNet as PygVisNet + HAS_VISNET_FLAG = True +except ImportError: + pass class EquivariantVecToScaler(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method From f78ed8c6b8fd047bfdb2cefa0c3a88958c48229f Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 11:22:25 -0500 Subject: [PATCH 11/42] add test for visnet --- mtenn/tests/test_model_config.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index dd16da7..4aebf39 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -1,4 +1,4 @@ -from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig +from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig, ViSNetModelConfig def test_random_seed_gat(): @@ -65,3 +65,24 @@ def test_random_seed_schnet(): for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) ] assert sum(set_equal) == len(set_equal) + +def test_random_seed_visnet(): + rand_config = ViSNetModelConfig() + set_config = ViSNetModelConfig(rand_seed=10) + + rand_model1 = rand_config.build() + rand_model2 = rand_config.build() + set_model1 = set_config.build() + set_model2 = set_config.build() + + rand_equal = [ + (p1 == p2).all() + for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + ] + assert sum(rand_equal) < len(rand_equal) + + set_equal = [ + (p1 == p2).all() + for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) + ] + assert sum(set_equal) == len(set_equal) \ No newline at end of file From 4f8975f71a8912a91ab96900cee95abd5e378b43 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 13:34:27 -0500 Subject: [PATCH 12/42] VisNet import error handling for older PyG version --- mtenn/config.py | 77 ++--- mtenn/conversion_utils/__init__.py | 4 +- mtenn/conversion_utils/visnet.py | 440 ++++++++++++++--------------- 3 files changed, 265 insertions(+), 256 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 98b177f..613f1e9 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, Field, root_validator import random from typing import Callable, ClassVar - import mtenn import numpy as np import torch @@ -753,7 +752,7 @@ def _build(self, mtenn_params={}): class ViSNetModelConfig(ModelConfigBase): """ - Class for constructing a SchNet ML model. Default values here are the default values + Class for constructing a VisNet ML model. Default values here are the default values given in PyG. """ @@ -813,6 +812,8 @@ def validate(cls, values): raise ValueError(f"atomref must be length 100 (got {len(atomref)})") return values + + def _build(self, mtenn_params={}): """ @@ -829,39 +830,45 @@ def _build(self, mtenn_params={}): mtenn.model.Model MTENN ViSNet Model/GroupedModel """ - from mtenn.conversion_utils import ViSNet - # Create an MTENN ViSNet model from PyG ViSNet model - model = ViSNet( - lmax=self.lmax, - vecnorm_type=self.vecnorm_type, - trainable_vecnorm=self.trainable_vecnorm, - num_heads=self.num_heads, - num_layers=self.num_layers, - hidden_channels=self.hidden_channels, - num_rbf=self.num_rbf, - trainable_rbf=self.trainable_rbf, - max_z=self.max_z, - cutoff=self.cutoff, - max_num_neighbors=self.max_num_neighbors, - vertex=self.vertex, - reduce_op=self.reduce_op, - mean=self.mean, - std=self.std, - derivative=self.derivative, - atomref=self.atomref, - ) - combination = mtenn_params.get("combination", None) - pred_readout = mtenn_params.get("pred_readout", None) - comb_readout = mtenn_params.get("comb_readout", None) + from mtenn.conversion_utils.visnet import HAS_VISNET_FLAG + if HAS_VISNET_FLAG: + from mtenn.conversion_utils import ViSNet + + model = ViSNet( + lmax=self.lmax, + vecnorm_type=self.vecnorm_type, + trainable_vecnorm=self.trainable_vecnorm, + num_heads=self.num_heads, + num_layers=self.num_layers, + hidden_channels=self.hidden_channels, + num_rbf=self.num_rbf, + trainable_rbf=self.trainable_rbf, + max_z=self.max_z, + cutoff=self.cutoff, + max_num_neighbors=self.max_num_neighbors, + vertex=self.vertex, + reduce_op=self.reduce_op, + mean=self.mean, + std=self.std, + derivative=self.derivative, + atomref=self.atomref, + ) + combination = mtenn_params.get("combination", None) + pred_readout = mtenn_params.get("pred_readout", None) + comb_readout = mtenn_params.get("comb_readout", None) + + return ViSNet.get_model( + model=model, + grouped=self.grouped, + fix_device=True, + strategy=self.strategy, + combination=combination, + pred_readout=pred_readout, + comb_readout=comb_readout, + ) - return ViSNet.get_model( - model=model, - grouped=self.grouped, - fix_device=True, - strategy=self.strategy, - combination=combination, - pred_readout=pred_readout, - comb_readout=comb_readout, - ) \ No newline at end of file + else: + raise ImportError("ViSNet not found. Is your PyG >=2.5.0?") + diff --git a/mtenn/conversion_utils/__init__.py b/mtenn/conversion_utils/__init__.py index 8a84f04..41438d7 100644 --- a/mtenn/conversion_utils/__init__.py +++ b/mtenn/conversion_utils/__init__.py @@ -1,6 +1,8 @@ from .e3nn import E3NN from .gat import GAT from .schnet import SchNet -from .visnet import ViSNet +from .visnet import HAS_VISNET_FLAG +if HAS_VISNET_FLAG: + from .visnet import ViSNet __all__ = ["E3NN", "GAT", "SchNet", "ViSNet"] diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 7745f2f..5d994e4 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -15,223 +15,223 @@ HAS_VISNET_FLAG = True except ImportError: pass - -class EquivariantVecToScaler(torch.nn.Module): - # Wrapper for PygVisNet.EquivariantScalar to implement forward() method - def __init__(self, mean, reduce_op): - super(EquivariantVecToScaler, self).__init__() - self.mean = mean - self.reduce_op = reduce_op - def forward(self, x): - # dummy variable. all atoms from the same molecule and the same batch - batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) - - y = scatter(x, batch, dim=0, reduce=self.reduce_op) - return y + self.mean - - -class ViSNet(torch.nn.Module): - def __init__(self, *args, model=None, **kwargs): - super().__init__() - ## If no model is passed, construct default SchNet model, otherwise copy - ## all parameters and weights over - if model is None: - self.visnet = PygVisNet(*args, **kwargs) - else: - try: - atomref = model.atomref.weight.detach().clone() - except AttributeError: - atomref = None - model_params = ( - model.lmax, - model.vecnorm_type, - model.trainable_vecnorm, - model.num_heads, - model.num_layers, - model.hidden_channels, - model.num_rbf, - model.trainable_rbf, - model.max_z, - model.cutoff, - model.max_num_neighbors, - model.vertex, - model.reduce_op, - model.mean, - model.std, - model.derivative, # not used. originally calculates "force" from energy - atomref, - ) - self.visnet = PygVisNet(*model_params) - self.load_state_dict(model.state_dict()) - - self.readout = EquivariantVecToScaler(self.visnet.mean, self.visnet.reduce_op) - - def forward(self, data): - """ - Computes the energies or properties (forces) for a batch of - molecules. - - Args: - z (torch.Tensor): The atomic numbers. - pos (torch.Tensor): The coordinates of the atoms. - batch (torch.Tensor): A batch vector, - which assigns each node to a specific example. - - Returns: - x (torch.Tensor): Scalar output based on node features and vector features. - """ - pos = data["pos"] - z = data["z"] - - # all atom in one pass from the same molecule - # TODO: set separate batch for ligand and protein - batch = torch.zeros(z.shape[0], device=z.device) - x, v = self.visnet.representation_model(z, pos, batch) - x = self.visnet.output_model.pre_reduce(x, v) - x = x * self.visnet.std - - return x - - def _get_representation(self): - """ - Input model, remove last layer. - - Parameters - ---------- - model: SchNet - SchNet model - - Returns - ------- - SchNet - Copied SchNet model with the last layer replaced by an Identity module - """ - - ## Copy model so initial model isn't affected - model_copy = deepcopy(self) - - return model_copy - - def _get_energy_func(self): - """ - Return last layer of the model (outputs scalar value) - - Parameters - ---------- - model: SchNet - SchNet model - - Returns - ------- - torch.nn.modules.linear.Linear - Copy of `model`'s last layer - """ - return deepcopy(self.readout) - - def _get_delta_strategy(self): - """ - Build a DeltaStrategy object based on the passed model. - - Parameters - ---------- - model: SchNet - SchNet model - - Returns - ------- - DeltaStrategy - DeltaStrategy built from `model` - """ - - return DeltaStrategy(self._get_energy_func()) - - def _get_complex_only_strategy(self): - """ - Build a ComplexOnlyStrategy object based on the passed model. - - Returns - ------- - ComplexOnlyStrategy - ComplexOnlyStrategy built from `self` - """ - - return ComplexOnlyStrategy(self._get_energy_func()) - - @staticmethod - def get_model( - model=None, - grouped=False, - fix_device=False, - strategy: str = "delta", - combination=None, - pred_readout=None, - comb_readout=None, - ): - """ - Exposed function to build a Model object from a SchNet object. If none - is provided, a default model is initialized. - - Parameters - ---------- - model: SchNet, optional - SchNet model to use to build the Model object. If left as none, a - default model will be initialized and used - grouped: bool, default=False - Whether this model should accept groups of inputs or one input at a - time. - fix_device: bool, default=False - If True, make sure the input is on the same device as the model, - copying over as necessary. - strategy: str, default='delta' - Strategy to use to combine representation of the different parts. - Options are ['delta', 'concat', 'complex'] - combination: Combination, optional - Combination object to use to combine predictions in a group. A value - must be passed if `grouped` is `True`. - pred_readout : Readout - Readout object for the energy predictions. If `grouped` is `False`, - this option will still be used in the construction of the `Model` - object. - comb_readout : Readout - Readout object for the combination output. - - Returns - ------- - Model - Model object containing the desired Representation and Strategy - """ - if model is None: - model = ViSNet() - - ## First get representation module - representation = model._get_representation() - - ## Construct strategy module based on model and - ## representation (if necessary) - strategy = strategy.lower() - if strategy == "delta": - strategy = model._get_delta_strategy() - elif strategy == "concat": - strategy = ConcatStrategy() - elif strategy == "complex": - strategy = model._get_complex_only_strategy() - else: - raise ValueError(f"Unknown strategy: {strategy}") - - ## Check on `combination` - if grouped and (combination is None): - raise ValueError( - "Must pass a value for `combination` if `grouped` is `True`." - ) - - if grouped: - return GroupedModel( - representation, - strategy, - combination, - pred_readout, - comb_readout, - fix_device, - ) - else: - return Model(representation, strategy, pred_readout, fix_device) +if HAS_VISNET_FLAG: + class EquivariantVecToScaler(torch.nn.Module): + # Wrapper for PygVisNet.EquivariantScalar to implement forward() method + def __init__(self, mean, reduce_op): + super(EquivariantVecToScaler, self).__init__() + self.mean = mean + self.reduce_op = reduce_op + def forward(self, x): + # dummy variable. all atoms from the same molecule and the same batch + batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) + + y = scatter(x, batch, dim=0, reduce=self.reduce_op) + return y + self.mean + + + class ViSNet(torch.nn.Module): + def __init__(self, *args, model=None, **kwargs): + super().__init__() + ## If no model is passed, construct default SchNet model, otherwise copy + ## all parameters and weights over + if model is None: + self.visnet = PygVisNet(*args, **kwargs) + else: + try: + atomref = model.atomref.weight.detach().clone() + except AttributeError: + atomref = None + model_params = ( + model.lmax, + model.vecnorm_type, + model.trainable_vecnorm, + model.num_heads, + model.num_layers, + model.hidden_channels, + model.num_rbf, + model.trainable_rbf, + model.max_z, + model.cutoff, + model.max_num_neighbors, + model.vertex, + model.reduce_op, + model.mean, + model.std, + model.derivative, # not used. originally calculates "force" from energy + atomref, + ) + self.visnet = PygVisNet(*model_params) + self.load_state_dict(model.state_dict()) + + self.readout = EquivariantVecToScaler(self.visnet.mean, self.visnet.reduce_op) + + def forward(self, data): + """ + Computes the energies or properties (forces) for a batch of + molecules. + + Args: + z (torch.Tensor): The atomic numbers. + pos (torch.Tensor): The coordinates of the atoms. + batch (torch.Tensor): A batch vector, + which assigns each node to a specific example. + + Returns: + x (torch.Tensor): Scalar output based on node features and vector features. + """ + pos = data["pos"] + z = data["z"] + + # all atom in one pass from the same molecule + # TODO: set separate batch for ligand and protein + batch = torch.zeros(z.shape[0], device=z.device) + x, v = self.visnet.representation_model(z, pos, batch) + x = self.visnet.output_model.pre_reduce(x, v) + x = x * self.visnet.std + + return x + + def _get_representation(self): + """ + Input model, remove last layer. + + Parameters + ---------- + model: SchNet + SchNet model + + Returns + ------- + SchNet + Copied SchNet model with the last layer replaced by an Identity module + """ + + ## Copy model so initial model isn't affected + model_copy = deepcopy(self) + + return model_copy + + def _get_energy_func(self): + """ + Return last layer of the model (outputs scalar value) + + Parameters + ---------- + model: SchNet + SchNet model + + Returns + ------- + torch.nn.modules.linear.Linear + Copy of `model`'s last layer + """ + return deepcopy(self.readout) + + def _get_delta_strategy(self): + """ + Build a DeltaStrategy object based on the passed model. + + Parameters + ---------- + model: SchNet + SchNet model + + Returns + ------- + DeltaStrategy + DeltaStrategy built from `model` + """ + + return DeltaStrategy(self._get_energy_func()) + + def _get_complex_only_strategy(self): + """ + Build a ComplexOnlyStrategy object based on the passed model. + + Returns + ------- + ComplexOnlyStrategy + ComplexOnlyStrategy built from `self` + """ + + return ComplexOnlyStrategy(self._get_energy_func()) + + @staticmethod + def get_model( + model=None, + grouped=False, + fix_device=False, + strategy: str = "delta", + combination=None, + pred_readout=None, + comb_readout=None, + ): + """ + Exposed function to build a Model object from a SchNet object. If none + is provided, a default model is initialized. + + Parameters + ---------- + model: SchNet, optional + SchNet model to use to build the Model object. If left as none, a + default model will be initialized and used + grouped: bool, default=False + Whether this model should accept groups of inputs or one input at a + time. + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. + strategy: str, default='delta' + Strategy to use to combine representation of the different parts. + Options are ['delta', 'concat', 'complex'] + combination: Combination, optional + Combination object to use to combine predictions in a group. A value + must be passed if `grouped` is `True`. + pred_readout : Readout + Readout object for the energy predictions. If `grouped` is `False`, + this option will still be used in the construction of the `Model` + object. + comb_readout : Readout + Readout object for the combination output. + + Returns + ------- + Model + Model object containing the desired Representation and Strategy + """ + if model is None: + model = ViSNet() + + ## First get representation module + representation = model._get_representation() + + ## Construct strategy module based on model and + ## representation (if necessary) + strategy = strategy.lower() + if strategy == "delta": + strategy = model._get_delta_strategy() + elif strategy == "concat": + strategy = ConcatStrategy() + elif strategy == "complex": + strategy = model._get_complex_only_strategy() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + ## Check on `combination` + if grouped and (combination is None): + raise ValueError( + "Must pass a value for `combination` if `grouped` is `True`." + ) + + if grouped: + return GroupedModel( + representation, + strategy, + combination, + pred_readout, + comb_readout, + fix_device, + ) + else: + return Model(representation, strategy, pred_readout, fix_device) From 6c15f042a9018f6a664ed4b8118f86b6b0c3a962 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 14:06:16 -0500 Subject: [PATCH 13/42] change variable name --- mtenn/config.py | 4 ++-- mtenn/conversion_utils/__init__.py | 4 ++-- mtenn/conversion_utils/visnet.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 613f1e9..7523460 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -832,8 +832,8 @@ def _build(self, mtenn_params={}): """ # Create an MTENN ViSNet model from PyG ViSNet model - from mtenn.conversion_utils.visnet import HAS_VISNET_FLAG - if HAS_VISNET_FLAG: + from mtenn.conversion_utils.visnet import HAS_VISNET + if HAS_VISNET: from mtenn.conversion_utils import ViSNet model = ViSNet( diff --git a/mtenn/conversion_utils/__init__.py b/mtenn/conversion_utils/__init__.py index 41438d7..0d2dcb1 100644 --- a/mtenn/conversion_utils/__init__.py +++ b/mtenn/conversion_utils/__init__.py @@ -1,8 +1,8 @@ from .e3nn import E3NN from .gat import GAT from .schnet import SchNet -from .visnet import HAS_VISNET_FLAG -if HAS_VISNET_FLAG: +from .visnet import HAS_VISNET +if HAS_VISNET: from .visnet import ViSNet __all__ = ["E3NN", "GAT", "SchNet", "ViSNet"] diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 5d994e4..ebdabbd 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -9,13 +9,13 @@ from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy -HAS_VISNET_FLAG = False +HAS_VISNET = False try: from torch_geometric.nn.models import ViSNet as PygVisNet - HAS_VISNET_FLAG = True + HAS_VISNET = True except ImportError: pass -if HAS_VISNET_FLAG: +if HAS_VISNET: class EquivariantVecToScaler(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method def __init__(self, mean, reduce_op): From 549e8f7e0d0da607089ec413f2573fbafa58fbed Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 14:11:26 -0500 Subject: [PATCH 14/42] add import guard to VisNet test --- mtenn/tests/test_model_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index 4aebf39..ebf91f0 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -1,5 +1,5 @@ from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig, ViSNetModelConfig - +from mtenn.conversion_utils.visnet import HAS_VISNET def test_random_seed_gat(): rand_config = GATModelConfig() @@ -66,6 +66,7 @@ def test_random_seed_schnet(): ] assert sum(set_equal) == len(set_equal) +@pytest.mark.skipif(not HAS_VISNET, reason="requires VisNet from nightly PyG") def test_random_seed_visnet(): rand_config = ViSNetModelConfig() set_config = ViSNetModelConfig(rand_seed=10) From efad5c3b740f282d03f69d99ae212c7ffef9f474 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 14:19:56 -0500 Subject: [PATCH 15/42] create two test environment files --- devtools/conda-envs/test_env-nightly.yaml | 24 +++++++++++++++++++++++ devtools/conda-envs/test_env.yaml | 8 ++------ 2 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 devtools/conda-envs/test_env-nightly.yaml diff --git a/devtools/conda-envs/test_env-nightly.yaml b/devtools/conda-envs/test_env-nightly.yaml new file mode 100644 index 0000000..ba2d308 --- /dev/null +++ b/devtools/conda-envs/test_env-nightly.yaml @@ -0,0 +1,24 @@ +name: test +channels: + - conda-forge +dependencies: + - pip + - pytorch + - pytorch_cluster + - pytorch_scatter + - pytorch_sparse + - numpy + - h5py + - e3nn + - dgllife + - dgl + - rdkit + - ase + # testing dependencies + - pytest + - pytest-cov + - codecov + - pydantic >=1.10.8,<2.0.0a0 + + - pip: + - pyg-nightly \ No newline at end of file diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 35ee854..9704c0f 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -2,9 +2,8 @@ name: test channels: - conda-forge dependencies: - - pip - pytorch - # - pytorch_geometric + - pytorch_geometric - pytorch_cluster - pytorch_scatter - pytorch_sparse @@ -19,7 +18,4 @@ dependencies: - pytest - pytest-cov - codecov - - pydantic >=1.10.8,<2.0.0a0 - - - pip: - - pyg-nightly \ No newline at end of file + - pydantic >=1.10.8,<2.0.0a0 \ No newline at end of file From 25be21317fd15203a08b23a1dd32b6e5c3ef2ce3 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 14:22:59 -0500 Subject: [PATCH 16/42] update CI to test stable and nightly builds --- .github/workflows/CI.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index b3900d6..34f0562 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -29,6 +29,7 @@ jobs: matrix: os: [macOS-latest, ubuntu-latest] python-version: ["3.10", "3.11"] + deps: ["test_env.yaml", "test_env-nightly.yaml"] steps: - name: Checkout Repository @@ -45,7 +46,7 @@ jobs: - name: Setup Micromamba uses: mamba-org/setup-micromamba@v1 with: - environment-file: devtools/conda-envs/test_env.yaml + environment-file: ${{ matrix.deps }} environment-name: test create-args: >- python==${{ matrix.python-version }} From 73765c4cf81957c1e573ea0e136baf8c4207a82a Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 14:27:51 -0500 Subject: [PATCH 17/42] fix CI --- .github/workflows/CI.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 34f0562..400df22 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -22,14 +22,14 @@ defaults: jobs: test: - name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} + name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}, Env ${{ matrix.deps }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [macOS-latest, ubuntu-latest] python-version: ["3.10", "3.11"] - deps: ["test_env.yaml", "test_env-nightly.yaml"] + deps: ["devtools/conda-envs/test_env.yaml", "devtools/conda-envs/test_env-nightly.yaml"] steps: - name: Checkout Repository @@ -69,5 +69,5 @@ jobs: with: file: ./coverage.xml flags: unittests - name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} + name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}-enc${{ matrix.deps }} fail_ci_if_error: false From f92491ca127746bcf4881943b6cff003d678abd9 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 14:30:57 -0500 Subject: [PATCH 18/42] add pytest import --- mtenn/tests/test_model_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index ebf91f0..b6b0b8d 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -1,5 +1,6 @@ from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig, ViSNetModelConfig from mtenn.conversion_utils.visnet import HAS_VISNET +import pytest def test_random_seed_gat(): rand_config = GATModelConfig() From eb816263c8257d7c3813508c52745de76c78487b Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 16:05:15 -0500 Subject: [PATCH 19/42] fix typo and style --- .github/workflows/CI.yaml | 2 +- mtenn/conversion_utils/visnet.py | 30 +++++++++++++++++------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 400df22..41be6da 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -69,5 +69,5 @@ jobs: with: file: ./coverage.xml flags: unittests - name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}-enc${{ matrix.deps }} + name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}-env${{ matrix.deps }} fail_ci_if_error: false diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index ebdabbd..a2657ad 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -9,27 +9,31 @@ from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy +# guard required: currently require PyG nightly 2.5.0 (SEE ISSUE #123456) HAS_VISNET = False + try: from torch_geometric.nn.models import ViSNet as PygVisNet HAS_VISNET = True except ImportError: - pass -if HAS_VISNET: - class EquivariantVecToScaler(torch.nn.Module): - # Wrapper for PygVisNet.EquivariantScalar to implement forward() method - def __init__(self, mean, reduce_op): - super(EquivariantVecToScaler, self).__init__() - self.mean = mean - self.reduce_op = reduce_op - def forward(self, x): - # dummy variable. all atoms from the same molecule and the same batch - batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) + pass + - y = scatter(x, batch, dim=0, reduce=self.reduce_op) - return y + self.mean +class EquivariantVecToScaler(torch.nn.Module): + # Wrapper for PygVisNet.EquivariantScalar to implement forward() method + def __init__(self, mean, reduce_op): + super(EquivariantVecToScaler, self).__init__() + self.mean = mean + self.reduce_op = reduce_op + def forward(self, x): + # dummy variable. all atoms from the same molecule and the same batch + batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) + y = scatter(x, batch, dim=0, reduce=self.reduce_op) + return y + self.mean + +if HAS_VISNET: class ViSNet(torch.nn.Module): def __init__(self, *args, model=None, **kwargs): super().__init__() From 176fccc4ad1f075a54b8e67af3d8c7ad587d0eaa Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 16:18:18 -0500 Subject: [PATCH 20/42] Comments for Issue 42 --- mtenn/config.py | 2 +- mtenn/conversion_utils/visnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 7523460..0ef6f38 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -870,5 +870,5 @@ def _build(self, mtenn_params={}): ) else: - raise ImportError("ViSNet not found. Is your PyG >=2.5.0?") + raise ImportError("ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42.") diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index a2657ad..eda8040 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -9,7 +9,7 @@ from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy -# guard required: currently require PyG nightly 2.5.0 (SEE ISSUE #123456) +# guard required: currently require PyG nightly 2.5.0 (SEE ISSUE #42) HAS_VISNET = False try: From a277bee8647e1d435e18edb64b74a4295c7de0f8 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 16:18:28 -0500 Subject: [PATCH 21/42] comments for issue #42 --- mtenn/conversion_utils/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mtenn/conversion_utils/__init__.py b/mtenn/conversion_utils/__init__.py index 0d2dcb1..58d8eab 100644 --- a/mtenn/conversion_utils/__init__.py +++ b/mtenn/conversion_utils/__init__.py @@ -1,6 +1,8 @@ from .e3nn import E3NN from .gat import GAT from .schnet import SchNet + +# refer to issue #42 from .visnet import HAS_VISNET if HAS_VISNET: from .visnet import ViSNet From c08327779cd8633082833bb3976457ca5f53d2fe Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 29 Jan 2024 16:19:36 -0500 Subject: [PATCH 22/42] fix a typo --- mtenn/conversion_utils/visnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index eda8040..1872174 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -19,10 +19,10 @@ pass -class EquivariantVecToScaler(torch.nn.Module): +class EquivariantVecToScalar(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method def __init__(self, mean, reduce_op): - super(EquivariantVecToScaler, self).__init__() + super(EquivariantVecToScalar, self).__init__() self.mean = mean self.reduce_op = reduce_op def forward(self, x): @@ -68,7 +68,7 @@ def __init__(self, *args, model=None, **kwargs): self.visnet = PygVisNet(*model_params) self.load_state_dict(model.state_dict()) - self.readout = EquivariantVecToScaler(self.visnet.mean, self.visnet.reduce_op) + self.readout = EquivariantVecToScalar(self.visnet.mean, self.visnet.reduce_op) def forward(self, data): """ From 03d0f4de3941da9daabbe15f93956d5c57fcd054 Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:18:12 -0500 Subject: [PATCH 23/42] Fix a typo Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/conversion_utils/visnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 1872174..6079611 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -37,7 +37,7 @@ def forward(self, x): class ViSNet(torch.nn.Module): def __init__(self, *args, model=None, **kwargs): super().__init__() - ## If no model is passed, construct default SchNet model, otherwise copy + ## If no model is passed, construct default ViSNet model, otherwise copy ## all parameters and weights over if model is None: self.visnet = PygVisNet(*args, **kwargs) From 6833c21f347754ac9ed734f54a8f6028ad00fa66 Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:18:49 -0500 Subject: [PATCH 24/42] Visnet mean, std cannot be None Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 0ef6f38..77d3994 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -794,8 +794,8 @@ class ViSNetModelConfig(ModelConfigBase): "sum", description="The type of reduction operation to apply. ['sum', 'mean']" ) - mean: float | None = Field(0.0, description="The mean of the output distribution.") - std: float | None = Field(1.0, description="The standard deviation of the output distribution.") + mean: float = Field(0.0, description="The mean of the output distribution.") + std: float = Field(1.0, description="The standard deviation of the output distribution.") derivative: bool = Field( False, description="Whether to compute the derivative of the output with respect to the positions." From 017460d0377d94af502b6a0d2f7b590dccd54fdc Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Wed, 31 Jan 2024 10:02:43 -0500 Subject: [PATCH 25/42] VisNet accepts atomref = none Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/conversion_utils/visnet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 6079611..6b6df43 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -42,10 +42,7 @@ def __init__(self, *args, model=None, **kwargs): if model is None: self.visnet = PygVisNet(*args, **kwargs) else: - try: - atomref = model.atomref.weight.detach().clone() - except AttributeError: - atomref = None + atomref = model.prior_model.atomref.weight.detach().clone() model_params = ( model.lmax, model.vecnorm_type, From 2c9bb9980d8e5ebe0ec83df3981f8a3fd842a663 Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 31 Jan 2024 10:03:26 -0500 Subject: [PATCH 26/42] atomref should match max_z --- mtenn/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 0ef6f38..e55661e 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -786,8 +786,7 @@ class ViSNetModelConfig(ModelConfigBase): atomref: list[float] | None = Field( None, description=( - "Reference values for single-atom properties. Should have length of 100 to " - "match with PyG." + "Reference values for single-atom properties. Should have length max_z" ) ) reduce_op: str = Field( @@ -808,8 +807,8 @@ def validate(cls, values): # Make sure atomref length is correct (this is required by PyG) atomref = values["atomref"] - if (atomref is not None) and (len(atomref) != 100): - raise ValueError(f"atomref must be length 100 (got {len(atomref)})") + if (atomref is not None) and (len(atomref) != values["max_z"]): + raise ValueError(f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})") return values From 48721fb73af035794cd7e7738a49901c272ae789 Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 31 Jan 2024 10:36:52 -0500 Subject: [PATCH 27/42] clean up todo --- mtenn/conversion_utils/visnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 6b6df43..67644d2 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -85,7 +85,6 @@ def forward(self, data): z = data["z"] # all atom in one pass from the same molecule - # TODO: set separate batch for ligand and protein batch = torch.zeros(z.shape[0], device=z.device) x, v = self.visnet.representation_model(z, pos, batch) x = self.visnet.output_model.pre_reduce(x, v) From ffa571cee4472a50e4757c8f9b6fc1df8e8d8203 Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Wed, 31 Jan 2024 15:23:27 -0500 Subject: [PATCH 28/42] bring in prior_model from PyG visnet Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/conversion_utils/visnet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 67644d2..e1508aa 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -90,6 +90,9 @@ def forward(self, data): x = self.visnet.output_model.pre_reduce(x, v) x = x * self.visnet.std + if self.visnet.prior_model is not None: + x = self.visnet.prior_model(x, z) + return x def _get_representation(self): From ec3c0c24b76c689ecc15e68bd50df1c16f5a76fc Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 31 Jan 2024 15:25:15 -0500 Subject: [PATCH 29/42] fix indentation --- mtenn/conversion_utils/visnet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index e1508aa..005080c 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -89,9 +89,8 @@ def forward(self, data): x, v = self.visnet.representation_model(z, pos, batch) x = self.visnet.output_model.pre_reduce(x, v) x = x * self.visnet.std - - if self.visnet.prior_model is not None: - x = self.visnet.prior_model(x, z) + if self.visnet.prior_model is not None: + x = self.visnet.prior_model(x, z) return x From 34c8ace053ea53fa73724125535e49e9afc1b556 Mon Sep 17 00:00:00 2001 From: fyng Date: Thu, 1 Feb 2024 14:34:49 -0500 Subject: [PATCH 30/42] add a visnet test --- mtenn/tests/test_model_config.py | 45 +++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index b6b0b8d..63efd65 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -87,4 +87,47 @@ def test_random_seed_visnet(): (p1 == p2).all() for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) ] - assert sum(set_equal) == len(set_equal) \ No newline at end of file + assert sum(set_equal) == len(set_equal) + +@pytest.mark.skipif(not HAS_VISNET, reason="requires VisNet from nightly PyG") +def test_random_seed_visnet_from_pyg(): + from torch_geometric.nn.models import ViSNet as PyVisNet + model = PyVisNet( + lmax=1, + vecnorm_type=None, + trainable_vecnorm=False, + num_heads=8, + num_layers=6, + hidden_channels=128, + num_rbf=32, + trainable_rbf=False, + max_z=100, + cutoff=5.0, + max_num_neighbors=32, + vertex=False, + reduce_op="sum", + mean=0.0, + std=1.0, + derivative=False, + atomref=None, + ) + + rand_config = ViSNetModelConfig(model=model) + set_config = ViSNetModelConfig(rand_seed=10) + + rand_model1 = rand_config.build() + rand_model2 = rand_config.build() + set_model1 = set_config.build() + set_model2 = set_config.build() + + rand_equal = [ + (p1 == p2).all() + for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + ] + assert sum(rand_equal) < len(rand_equal) + + set_equal = [ + (p1 == p2).all() + for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) + ] + assert sum(set_equal) == len(set_equal) From ba919381f9d442fec4afeedc0cde32281db8beb0 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 5 Feb 2024 11:49:14 -0500 Subject: [PATCH 31/42] add import warning to visnet import guard --- mtenn/conversion_utils/visnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 005080c..94de37e 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -1,6 +1,7 @@ """ Representation and strategy for ViSNet model. """ +import warning from copy import deepcopy import torch from torch.autograd import grad @@ -16,8 +17,7 @@ from torch_geometric.nn.models import ViSNet as PygVisNet HAS_VISNET = True except ImportError: - pass - + warning.warn("VisNet import error. Is your PyG >=2.5.0? Refer to issue #42", ImportWarning) class EquivariantVecToScalar(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method From a406012eb978108ed11837d7002c91c932cf5266 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 5 Feb 2024 11:58:46 -0500 Subject: [PATCH 32/42] fix typo --- mtenn/conversion_utils/visnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 94de37e..3161279 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -1,7 +1,7 @@ """ Representation and strategy for ViSNet model. """ -import warning +import warnings from copy import deepcopy import torch from torch.autograd import grad @@ -17,7 +17,7 @@ from torch_geometric.nn.models import ViSNet as PygVisNet HAS_VISNET = True except ImportError: - warning.warn("VisNet import error. Is your PyG >=2.5.0? Refer to issue #42", ImportWarning) + warnings.warn("VisNet import error. Is your PyG >=2.5.0? Refer to issue #42", ImportWarning) class EquivariantVecToScalar(torch.nn.Module): # Wrapper for PygVisNet.EquivariantScalar to implement forward() method From 5f161a3055c58dfef67205400fe325b7ccbeca1f Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 5 Feb 2024 16:18:38 -0500 Subject: [PATCH 33/42] remove redundant visnet set_config test --- mtenn/tests/test_model_config.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index 63efd65..22c7542 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -113,21 +113,11 @@ def test_random_seed_visnet_from_pyg(): ) rand_config = ViSNetModelConfig(model=model) - set_config = ViSNetModelConfig(rand_seed=10) rand_model1 = rand_config.build() - rand_model2 = rand_config.build() - set_model1 = set_config.build() - set_model2 = set_config.build() rand_equal = [ (p1 == p2).all() - for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + for p1, p2 in zip(rand_model1.parameters(), model.parameters()) ] assert sum(rand_equal) < len(rand_equal) - - set_equal = [ - (p1 == p2).all() - for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) - ] - assert sum(set_equal) == len(set_equal) From d14f65db9d5ce4607b84db257f37e442623ad9d1 Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 7 Feb 2024 13:41:23 -0500 Subject: [PATCH 34/42] fix typo in docstring --- mtenn/conversion_utils/visnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 3161279..bcba217 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -170,13 +170,13 @@ def get_model( comb_readout=None, ): """ - Exposed function to build a Model object from a SchNet object. If none + Exposed function to build a Model object from a VisNet object. If none is provided, a default model is initialized. Parameters ---------- - model: SchNet, optional - SchNet model to use to build the Model object. If left as none, a + model: VisNet, optional + VisNet model to use to build the Model object. If left as none, a default model will be initialized and used grouped: bool, default=False Whether this model should accept groups of inputs or one input at a From 60aa2453f86fc7d53ccba6b01b39fb6252ab9207 Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 7 Feb 2024 13:53:20 -0500 Subject: [PATCH 35/42] fix visnet instantiation from pyg test --- mtenn/tests/test_model_config.py | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index 22c7542..d9ef436 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -90,34 +90,34 @@ def test_random_seed_visnet(): assert sum(set_equal) == len(set_equal) @pytest.mark.skipif(not HAS_VISNET, reason="requires VisNet from nightly PyG") -def test_random_seed_visnet_from_pyg(): +def test_visnet_from_pyg(): from torch_geometric.nn.models import ViSNet as PyVisNet - model = PyVisNet( - lmax=1, - vecnorm_type=None, - trainable_vecnorm=False, - num_heads=8, - num_layers=6, - hidden_channels=128, - num_rbf=32, - trainable_rbf=False, - max_z=100, - cutoff=5.0, - max_num_neighbors=32, - vertex=False, - reduce_op="sum", - mean=0.0, - std=1.0, - derivative=False, - atomref=None, - ) - - rand_config = ViSNetModelConfig(model=model) - - rand_model1 = rand_config.build() + from mtenn.conversion_utils.visnet import ViSNet + kwargs={ + 'lmax': 1, + 'vecnorm_type': None, + 'trainable_vecnorm': False, + 'num_heads': 8, + 'num_layers': 6, + 'hidden_channels': 128, + 'num_rbf': 32, + 'trainable_rbf': False, + 'max_z': 100, + 'cutoff': 5.0, + 'max_num_neighbors': 32, + 'vertex': False, + 'reduce_op': "sum", + 'mean': 0.0, + 'std': 1.0, + 'derivative': False, + 'atomref': None, + } + + pyg_model = PyVisNet(**kwargs) + visnet_model = ViSNet(pyg_model) rand_equal = [ (p1 == p2).all() - for p1, p2 in zip(rand_model1.parameters(), model.parameters()) + for p1, p2 in zip(pyg_model.parameters(), visnet_model.parameters()) ] - assert sum(rand_equal) < len(rand_equal) + assert sum(rand_equal) == len(rand_equal) From c85b995536158f28d02ec56b2af0234e9ee6374b Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Fri, 9 Feb 2024 00:43:14 -0500 Subject: [PATCH 36/42] Update mtenn/tests/test_model_config.py Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/tests/test_model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index d9ef436..b158084 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -114,7 +114,7 @@ def test_visnet_from_pyg(): } pyg_model = PyVisNet(**kwargs) - visnet_model = ViSNet(pyg_model) + visnet_model = ViSNet(model=pyg_model) rand_equal = [ (p1 == p2).all() From ec8b6c236802dc930067b7d4d3509a7629b94b83 Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Fri, 9 Feb 2024 00:44:41 -0500 Subject: [PATCH 37/42] minor changes --- mtenn/tests/test_model_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index b158084..e5f5096 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -116,8 +116,8 @@ def test_visnet_from_pyg(): pyg_model = PyVisNet(**kwargs) visnet_model = ViSNet(model=pyg_model) - rand_equal = [ + params_equal = [ (p1 == p2).all() for p1, p2 in zip(pyg_model.parameters(), visnet_model.parameters()) ] - assert sum(rand_equal) == len(rand_equal) + assert sum(params_equal) == len(params_equal) From 4a539196cdbef5bdf3a4a55dd5a9c337b072ecb2 Mon Sep 17 00:00:00 2001 From: fyng Date: Fri, 9 Feb 2024 13:50:54 -0500 Subject: [PATCH 38/42] update visnet instantiation of pyg and test --- mtenn/conversion_utils/visnet.py | 40 ++++++++++++++++---------------- mtenn/tests/test_model_config.py | 8 +++---- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index bcba217..2c322c7 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -43,26 +43,26 @@ def __init__(self, *args, model=None, **kwargs): self.visnet = PygVisNet(*args, **kwargs) else: atomref = model.prior_model.atomref.weight.detach().clone() - model_params = ( - model.lmax, - model.vecnorm_type, - model.trainable_vecnorm, - model.num_heads, - model.num_layers, - model.hidden_channels, - model.num_rbf, - model.trainable_rbf, - model.max_z, - model.cutoff, - model.max_num_neighbors, - model.vertex, - model.reduce_op, - model.mean, - model.std, - model.derivative, # not used. originally calculates "force" from energy - atomref, - ) - self.visnet = PygVisNet(*model_params) + model_params = { + 'lmax': model.representation_model.lmax, + 'vecnorm_type': model.representation_model.vecnorm_type, + 'trainable_vecnorm': model.representation_model.trainable_vecnorm, + 'num_heads': model.representation_model.num_heads, + 'num_layers': model.representation_model.num_layers, + 'hidden_channels': model.representation_model.hidden_channels, + 'num_rbf': model.representation_model.num_rbf, + 'trainable_rbf': model.representation_model.trainable_rbf, + 'max_z': model.representation_model.max_z, + 'cutoff': model.representation_model.cutoff, + 'reduce_op': model.representation_model.max_num_neighbors, + 'vertex': True if model.representation_model.vis_mp_layers[0].__class__.__name__ == 'ViS_MP_Vertex' else False, + 'reduce_op': model.reduce_op, + 'mean': model.mean, + 'std': model.std, + 'derivative': model.derivative, # not used. originally calculates "force" from energy + 'atomref': atomref, + } + self.visnet = PygVisNet(*model_params,**kwargs) self.load_state_dict(model.state_dict()) self.readout = EquivariantVecToScalar(self.visnet.mean, self.visnet.reduce_op) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index e5f5096..338a1d9 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -92,8 +92,8 @@ def test_random_seed_visnet(): @pytest.mark.skipif(not HAS_VISNET, reason="requires VisNet from nightly PyG") def test_visnet_from_pyg(): from torch_geometric.nn.models import ViSNet as PyVisNet - from mtenn.conversion_utils.visnet import ViSNet - kwargs={ + from mtenn.conversion_utils import ViSNet + args={ 'lmax': 1, 'vecnorm_type': None, 'trainable_vecnorm': False, @@ -113,8 +113,8 @@ def test_visnet_from_pyg(): 'atomref': None, } - pyg_model = PyVisNet(**kwargs) - visnet_model = ViSNet(model=pyg_model) + pyg_model = PyVisNet(*args) + visnet_model = ViSNet(model=pyg_model, **kwargs) params_equal = [ (p1 == p2).all() From 3458514d8080ea0f67bae66bdb85a7b3b089b2cd Mon Sep 17 00:00:00 2001 From: fyng Date: Fri, 9 Feb 2024 15:07:04 -0500 Subject: [PATCH 39/42] fix mtenn visnet instantiation from pyg visnet --- mtenn/conversion_utils/visnet.py | 8 ++++---- mtenn/tests/test_model_config.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 2c322c7..4b96f23 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -6,6 +6,7 @@ import torch from torch.autograd import grad from torch_geometric.utils import scatter +from torch_geometric.nn.models.visnet import ViS_MP_Vertex from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy @@ -28,7 +29,6 @@ def __init__(self, mean, reduce_op): def forward(self, x): # dummy variable. all atoms from the same molecule and the same batch batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device) - y = scatter(x, batch, dim=0, reduce=self.reduce_op) return y + self.mean @@ -55,15 +55,15 @@ def __init__(self, *args, model=None, **kwargs): 'max_z': model.representation_model.max_z, 'cutoff': model.representation_model.cutoff, 'reduce_op': model.representation_model.max_num_neighbors, - 'vertex': True if model.representation_model.vis_mp_layers[0].__class__.__name__ == 'ViS_MP_Vertex' else False, + 'vertex': isinstance(model.representation_model.vis_mp_layers[0], ViS_MP_Vertex), 'reduce_op': model.reduce_op, 'mean': model.mean, 'std': model.std, 'derivative': model.derivative, # not used. originally calculates "force" from energy 'atomref': atomref, } - self.visnet = PygVisNet(*model_params,**kwargs) - self.load_state_dict(model.state_dict()) + self.visnet = PygVisNet(**model_params) + self.visnet.load_state_dict(model.state_dict()) self.readout = EquivariantVecToScalar(self.visnet.mean, self.visnet.reduce_op) diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py index 338a1d9..2ef1919 100644 --- a/mtenn/tests/test_model_config.py +++ b/mtenn/tests/test_model_config.py @@ -93,7 +93,7 @@ def test_random_seed_visnet(): def test_visnet_from_pyg(): from torch_geometric.nn.models import ViSNet as PyVisNet from mtenn.conversion_utils import ViSNet - args={ + model_params={ 'lmax': 1, 'vecnorm_type': None, 'trainable_vecnorm': False, @@ -113,8 +113,8 @@ def test_visnet_from_pyg(): 'atomref': None, } - pyg_model = PyVisNet(*args) - visnet_model = ViSNet(model=pyg_model, **kwargs) + pyg_model = PyVisNet(**model_params) + visnet_model = ViSNet(model=pyg_model) params_equal = [ (p1 == p2).all() From cd2086bcf49efbc0a1e52583d200c2c707d21e18 Mon Sep 17 00:00:00 2001 From: fyng Date: Fri, 9 Feb 2024 15:10:50 -0500 Subject: [PATCH 40/42] guard import --- mtenn/conversion_utils/visnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 4b96f23..feee33e 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -6,7 +6,6 @@ import torch from torch.autograd import grad from torch_geometric.utils import scatter -from torch_geometric.nn.models.visnet import ViS_MP_Vertex from mtenn.model import GroupedModel, Model from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy @@ -16,6 +15,7 @@ try: from torch_geometric.nn.models import ViSNet as PygVisNet + from torch_geometric.nn.models.visnet import ViS_MP_Vertex HAS_VISNET = True except ImportError: warnings.warn("VisNet import error. Is your PyG >=2.5.0? Refer to issue #42", ImportWarning) From d5c1b4685c6200c5c2411eb2d8e6ba5756644060 Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 12 Feb 2024 14:54:39 -0500 Subject: [PATCH 41/42] update docstrings --- mtenn/conversion_utils/visnet.py | 53 ++++++++++---------------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index feee33e..cd1e079 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -72,14 +72,15 @@ def forward(self, data): Computes the energies or properties (forces) for a batch of molecules. - Args: - z (torch.Tensor): The atomic numbers. - pos (torch.Tensor): The coordinates of the atoms. - batch (torch.Tensor): A batch vector, - which assigns each node to a specific example. - - Returns: - x (torch.Tensor): Scalar output based on node features and vector features. + Args + ------- + Data. A dictionary of atomic point clouds. Contains the following fields: + z (torch.Tensor): The atomic numbers. + pos (torch.Tensor): The coordinates of the atoms. + + Returns + ------- + x (torch.Tensor): vector output based on node features and vector features. """ pos = data["pos"] z = data["z"] @@ -98,15 +99,10 @@ def _get_representation(self): """ Input model, remove last layer. - Parameters - ---------- - model: SchNet - SchNet model - Returns ------- - SchNet - Copied SchNet model with the last layer replaced by an Identity module + ViSNet + Copied ViSNet model, removing the last MLP layer that takes vector representation to scalar output. """ ## Copy model so initial model isn't affected @@ -118,15 +114,10 @@ def _get_energy_func(self): """ Return last layer of the model (outputs scalar value) - Parameters - ---------- - model: SchNet - SchNet model - Returns ------- - torch.nn.modules.linear.Linear - Copy of `model`'s last layer + torch.nn.Module + Copy of `model`'s last layer, which is an instance of EquivariantVecToScalar() class """ return deepcopy(self.readout) @@ -134,11 +125,6 @@ def _get_delta_strategy(self): """ Build a DeltaStrategy object based on the passed model. - Parameters - ---------- - model: SchNet - SchNet model - Returns ------- DeltaStrategy @@ -176,11 +162,9 @@ def get_model( Parameters ---------- model: VisNet, optional - VisNet model to use to build the Model object. If left as none, a - default model will be initialized and used + VisNet model to use to build the Model object. If left as none, a default model will be initialized and used grouped: bool, default=False - Whether this model should accept groups of inputs or one input at a - time. + Whether this model should accept groups of inputs or one input at a time. fix_device: bool, default=False If True, make sure the input is on the same device as the model, copying over as necessary. @@ -188,12 +172,9 @@ def get_model( Strategy to use to combine representation of the different parts. Options are ['delta', 'concat', 'complex'] combination: Combination, optional - Combination object to use to combine predictions in a group. A value - must be passed if `grouped` is `True`. + Combination object to use to combine predictions in a group. A value must be passed if `grouped` is `True`. pred_readout : Readout - Readout object for the energy predictions. If `grouped` is `False`, - this option will still be used in the construction of the `Model` - object. + Readout object for the energy predictions. If `grouped` is `False`, this option will still be used in the construction of the `Model` object. comb_readout : Readout Readout object for the combination output. From 3924d3f96ec448a71af49b6d69ef1d84f546547d Mon Sep 17 00:00:00 2001 From: fyng Date: Mon, 12 Feb 2024 15:11:30 -0500 Subject: [PATCH 42/42] update doc strings --- mtenn/conversion_utils/visnet.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index cd1e079..ac31032 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -69,18 +69,23 @@ def __init__(self, *args, model=None, **kwargs): def forward(self, data): """ - Computes the energies or properties (forces) for a batch of - molecules. + Computes the vector representation of a molecule from its atomic numbers and positional coordinates, to produce a vector representation. The output vector can be used for further molecular analysis or prediction tasks. - Args - ------- - Data. A dictionary of atomic point clouds. Contains the following fields: - z (torch.Tensor): The atomic numbers. - pos (torch.Tensor): The coordinates of the atoms. + Parameters + ---------- + data : dict[str, torch.Tensor] + A dictionary containing the atomic point clouds of a molecule. It should have the following keys: + - z (torch.Tensor): The atomic numbers of the atoms in the molecule. + - pos (torch.Tensor): The 3D coordinates (x, y, z) of each atom in the molecule. Returns ------- - x (torch.Tensor): vector output based on node features and vector features. + torch.Tensor + A tensor representing the vector output of the molecule + + Notes + ----- + Assumes all atoms passed in `data` belong to the same molecule and are processed in a single batch. """ pos = data["pos"] z = data["z"]