From f212e3d2afeb4998798dfd27885155f8f44b07f8 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 9 Nov 2023 16:51:01 -0500 Subject: [PATCH 01/12] First pass on unification. --- mtenn/conversion_utils/e3nn.py | 6 +++--- mtenn/conversion_utils/schnet.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 188473d..7f303fe 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -10,12 +10,12 @@ class E3NN(Network): - def __init__(self, model=None, model_kwargs=None): + def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct E3NN model with model_kwargs, ## otherwise copy all parameters and weights over if model is None: - super(E3NN, self).__init__(**model_kwargs) - self.model_parameters = model_kwargs + super(E3NN, self).__init__(*args, **kwargs) + self.model_parameters = kwargs else: # this will need changing to include model features of e3nn atomref = model.atomref.weight.detach().clone() diff --git a/mtenn/conversion_utils/schnet.py b/mtenn/conversion_utils/schnet.py index c696915..8beb476 100644 --- a/mtenn/conversion_utils/schnet.py +++ b/mtenn/conversion_utils/schnet.py @@ -10,11 +10,11 @@ class SchNet(PygSchNet): - def __init__(self, model=None): + def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct default SchNet model, otherwise copy ## all parameters and weights over if model is None: - super(SchNet, self).__init__() + super(SchNet, self).__init__(*args, **kwargs) else: try: atomref = model.atomref.weight.detach().clone() From 0c39b9816ce3aa54d68b845b5d15ef00f2909b40 Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 29 Nov 2023 11:31:46 -0500 Subject: [PATCH 02/12] Add config classes. --- mtenn/config.py | 723 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 723 insertions(+) create mode 100644 mtenn/config.py diff --git a/mtenn/config.py b/mtenn/config.py new file mode 100644 index 0000000..1e4c755 --- /dev/null +++ b/mtenn/config.py @@ -0,0 +1,723 @@ +from __future__ import annotations + +import abc +from enum import Enum +from pydantic import BaseModel, Field, root_validator +from typing import Callable, ClassVar + +import mtenn + + +class ModelType(str, Enum): + """ + Enum for model types + + GAT: Graph Attention Network + schnet: SchNet + e3nn: E(3)-equivariant neural network + INVALID: Invalid model type to catch instantiation errors + """ + + GAT = "GAT" + schnet = "schnet" + e3nn = "e3nn" + INVALID = "INVALID" + + @classmethod + def get_values(cls) -> list[str]: + return [member.value for member in cls] + + @classmethod + def reverse_lookup(cls, value): + return cls(value) + + @classmethod + def get_names(cls) -> list[str]: + return [member.name for member in cls] + + +class StrategyConfig(str, Enum): + """ + Enum for possible MTENN Strategy classes. + """ + + # delta G strategy + delta = "delta" + # ML concatenation strategy + concat = "concat" + # Complex-only strategy + complex = "complex" + + +class ReadoutConfig(str, Enum): + """ + Enum for possible MTENN Readout classes. + """ + + pic50 = "pic50" + + +class CombinationConfig(str, Enum): + """ + Enum for possible MTENN Readout classes. + """ + + mean = "mean" + max = "max" + boltzmann = "boltzmann" + + +class ModelConfigBase(BaseModel): + model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID + + # Shared parameters for MTENN + grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.") + strategy: StrategyConfig = Field( + StrategyConfig.delta, + description=( + "Which Strategy to use for combining complex, protein, and ligand " + "representations in the MTENN Model. " + f"Options are [{', '.join(StrategyConfig.get_values())}]." + ), + ) + pred_readout: ReadoutConfig | None = Field( + None, + description=( + "Which Readout to use for the model predictions. This corresponds " + "to the individual pose predictions in the case of a GroupedModel. " + f"Options are [{', '.join(ReadoutConfig.get_values())}]." + ), + ) + combination: CombinationConfig | None = Field( + None, + description=( + "Which Combination to use for combining predictions in a GroupedModel. " + f"Options are [{', '.join(CombinationConfig.get_values())}]." + ), + ) + comb_readout: ReadoutConfig | None = Field( + None, + description=( + "Which Readout to use for the combined model predictions. This is only " + "relevant in the case of a GroupedModel. " + f"Options are [{', '.join(ReadoutConfig.get_values())}]." + ), + ) + + # Parameters for MaxCombination + max_comb_neg: bool = Field( + True, + description=( + "Whether to take the min instead of max when combining pose predictions " + "with MaxCombination." + ), + ) + max_comb_scale: float = Field( + 1000, + description=( + "Scaling factor for values when taking the max/min when combining pose " + "predictions with MaxCombination. A value of 1 will approximate the " + "Boltzmann mean, while a larger value will more accurately approximate the " + "max/min operation." + ), + ) + + # Parameters for PIC50Readout for pred_readout + pred_substrate: float | None = Field( + None, + description=( + "Substrate concentration to use when using the Cheng-Prusoff equation to " + "convert deltaG -> IC50 in PIC50Readout for pred_readout. Assumed to be in " + "the same units as pred_km." + ), + ) + pred_km: float | None = Field( + None, + description=( + "Km value to use when using the Cheng-Prusoff equation to convert " + "deltaG -> IC50 in PIC50Readout for pred_readout. Assumed to be in " + "the same units as pred_substrate." + ), + ) + + # Parameters for PIC50Readout for comb_readout + comb_substrate: float | None = Field( + None, + description=( + "Substrate concentration to use when using the Cheng-Prusoff equation to " + "convert deltaG -> IC50 in PIC50Readout for comb_readout. Assumed to be in " + "the same units as comb_km." + ), + ) + comb_km: float | None = Field( + None, + description=( + "Km value to use when using the Cheng-Prusoff equation to convert " + "deltaG -> IC50 in PIC50Readout for comb_readout. Assumed to be in " + "the same units as comb_substrate." + ), + ) + + @abc.abstractmethod + def _build(self, mtenn_params={}) -> mtenn.model.Model: + ... + + def build(self) -> mtenn.model.Model: + # First handle the MTENN classes + match self.combination: + case CombinationConfig.mean: + mtenn_combination = mtenn.combination.MeanCombination() + case CombinationConfig.max: + mtenn_combination = mtenn.combination.MaxCombination( + neg=self.max_comb_neg, scale=self.max_comb_scale + ) + case CombinationConfig.boltzmann: + mtenn_combination = mtenn.combination.BoltzmannCombination() + case None: + mtenn_combination = None + + match self.pred_readout: + case ReadoutConfig.pic50: + mtenn_pred_readout = mtenn.readout.PIC50Readout( + substrate=self.pred_substrate, Km=self.pred_km + ) + case None: + mtenn_pred_readout = None + + match self.comb_readout: + case ReadoutConfig.pic50: + mtenn_comb_readout = mtenn.readout.PIC50Readout( + substrate=self.comb_substrate, Km=self.comb_km + ) + case None: + mtenn_comb_readout = None + + mtenn_params = { + "combination": mtenn_combination, + "pred_readout": mtenn_pred_readout, + "comb_readout": mtenn_comb_readout, + } + + # Build the actual Model + return self._build(mtenn_params) + + def update(self, config_updates={}) -> ModelConfigBase: + return self._update(config_updates) + + def _update(self, config_updates={}) -> ModelConfigBase: + """ + Default version of this function. Just update original config with new options, + and generate new object. Designed to be overloaded if there are specific things + that a class needs to handle (see GATModelConfig as an example). + """ + + orig_config = self.dict() + + # Get new config by overwriting old stuff with any new stuff + new_config = orig_config | config_updates + + return type(self)(**new_config) + + @staticmethod + def _check_grouped(values): + """ + Makes sure that a Combination method is passed if using a GroupedModel. Only + needs to be called for structure-based models. + """ + if values["grouped"] and (not values["combination"]): + raise ValueError("combination must be specified for a GroupedModel.") + + +class GATModelConfig(ModelConfigBase): + """ + Class for constructing a GAT ML model. Note that there are two methods for defining + the size of the model: + * If single values are passed for all parameters, the value of `num_layers` will be + used as the size of the model, and each layer will have the parameters given + * If a list of values is passed for any parameters, all parameters must be lists of + the same size, or single values. For parameters that are single values, that same + value will be used for each layer. For parameters that are lists, those lists will + be used + + Parameters passed as strings are assumed to be comma-separated lists, and will first + be cast to lists of the appropriate type, and then processed as described above. + + If lists of multiple different (non-1) sizes are found, an error will be raised. + + Default values here are the default values given in DGL-LifeSci. + """ + + from dgllife.utils import CanonicalAtomFeaturizer + + LIST_PARAMS: ClassVar[dict] = { + "hidden_feats": int, + "num_heads": int, + "feat_drops": float, + "attn_drops": float, + "alphas": float, + "residuals": bool, + "agg_modes": str, + "activations": None, + "biases": bool, + } + + model_type: ClassVar[ModelType.GAT] = ModelType.GAT + + in_feats: int = Field( + CanonicalAtomFeaturizer().feat_size(), + description=( + "Input node feature size. Defaults to size of the CanonicalAtomFeaturizer." + ), + ) + num_layers: int = Field( + 2, + description=( + "Number of GAT layers. Ignored if a list of values is passed for any " + "other argument." + ), + ) + hidden_feats: str | list[int] = Field( + 32, + description=( + "Output size of each GAT layer. If an int is passed, the value for " + "num_layers will be used to determine the size of the model. If a list of " + "ints is passed, the size of the model will be inferred from the length of " + "the list." + ), + ) + num_heads: str | list[int] = Field( + 4, + description=( + "Number of attention heads for each GAT layer. Passing an int or list of " + "ints functions similarly as for hidden_feats." + ), + ) + feat_drops: str | list[float] = Field( + 0, + description=( + "Dropout of input features for each GAT layer. Passing an float or list of " + "floats functions similarly as for hidden_feats." + ), + ) + attn_drops: str | list[float] = Field( + 0, + description=( + "Dropout of attention values for each GAT layer. Passing an float or list " + "of floats functions similarly as for hidden_feats." + ), + ) + alphas: str | list[float] = Field( + 0.2, + description=( + "Hyperparameter for LeakyReLU gate for each GAT layer. Passing an float or " + "list of floats functions similarly as for hidden_feats." + ), + ) + residuals: str | list[bool] = Field( + True, + description=( + "Whether to use residual connection for each GAT layer. Passing a bool or " + "list of bools functions similarly as for hidden_feats." + ), + ) + agg_modes: str | list[str] = Field( + "flatten", + description=( + "Which aggregation mode [flatten, mean] to use for each GAT layer. " + "Passing a str or list of strs functions similarly as for hidden_feats." + ), + ) + activations: list[Callable] | None = Field( + None, + description=( + "Activation function for each GAT layer. Passing a function or " + "list of functions functions similarly as for hidden_feats." + ), + ) + biases: str | list[bool] = Field( + True, + description=( + "Whether to use bias for each GAT layer. Passing a bool or " + "list of bools functions similarly as for hidden_feats." + ), + ) + allow_zero_in_degree: bool = Field( + False, description="Allow zero in degree nodes for all graph layers." + ) + + # Internal tracker for if the parameters were originally built from lists or using + # num_layers + _from_num_layers = False + + @root_validator(pre=False) + def massage_into_lists(cls, values) -> GATModelConfig: + # First convert string lists to actual lists + for param, param_type in cls.LIST_PARAMS.items(): + param_val = values[param] + if isinstance(param_val, str): + try: + param_val = list(map(param_type, param_val.split(","))) + except ValueError: + raise ValueError( + f"Unable to parse value {param_val} for parameter {param}. " + f"Expected type of {param_type}." + ) + values[param] = param_val + + # Get sizes of all lists + list_lens = {} + for p in cls.LIST_PARAMS: + param_val = values[p] + if not isinstance(param_val, list): + # Shouldn't be possible at this point but just in case + param_val = [param_val] + values[p] = param_val + list_lens[p] = len(param_val) + + # Check that there's only one length present + list_lens_set = set(list_lens.values()) + # This could be 0 if lists of length 1 were passed, which is valid + if len(list_lens_set - {1}) > 1: + raise ValueError( + "All passed parameter lists must be the same value. " + f"Instead got list lengths of: {list_lens}" + ) + elif list_lens_set == {1}: + # If all lists have only one value, we defer to the value passed to + # num_layers, as described in the class docstring + num_layers = values["num_layers"] + values["_from_num_layers"] = True + else: + num_layers = max(list_lens_set) + values["_from_num_layers"] = False + + values["num_layers"] = num_layers + # If we just want a model with one layer, can return early since we've already + # converted everything into lists + if num_layers == 1: + return values + + # Adjust any length 1 list to be the right length + for p, list_len in list_lens.items(): + if list_len == 1: + values[p] = values[p] * num_layers + + return values + + def _build(self, mtenn_params={}): + """ + Build an MTENN GAT Model from this config. + + Parameters + ---------- + mtenn_params: dict + Dict giving the MTENN Readout. This will be passed by the `build` method in + the abstract base class + + Returns + ------- + mtenn.model.Model + MTENN GAT LigandOnlyModel + """ + from mtenn.conversion_utils import GAT + + model = GAT( + in_feats=self.in_feats, + hidden_feats=self.hidden_feats, + num_heads=self.num_heads, + feat_drops=self.feat_drops, + attn_drops=self.attn_drops, + alphas=self.alphas, + residuals=self.residuals, + agg_modes=self.agg_modes, + activations=self.activations, + biases=self.biases, + allow_zero_in_degree=self.allow_zero_in_degree, + ) + + pred_readout = mtenn_params.get("pred_readout", None) + return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True) + + def _update(self, config_updates={}) -> GATModelConfig: + orig_config = self.dict() + if self._from_num_layers: + # If originally generated from num_layers, want to pull out the first entry + # in each list param so it can be re-broadcast with (potentially) new + # num_layers + for param_name in GATModelConfig.LIST_PARAMS.keys(): + orig_config[param_name] = orig_config[param_name][0] + + # Get new config by overwriting old stuff with any new stuff + new_config = orig_config | config_updates + + # A bit hacky, maybe try and change? + if isinstance(new_config["activations"], list) and ( + new_config["activations"][0] is None + ): + new_config["activations"] = None + + return GATModelConfig(**new_config) + + +class SchNetModelConfig(ModelConfigBase): + """ + Class for constructing a SchNet ML model. Default values here are the default values + given in PyG. + """ + + model_type: ClassVar[ModelType.schnet] = ModelType.schnet + + hidden_channels: int = Field(128, description="Hidden embedding size.") + num_filters: int = Field( + 128, description="Number of filters to use in the cfconv layers." + ) + num_interactions: int = Field(6, description="Number of interaction blocks.") + num_gaussians: int = Field( + 50, description="Number of gaussians to use in the interaction blocks." + ) + interaction_graph: Callable | None = Field( + None, + description=( + "Function to compute the pairwise interaction graph and " + "interatomic distances." + ), + ) + cutoff: float = Field( + 10, description="Cutoff distance for interatomic interactions." + ) + max_num_neighbors: int = Field( + 32, description="Maximum number of neighbors to collect for each node." + ) + readout: str = Field( + "add", description="Which global aggregation to use [add, mean]." + ) + dipole: bool = Field( + False, + description=( + "Whether to use the magnitude of the dipole moment to make the " + "final prediction." + ), + ) + mean: float | None = Field( + None, + description=( + "Mean of property to predict, to be added to the model prediction before " + "returning. This value is only used if dipole is False and a value is also " + "passed for std." + ), + ) + std: float | None = Field( + None, + description=( + "Standard deviation of property to predict, used to scale the model " + "prediction before returning. This value is only used if dipole is False " + "and a value is also passed for mean." + ), + ) + atomref: list[float] | None = Field( + None, + description=( + "Reference values for single-atom properties. Should have length of 100 to " + "match with PyG." + ), + ) + + @root_validator(pre=False) + def validate(cls, values): + # Make sure the grouped stuff is properly assigned + ModelConfigBase._check_grouped(values) + + # Make sure atomref length is correct (this is required by PyG) + atomref = values["atomref"] + if (atomref is not None) and (len(atomref) != 100): + raise ValueError(f"atomref must be length 100 (got {len(atomref)})") + + return values + + def _build(self, mtenn_params={}): + """ + Build an MTENN SchNet Model from this config. + + Parameters + ---------- + mtenn_params: dict + Dict giving the MTENN Readout. This will be passed by the `build` method in + the abstract base class + + Returns + ------- + mtenn.model.Model + MTENN SchNet Model/GroupedModel + """ + from mtenn.conversion_utils import SchNet + + # Create an MTENN SchNet model from PyG SchNet model + model = SchNet( + hidden_channels=self.hidden_channels, + num_filters=self.num_filters, + num_interactions=self.num_interactions, + num_gaussians=self.num_gaussians, + interaction_graph=self.interaction_graph, + cutoff=self.cutoff, + max_num_neighbors=self.max_num_neighbors, + readout=self.readout, + dipole=self.dipole, + mean=self.mean, + std=self.std, + atomref=self.atomref, + ) + + combination = mtenn_params.get("combination", None) + pred_readout = mtenn_params.get("pred_readout", None) + comb_readout = mtenn_params.get("comb_readout", None) + + return SchNet.get_model( + model=model, + grouped=self.grouped, + fix_device=True, + strategy=self.strategy, + combination=combination, + pred_readout=pred_readout, + comb_readout=comb_readout, + ) + + +class E3NNModelConfig(ModelConfigBase): + """ + Class for constructing an e3nn ML model. + """ + + model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn + + num_atom_types: int = Field( + 100, + description=( + "Number of different atom types. In general, this will just be the " + "max atomic number of all input atoms." + ), + ) + irreps_hidden: dict[str, int] | str = Field( + {"0": 10, "1": 3, "2": 2, "3": 1}, + description=( + "Irreps for the hidden layers of the network. " + "This can either take the form of an Irreps string, or a dict mapping " + "L levels (parity optional) to the number of Irreps of that level. " + "If parity is not passed for a given level, both parities will be used. If " + "you only want one parity for a given level, make sure you specify it. " + "A dict can also be specified as a string, in the format of a comma " + "separated list of :." + ), + ) + lig: bool = Field( + False, description="Include ligand labels as a node attribute information." + ) + irreps_edge_attr: int = Field( + 3, + description=( + "Which level of spherical harmonics to use for encoding edge attributes " + "internally." + ), + ) + num_layers: int = Field(3, description="Number of network layers.") + neighbor_dist: float = Field( + 10, description="Cutoff distance for including atoms as neighbors." + ) + num_basis: int = Field( + 10, description="Number of bases on which the edge length are projected." + ) + num_radial_layers: int = Field(1, description="Number of radial layers.") + num_radial_neurons: int = Field( + 128, description="Number of neurons in each radial layer." + ) + num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") + num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") + + @root_validator(pre=False) + def massage_irreps(cls, values): + from e3nn import o3 + + # First just check that the grouped stuff is properly assigned + ModelConfigBase._check_grouped(values) + + # Now deal with irreps + irreps = values["irreps_hidden"] + # First see if this string should be converted into a dict + if isinstance(irreps, str): + if ("," in irreps) and (":" in irreps): + orig_irreps = irreps + irreps = [i.split(":") for i in irreps.split(",")] + try: + irreps = { + irreps_l: int(num_irreps) for irreps_l, num_irreps in irreps + } + except ValueError: + raise ValueError( + f"Unable to parse irreps dict string: {orig_irreps}" + ) + else: + # If not, try and convert directly to Irreps + try: + _ = o3.Irreps(irreps) + except ValueError: + raise ValueError(f"Invalid irreps string: {irreps}") + + # If already in a good string, can just return + return values + + # If we got a dict, need to massage that into an Irreps string + # First make a copy of the input dict in case of errors + orig_irreps = irreps.copy() + # Find L levels that got an unspecified parity + unspecified_l = [k for k in irreps.keys() if ("o" not in k) and ("e" not in k)] + for irreps_l in unspecified_l: + num_irreps = irreps.pop(irreps_l) + irreps[f"{irreps_l}o"] = num_irreps + irreps[f"{irreps_l}e"] = num_irreps + + # Combine Irreps into str + irreps = "+".join( + [f"{num_irreps}x{irrep}" for irrep, num_irreps in irreps.items()] + ) + + # Make sure this Irreps string is valid + try: + _ = o3.Irreps(irreps) + except ValueError: + raise ValueError(f"Couldn't parse irreps dict: {orig_irreps}") + + values["irreps_hidden"] = irreps + return values + + def _build(self, mtenn_params={}): + from e3nn.o3 import Irreps + from mtenn.conversion_utils import E3NN + + model = E3NN( + irreps_in=f"{self.num_atom_types}x0e", + irreps_hidden=self.irreps_hidden, + irreps_out="1x0e", + irreps_node_attr="1x0e" if self.lig else None, + irreps_edge_attr=Irreps.spherical_harmonics(self.irreps_edge_attr), + layers=self.num_layers, + max_radius=self.neighbor_dist, + number_of_basis=self.num_basis, + radial_layers=self.num_radial_layers, + radial_neurons=self.num_radial_neurons, + num_neighbors=self.num_neighbors, + num_nodes=self.num_nodes, + reduce_output=True, + ) + + combination = mtenn_params.get("combination", None) + pred_readout = mtenn_params.get("pred_readout", None) + comb_readout = mtenn_params.get("comb_readout", None) + + return E3NN.get_model( + model=model, + grouped=self.grouped, + fix_device=True, + strategy=self.strategy, + combination=combination, + pred_readout=pred_readout, + comb_readout=comb_readout, + ) From 4061e2eda04ee1ded02c6359299c592701ff56f8 Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 29 Nov 2023 11:39:02 -0500 Subject: [PATCH 03/12] Update test Model construction. --- mtenn/tests/test_combination.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mtenn/tests/test_combination.py b/mtenn/tests/test_combination.py index 13355c3..ac57fd8 100644 --- a/mtenn/tests/test_combination.py +++ b/mtenn/tests/test_combination.py @@ -2,7 +2,6 @@ import numpy as np import pytest import torch -from torch_geometric.nn import SchNet as PygSchNet from mtenn.combination import MeanCombination, MaxCombination, BoltzmannCombination from mtenn.conversion_utils import SchNet @@ -11,7 +10,7 @@ @pytest.fixture() def models_and_inputs(): model_test = SchNet( - PygSchNet(hidden_channels=2, num_filters=2, num_interactions=2, num_gaussians=2) + hidden_channels=2, num_filters=2, num_interactions=2, num_gaussians=2 ) model_ref = deepcopy(model_test) model_ref = SchNet.get_model(model_ref, strategy="complex") From 686ec254bc2323c25ec9fe6168cc497109bc73ef Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 29 Nov 2023 15:40:26 -0500 Subject: [PATCH 04/12] Add StringEnum class. --- mtenn/config.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 1e4c755..8f0724a 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -8,7 +8,21 @@ import mtenn -class ModelType(str, Enum): +class StringEnum(str, Enum): + @classmethod + def get_values(cls) -> list[str]: + return [member.value for member in cls] + + @classmethod + def reverse_lookup(cls, value): + return cls(value) + + @classmethod + def get_names(cls) -> list[str]: + return [member.name for member in cls] + + +class ModelType(StringEnum): """ Enum for model types @@ -36,7 +50,7 @@ def get_names(cls) -> list[str]: return [member.name for member in cls] -class StrategyConfig(str, Enum): +class StrategyConfig(StringEnum): """ Enum for possible MTENN Strategy classes. """ @@ -49,7 +63,7 @@ class StrategyConfig(str, Enum): complex = "complex" -class ReadoutConfig(str, Enum): +class ReadoutConfig(StringEnum): """ Enum for possible MTENN Readout classes. """ @@ -57,7 +71,7 @@ class ReadoutConfig(str, Enum): pic50 = "pic50" -class CombinationConfig(str, Enum): +class CombinationConfig(StringEnum): """ Enum for possible MTENN Readout classes. """ From ca5ac119988aa5db53102ab39b92bf31ae62d82f Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 7 Dec 2023 12:03:32 -0500 Subject: [PATCH 05/12] Allow activations to be a list of None as well. --- mtenn/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index 8f0724a..48c3577 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -341,7 +341,7 @@ class GATModelConfig(ModelConfigBase): "Passing a str or list of strs functions similarly as for hidden_feats." ), ) - activations: list[Callable] | None = Field( + activations: list[Callable] | list[None] | None = Field( None, description=( "Activation function for each GAT layer. Passing a function or " From 66713e0b1ce848461c32a05891d47c85335d965d Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 7 Dec 2023 14:09:01 -0500 Subject: [PATCH 06/12] Change model_type to a Field so it gets exported with the ModelConfig. --- mtenn/config.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 48c3577..0808a19 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -82,7 +82,7 @@ class CombinationConfig(StringEnum): class ModelConfigBase(BaseModel): - model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID + model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False) # Shared parameters for MTENN grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.") @@ -172,6 +172,9 @@ class ModelConfigBase(BaseModel): ), ) + class Config: + validate_assignment = True + @abc.abstractmethod def _build(self, mtenn_params={}) -> mtenn.model.Model: ... @@ -275,7 +278,7 @@ class GATModelConfig(ModelConfigBase): "biases": bool, } - model_type: ClassVar[ModelType.GAT] = ModelType.GAT + model_type: ModelType = Field(ModelType.GAT, const=True) in_feats: int = Field( CanonicalAtomFeaturizer().feat_size(), @@ -479,7 +482,7 @@ class SchNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ClassVar[ModelType.schnet] = ModelType.schnet + model_type: ModelType = Field(ModelType.schnet, const=True) hidden_channels: int = Field(128, description="Hidden embedding size.") num_filters: int = Field( @@ -601,7 +604,7 @@ class E3NNModelConfig(ModelConfigBase): Class for constructing an e3nn ML model. """ - model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn + model_type: ModelType = Field(ModelType.e3nn, const=True) num_atom_types: int = Field( 100, From 0f8c9de00dd2660a8d6b8fcc60132b8f94cbae34 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 14 Dec 2023 10:59:15 -0500 Subject: [PATCH 07/12] Remove duplicated code. --- mtenn/config.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 0808a19..605ba1e 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -37,18 +37,6 @@ class ModelType(StringEnum): e3nn = "e3nn" INVALID = "INVALID" - @classmethod - def get_values(cls) -> list[str]: - return [member.value for member in cls] - - @classmethod - def reverse_lookup(cls, value): - return cls(value) - - @classmethod - def get_names(cls) -> list[str]: - return [member.name for member in cls] - class StrategyConfig(StringEnum): """ From 63e57cf2adfa7ce44da5659cbbd62752522c9365 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 4 Jan 2024 13:21:20 -0500 Subject: [PATCH 08/12] Adjust checking for how to process irreps string. --- mtenn/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index 605ba1e..5e377f5 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -648,7 +648,7 @@ def massage_irreps(cls, values): irreps = values["irreps_hidden"] # First see if this string should be converted into a dict if isinstance(irreps, str): - if ("," in irreps) and (":" in irreps): + if ":" in irreps: orig_irreps = irreps irreps = [i.split(":") for i in irreps.split(",")] try: From caad805134ab1c10b2aa61136f0aa7195bb34743 Mon Sep 17 00:00:00 2001 From: kaminow Date: Fri, 5 Jan 2024 11:18:38 -0500 Subject: [PATCH 09/12] Add optional random seed for generating model weights. --- mtenn/config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mtenn/config.py b/mtenn/config.py index 5e377f5..5e943ac 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -3,9 +3,12 @@ import abc from enum import Enum from pydantic import BaseModel, Field, root_validator +import random from typing import Callable, ClassVar import mtenn +import numpy as np +import torch class StringEnum(str, Enum): @@ -72,6 +75,11 @@ class CombinationConfig(StringEnum): class ModelConfigBase(BaseModel): model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False) + # Random seed optional for reproducibility + rand_seed: int | None = Field( + None, type=int, description="Random seed to set for Python, PyTorch, and NumPy." + ) + # Shared parameters for MTENN grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.") strategy: StrategyConfig = Field( @@ -168,6 +176,12 @@ def _build(self, mtenn_params={}) -> mtenn.model.Model: ... def build(self) -> mtenn.model.Model: + # First set random seeds if applicable + if self.rand_seed is not None: + random.seed(self.rand_seed) + torch.manual_seed(self.rand_seed) + np.random.seed(self.rand_seed) + # First handle the MTENN classes match self.combination: case CombinationConfig.mean: From db209b1d9ab64fadd3c144b6ba7b4e49d8f1ae52 Mon Sep 17 00:00:00 2001 From: kaminow Date: Fri, 5 Jan 2024 12:04:45 -0500 Subject: [PATCH 10/12] Add model weights to config. --- mtenn/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index 5e943ac..bf92c73 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -80,6 +80,9 @@ class ModelConfigBase(BaseModel): None, type=int, description="Random seed to set for Python, PyTorch, and NumPy." ) + # Model weights + model_weights: dict | None = Field(None, type=dict, description="Model weights.") + # Shared parameters for MTENN grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.") strategy: StrategyConfig = Field( @@ -218,7 +221,13 @@ def build(self) -> mtenn.model.Model: } # Build the actual Model - return self._build(mtenn_params) + model = self._build(mtenn_params) + + # Set model weights + if self.model_weights: + model.load_state_dict(self.model_weights) + + return model def update(self, config_updates={}) -> ModelConfigBase: return self._update(config_updates) From 590a199e0b5d7ba769c7cd3a7b327c6eb046fac3 Mon Sep 17 00:00:00 2001 From: kaminow Date: Mon, 8 Jan 2024 10:18:56 -0500 Subject: [PATCH 11/12] Add tests to make sure the random seed is working the right way. --- mtenn/tests/test_model_config.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 mtenn/tests/test_model_config.py diff --git a/mtenn/tests/test_model_config.py b/mtenn/tests/test_model_config.py new file mode 100644 index 0000000..dd16da7 --- /dev/null +++ b/mtenn/tests/test_model_config.py @@ -0,0 +1,67 @@ +from mtenn.config import GATModelConfig, E3NNModelConfig, SchNetModelConfig + + +def test_random_seed_gat(): + rand_config = GATModelConfig() + set_config = GATModelConfig(rand_seed=10) + + rand_model1 = rand_config.build() + rand_model2 = rand_config.build() + set_model1 = set_config.build() + set_model2 = set_config.build() + + rand_equal = [ + (p1 == p2).all() + for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + ] + assert sum(rand_equal) < len(rand_equal) + + set_equal = [ + (p1 == p2).all() + for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) + ] + assert sum(set_equal) == len(set_equal) + + +def test_random_seed_e3nn(): + rand_config = E3NNModelConfig() + set_config = E3NNModelConfig(rand_seed=10) + + rand_model1 = rand_config.build() + rand_model2 = rand_config.build() + set_model1 = set_config.build() + set_model2 = set_config.build() + + rand_equal = [ + (p1 == p2).all() + for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + ] + assert sum(rand_equal) < len(rand_equal) + + set_equal = [ + (p1 == p2).all() + for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) + ] + assert sum(set_equal) == len(set_equal) + + +def test_random_seed_schnet(): + rand_config = SchNetModelConfig() + set_config = SchNetModelConfig(rand_seed=10) + + rand_model1 = rand_config.build() + rand_model2 = rand_config.build() + set_model1 = set_config.build() + set_model2 = set_config.build() + + rand_equal = [ + (p1 == p2).all() + for p1, p2 in zip(rand_model1.parameters(), rand_model2.parameters()) + ] + assert sum(rand_equal) < len(rand_equal) + + set_equal = [ + (p1 == p2).all() + for p1, p2 in zip(set_model1.parameters(), set_model2.parameters()) + ] + assert sum(set_equal) == len(set_equal) From 061037d2a3c3aeb8a056c96374f2711b614579c3 Mon Sep 17 00:00:00 2001 From: kaminow <51923685+kaminow@users.noreply.github.com> Date: Mon, 8 Jan 2024 10:23:49 -0500 Subject: [PATCH 12/12] Add pydantic to CI conda env. --- devtools/conda-envs/test_env.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index bde3c8a..b8221be 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -18,4 +18,4 @@ dependencies: - pytest - pytest-cov - codecov - + - pydantic >=1.10.8,<2.0.0a0