diff --git a/devtools/conda-envs/mtenn.yaml b/devtools/conda-envs/mtenn.yaml index 8139a84..2902743 100644 --- a/devtools/conda-envs/mtenn.yaml +++ b/devtools/conda-envs/mtenn.yaml @@ -3,10 +3,11 @@ channels: - conda-forge dependencies: - pytorch - - pytorch_geometric + - pytorch_geometric >=2.5.0 - pytorch_cluster - pytorch_scatter - pytorch_sparse + - pydantic >=2.0.0a0 - numpy - h5py - e3nn @@ -14,3 +15,4 @@ dependencies: - dgl - rdkit - ase + - fsspec diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 9fce2b8..e140284 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -7,6 +7,7 @@ dependencies: - pytorch_cluster - pytorch_scatter - pytorch_sparse + - pydantic >=2.0.0a0 - numpy - h5py - e3nn @@ -19,5 +20,4 @@ dependencies: - pytest - pytest-cov - codecov - - pydantic >=1.10.8,<2.0.0a0 - fsspec diff --git a/mtenn/config.py b/mtenn/config.py index a8d6eef..a990f3a 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -15,9 +15,9 @@ import abc from enum import Enum -from pydantic import BaseModel, Field, root_validator +from pydantic import model_validator, ConfigDict, BaseModel, Field import random -from typing import Callable, ClassVar +from typing import Literal, Callable, ClassVar import mtenn.combination import mtenn.readout import mtenn.model @@ -140,7 +140,8 @@ class ModelConfigBase(BaseModel): to implement the ``_build`` method in order to be used. """ - model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False) + model_type: Literal[ModelType.INVALID] = ModelType.INVALID + # Random seed optional for reproducibility rand_seed: int | None = Field( @@ -240,9 +241,7 @@ class ModelConfigBase(BaseModel): "``comb_substrate``." ), ) - - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) def build(self) -> mtenn.model.Model: """ @@ -394,7 +393,7 @@ def _check_grouped(values): Makes sure that a Combination method is passed if using a GroupedModel. Only needs to be called for structure-based models. """ - if values["grouped"] and (not values["combination"]): + if values.grouped and not values.combination: raise ValueError("combination must be specified for a GroupedModel.") @@ -436,7 +435,7 @@ class GATModelConfig(ModelConfigBase): "biases": bool, } #: :meta private: - model_type: ModelType = Field(ModelType.GAT, const=True) + model_type: Literal[ModelType.GAT] = ModelType.GAT in_feats: int = Field( _CanonicalAtomFeaturizer().feat_size(), @@ -527,14 +526,16 @@ class GATModelConfig(ModelConfigBase): # num_layers _from_num_layers = False - @root_validator(pre=False) - def massage_into_lists(cls, values) -> GATModelConfig: + @model_validator(mode="after") + def massage_into_lists(self) -> GATModelConfig: """ Validator to handle unifying all the values into the proper list forms based on the rules described in the class docstring. """ + values = self.dict() + # First convert string lists to actual lists - for param, param_type in cls.LIST_PARAMS.items(): + for param, param_type in self.LIST_PARAMS.items(): param_val = values[param] if isinstance(param_val, str): try: @@ -548,7 +549,7 @@ def massage_into_lists(cls, values) -> GATModelConfig: # Get sizes of all lists list_lens = {} - for p in cls.LIST_PARAMS: + for p in self.LIST_PARAMS: param_val = values[p] if not isinstance(param_val, list): # Shouldn't be possible at this point but just in case @@ -577,14 +578,17 @@ def massage_into_lists(cls, values) -> GATModelConfig: # If we just want a model with one layer, can return early since we've already # converted everything into lists if num_layers == 1: - return values + # update self with the new values + self.__dict__.update(values) + # Adjust any length 1 list to be the right length for p, list_len in list_lens.items(): if list_len == 1: values[p] = values[p] * num_layers - return values + self.__dict__.update(values) + return self def _build(self, mtenn_params={}): """ @@ -681,7 +685,7 @@ class SchNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ModelType = Field(ModelType.schnet, const=True) + model_type: Literal[ModelType.schnet] = ModelType.schnet hidden_channels: int = Field(128, description="Hidden embedding size.") num_filters: int = Field( @@ -738,13 +742,14 @@ class SchNetModelConfig(ModelConfigBase): ), ) - @root_validator(pre=False) + @model_validator(mode="after") + @classmethod def validate(cls, values): # Make sure the grouped stuff is properly assigned ModelConfigBase._check_grouped(values) # Make sure atomref length is correct (this is required by PyG) - atomref = values["atomref"] + atomref = values.atomref if (atomref is not None) and (len(atomref) != 100): raise ValueError(f"atomref must be length 100 (got {len(atomref)})") @@ -816,7 +821,7 @@ class E3NNModelConfig(ModelConfigBase): Class for constructing an e3nn ML model. """ - model_type: ModelType = Field(ModelType.e3nn, const=True) + model_type: Literal[ModelType.e3nn] = ModelType.e3nn num_atom_types: int = Field( 100, @@ -862,7 +867,8 @@ class E3NNModelConfig(ModelConfigBase): num_neighbors: float = Field(25, description="Typical number of neighbor nodes.") num_nodes: float = Field(4700, description="Typical number of nodes in a graph.") - @root_validator(pre=False) + @model_validator(mode="after") + @classmethod def massage_irreps(cls, values): """ Check that the value given for ``irreps_hidden`` can be converted into an Irreps @@ -874,7 +880,7 @@ def massage_irreps(cls, values): ModelConfigBase._check_grouped(values) # Now deal with irreps - irreps = values["irreps_hidden"] + irreps = values.irreps_hidden # First see if this string should be converted into a dict if isinstance(irreps, str): if ":" in irreps: @@ -923,7 +929,7 @@ def massage_irreps(cls, values): except ValueError: raise ValueError(f"Couldn't parse irreps dict: {orig_irreps}") - values["irreps_hidden"] = irreps + values.irreps_hidden = irreps return values def _build(self, mtenn_params={}): @@ -994,7 +1000,7 @@ class ViSNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ModelType = Field(ModelType.visnet, const=True) + model_type: Literal[ModelType.visnet] = ModelType.visnet lmax: int = Field(1, description="The maximum degree of the spherical harmonics.") vecnorm_type: str | None = Field( None, description="The type of normalization to apply to the vectors." @@ -1041,7 +1047,8 @@ class ViSNetModelConfig(ModelConfigBase): ), ) - @root_validator(pre=False) + @model_validator(mode="after") + @classmethod def validate(cls, values): """ Check that ``atomref`` and ``max_z`` agree. @@ -1050,10 +1057,10 @@ def validate(cls, values): ModelConfigBase._check_grouped(values) # Make sure atomref length is correct (this is required by PyG) - atomref = values["atomref"] - if (atomref is not None) and (len(atomref) != values["max_z"]): + atomref = values.atomref + if (atomref is not None) and (len(atomref) != values.max_z): raise ValueError( - f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})" + f"atomref length must match max_z. (Expected {values.max_z}, got {len(atomref)})" ) return values