From 580845335e8e151801ef6db4a16d4f3a426f7a7d Mon Sep 17 00:00:00 2001 From: Hugo MacDermott-Opeskin Date: Mon, 30 Sep 2024 13:07:13 +1000 Subject: [PATCH 1/8] migrate? --- devtools/conda-envs/mtenn.yaml | 4 +++- devtools/conda-envs/test_env.yaml | 2 +- mtenn/config.py | 30 ++++++++++++++++-------------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/devtools/conda-envs/mtenn.yaml b/devtools/conda-envs/mtenn.yaml index 8139a84..2902743 100644 --- a/devtools/conda-envs/mtenn.yaml +++ b/devtools/conda-envs/mtenn.yaml @@ -3,10 +3,11 @@ channels: - conda-forge dependencies: - pytorch - - pytorch_geometric + - pytorch_geometric >=2.5.0 - pytorch_cluster - pytorch_scatter - pytorch_sparse + - pydantic >=2.0.0a0 - numpy - h5py - e3nn @@ -14,3 +15,4 @@ dependencies: - dgl - rdkit - ase + - fsspec diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 9fce2b8..e140284 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -7,6 +7,7 @@ dependencies: - pytorch_cluster - pytorch_scatter - pytorch_sparse + - pydantic >=2.0.0a0 - numpy - h5py - e3nn @@ -19,5 +20,4 @@ dependencies: - pytest - pytest-cov - codecov - - pydantic >=1.10.8,<2.0.0a0 - fsspec diff --git a/mtenn/config.py b/mtenn/config.py index a8d6eef..9ffcb10 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -15,9 +15,9 @@ import abc from enum import Enum -from pydantic import BaseModel, Field, root_validator +from pydantic import model_validator, ConfigDict, BaseModel, Field import random -from typing import Callable, ClassVar +from typing import Literal, Callable, ClassVar import mtenn.combination import mtenn.readout import mtenn.model @@ -140,7 +140,7 @@ class ModelConfigBase(BaseModel): to implement the ``_build`` method in order to be used. """ - model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False) + model_type: ModelType = Field(ModelType.INVALID, const=True, frozen=True) # Random seed optional for reproducibility rand_seed: int | None = Field( @@ -240,9 +240,7 @@ class ModelConfigBase(BaseModel): "``comb_substrate``." ), ) - - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) def build(self) -> mtenn.model.Model: """ @@ -436,7 +434,7 @@ class GATModelConfig(ModelConfigBase): "biases": bool, } #: :meta private: - model_type: ModelType = Field(ModelType.GAT, const=True) + model_type: Literal[ModelType.GAT] = ModelType.GAT in_feats: int = Field( _CanonicalAtomFeaturizer().feat_size(), @@ -527,7 +525,8 @@ class GATModelConfig(ModelConfigBase): # num_layers _from_num_layers = False - @root_validator(pre=False) + @model_validator() + @classmethod def massage_into_lists(cls, values) -> GATModelConfig: """ Validator to handle unifying all the values into the proper list forms based on @@ -681,7 +680,7 @@ class SchNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ModelType = Field(ModelType.schnet, const=True) + model_type: Literal[ModelType.schnet] = ModelType.schnet hidden_channels: int = Field(128, description="Hidden embedding size.") num_filters: int = Field( @@ -738,7 +737,8 @@ class SchNetModelConfig(ModelConfigBase): ), ) - @root_validator(pre=False) + @model_validator() + @classmethod def validate(cls, values): # Make sure the grouped stuff is properly assigned ModelConfigBase._check_grouped(values) @@ -816,7 +816,7 @@ class E3NNModelConfig(ModelConfigBase): Class for constructing an e3nn ML model. """ - model_type: ModelType = Field(ModelType.e3nn, const=True) + model_type: Literal[ModelType.e3nn] = ModelType.e3nn num_atom_types: int = Field( 100, @@ -862,7 +862,8 @@ class E3NNModelConfig(ModelConfigBase): num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") - @root_validator(pre=False) + @model_validator() + @classmethod def massage_irreps(cls, values): """ Check that the value given for ``irreps_hidden`` can be converted into an Irreps @@ -994,7 +995,7 @@ class ViSNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ModelType = Field(ModelType.visnet, const=True) + model_type: Literal[ModelType.visnet] = ModelType.visnet lmax: int = Field(1, description="The maximum degree of the spherical harmonics.") vecnorm_type: str | None = Field( None, description="The type of normalization to apply to the vectors." @@ -1041,7 +1042,8 @@ class ViSNetModelConfig(ModelConfigBase): ), ) - @root_validator(pre=False) + @model_validator() + @classmethod def validate(cls, values): """ Check that ``atomref`` and ``max_z`` agree. From 71ede46726c8cbe9984d299bbda9ad0487b6a94b Mon Sep 17 00:00:00 2001 From: Hugo MacDermott-Opeskin Date: Mon, 30 Sep 2024 13:15:26 +1000 Subject: [PATCH 2/8] remove missed const --- mtenn/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index 9ffcb10..933d56f 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -140,7 +140,8 @@ class ModelConfigBase(BaseModel): to implement the ``_build`` method in order to be used. """ - model_type: ModelType = Field(ModelType.INVALID, const=True, frozen=True) + model_type: Literal[ModelType.INVALID] = ModelType.INVALID + # Random seed optional for reproducibility rand_seed: int | None = Field( From e1885d3021ca70ee08bd11d10377241eec2139f7 Mon Sep 17 00:00:00 2001 From: Hugo MacDermott-Opeskin Date: Mon, 30 Sep 2024 13:27:44 +1000 Subject: [PATCH 3/8] before --- mtenn/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 933d56f..849612b 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -526,7 +526,7 @@ class GATModelConfig(ModelConfigBase): # num_layers _from_num_layers = False - @model_validator() + @model_validator(mode="before") @classmethod def massage_into_lists(cls, values) -> GATModelConfig: """ @@ -738,7 +738,7 @@ class SchNetModelConfig(ModelConfigBase): ), ) - @model_validator() + @model_validator(mode="before") @classmethod def validate(cls, values): # Make sure the grouped stuff is properly assigned @@ -863,7 +863,7 @@ class E3NNModelConfig(ModelConfigBase): num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") - @model_validator() + @model_validator(mode="before") @classmethod def massage_irreps(cls, values): """ @@ -1043,7 +1043,7 @@ class ViSNetModelConfig(ModelConfigBase): ), ) - @model_validator() + @model_validator(mode="before") @classmethod def validate(cls, values): """ From 6e8684dd11597b4099b2f776e92fd65eade94ffa Mon Sep 17 00:00:00 2001 From: Hugo MacDermott-Opeskin Date: Mon, 30 Sep 2024 13:48:29 +1000 Subject: [PATCH 4/8] try after? --- mtenn/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 849612b..fc1b273 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -526,7 +526,7 @@ class GATModelConfig(ModelConfigBase): # num_layers _from_num_layers = False - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def massage_into_lists(cls, values) -> GATModelConfig: """ @@ -738,7 +738,7 @@ class SchNetModelConfig(ModelConfigBase): ), ) - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def validate(cls, values): # Make sure the grouped stuff is properly assigned @@ -863,7 +863,7 @@ class E3NNModelConfig(ModelConfigBase): num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def massage_irreps(cls, values): """ @@ -1043,7 +1043,7 @@ class ViSNetModelConfig(ModelConfigBase): ), ) - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def validate(cls, values): """ From 6f55371ff423b7078316a078db7abe7d6e42656b Mon Sep 17 00:00:00 2001 From: Hugo MacDermott-Opeskin Date: Mon, 30 Sep 2024 18:14:33 +1000 Subject: [PATCH 5/8] try before again --- mtenn/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index fc1b273..849612b 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -526,7 +526,7 @@ class GATModelConfig(ModelConfigBase): # num_layers _from_num_layers = False - @model_validator(mode="after") + @model_validator(mode="before") @classmethod def massage_into_lists(cls, values) -> GATModelConfig: """ @@ -738,7 +738,7 @@ class SchNetModelConfig(ModelConfigBase): ), ) - @model_validator(mode="after") + @model_validator(mode="before") @classmethod def validate(cls, values): # Make sure the grouped stuff is properly assigned @@ -863,7 +863,7 @@ class E3NNModelConfig(ModelConfigBase): num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") - @model_validator(mode="after") + @model_validator(mode="before") @classmethod def massage_irreps(cls, values): """ @@ -1043,7 +1043,7 @@ class ViSNetModelConfig(ModelConfigBase): ), ) - @model_validator(mode="after") + @model_validator(mode="before") @classmethod def validate(cls, values): """ From 9859fa2780bc60b5ab33239ca8558dfb1ed7c0af Mon Sep 17 00:00:00 2001 From: hmacdope Date: Thu, 10 Oct 2024 14:32:15 +1100 Subject: [PATCH 6/8] try --- mtenn/config.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 849612b..14d59be 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -393,7 +393,7 @@ def _check_grouped(values): Makes sure that a Combination method is passed if using a GroupedModel. Only needs to be called for structure-based models. """ - if values["grouped"] and (not values["combination"]): + if values.grouped and not values.combination: raise ValueError("combination must be specified for a GroupedModel.") @@ -526,7 +526,7 @@ class GATModelConfig(ModelConfigBase): # num_layers _from_num_layers = False - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def massage_into_lists(cls, values) -> GATModelConfig: """ @@ -567,13 +567,13 @@ def massage_into_lists(cls, values) -> GATModelConfig: elif list_lens_set == {1}: # If all lists have only one value, we defer to the value passed to # num_layers, as described in the class docstring - num_layers = values["num_layers"] - values["_from_num_layers"] = True + num_layers = values.num_layers + values._from_num_layers = True else: num_layers = max(list_lens_set) - values["_from_num_layers"] = False + values._from_num_layers = False - values["num_layers"] = num_layers + values.num_layers = num_layers # If we just want a model with one layer, can return early since we've already # converted everything into lists if num_layers == 1: @@ -738,14 +738,14 @@ class SchNetModelConfig(ModelConfigBase): ), ) - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def validate(cls, values): # Make sure the grouped stuff is properly assigned ModelConfigBase._check_grouped(values) # Make sure atomref length is correct (this is required by PyG) - atomref = values["atomref"] + atomref = values.atomref if (atomref is not None) and (len(atomref) != 100): raise ValueError(f"atomref must be length 100 (got {len(atomref)})") @@ -863,7 +863,7 @@ class E3NNModelConfig(ModelConfigBase): num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def massage_irreps(cls, values): """ @@ -876,7 +876,7 @@ def massage_irreps(cls, values): ModelConfigBase._check_grouped(values) # Now deal with irreps - irreps = values["irreps_hidden"] + irreps = values.irreps_hidden # First see if this string should be converted into a dict if isinstance(irreps, str): if ":" in irreps: @@ -925,7 +925,7 @@ def massage_irreps(cls, values): except ValueError: raise ValueError(f"Couldn't parse irreps dict: {orig_irreps}") - values["irreps_hidden"] = irreps + values.irreps_hidden = irreps return values def _build(self, mtenn_params={}): @@ -1043,7 +1043,7 @@ class ViSNetModelConfig(ModelConfigBase): ), ) - @model_validator(mode="before") + @model_validator(mode="after") @classmethod def validate(cls, values): """ @@ -1053,10 +1053,10 @@ def validate(cls, values): ModelConfigBase._check_grouped(values) # Make sure atomref length is correct (this is required by PyG) - atomref = values["atomref"] - if (atomref is not None) and (len(atomref) != values["max_z"]): + atomref = values.atomref + if (atomref is not None) and (len(atomref) != values.max_z): raise ValueError( - f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})" + f"atomref length must match max_z. (Expected {values.max_z}, got {len(atomref)})" ) return values From 74a1070a10f4428cfcf559f1b372d69efd5a5595 Mon Sep 17 00:00:00 2001 From: hmacdope Date: Thu, 10 Oct 2024 15:26:52 +1100 Subject: [PATCH 7/8] ugly but works --- mtenn/config.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 14d59be..ef73761 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -527,14 +527,15 @@ class GATModelConfig(ModelConfigBase): _from_num_layers = False @model_validator(mode="after") - @classmethod - def massage_into_lists(cls, values) -> GATModelConfig: + def massage_into_lists(self) -> GATModelConfig: """ Validator to handle unifying all the values into the proper list forms based on the rules described in the class docstring. """ + values = self.dict() + # First convert string lists to actual lists - for param, param_type in cls.LIST_PARAMS.items(): + for param, param_type in self.LIST_PARAMS.items(): param_val = values[param] if isinstance(param_val, str): try: @@ -548,7 +549,7 @@ def massage_into_lists(cls, values) -> GATModelConfig: # Get sizes of all lists list_lens = {} - for p in cls.LIST_PARAMS: + for p in self.LIST_PARAMS: param_val = values[p] if not isinstance(param_val, list): # Shouldn't be possible at this point but just in case @@ -567,24 +568,26 @@ def massage_into_lists(cls, values) -> GATModelConfig: elif list_lens_set == {1}: # If all lists have only one value, we defer to the value passed to # num_layers, as described in the class docstring - num_layers = values.num_layers - values._from_num_layers = True + num_layers = values["num_layers"] + values["_from_num_layers"] = True else: num_layers = max(list_lens_set) - values._from_num_layers = False + values["_from_num_layers"] = False - values.num_layers = num_layers + values["num_layers"] = num_layers # If we just want a model with one layer, can return early since we've already # converted everything into lists if num_layers == 1: - return values + # update self with the new values + self.__dict__.update(values) + # Adjust any length 1 list to be the right length for p, list_len in list_lens.items(): if list_len == 1: values[p] = values[p] * num_layers - return values + return self.__dict__.update(values) def _build(self, mtenn_params={}): """ From 965f0b76c6bb5ae07a2af215a9113dac34209580 Mon Sep 17 00:00:00 2001 From: hmacdope Date: Fri, 11 Oct 2024 07:01:03 +1100 Subject: [PATCH 8/8] kaminow code review --- mtenn/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mtenn/config.py b/mtenn/config.py index ef73761..a990f3a 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -587,7 +587,8 @@ def massage_into_lists(self) -> GATModelConfig: if list_len == 1: values[p] = values[p] * num_layers - return self.__dict__.update(values) + self.__dict__.update(values) + return self def _build(self, mtenn_params={}): """