Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to pydantic 2.0+ #73

Merged
merged 8 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion devtools/conda-envs/mtenn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ channels:
- conda-forge
dependencies:
- pytorch
- pytorch_geometric
- pytorch_geometric >=2.5.0
- pytorch_cluster
- pytorch_scatter
- pytorch_sparse
- pydantic >=2.0.0a0
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- fsspec
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- pytorch_cluster
- pytorch_scatter
- pytorch_sparse
- pydantic >=2.0.0a0
- numpy
- h5py
- e3nn
Expand All @@ -19,5 +20,4 @@ dependencies:
- pytest
- pytest-cov
- codecov
- pydantic >=1.10.8,<2.0.0a0
- fsspec
58 changes: 32 additions & 26 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import abc
from enum import Enum
from pydantic import BaseModel, Field, root_validator
from pydantic import model_validator, ConfigDict, BaseModel, Field
import random
from typing import Callable, ClassVar
from typing import Literal, Callable, ClassVar
import mtenn.combination
import mtenn.readout
import mtenn.model
Expand Down Expand Up @@ -140,7 +140,8 @@ class ModelConfigBase(BaseModel):
to implement the ``_build`` method in order to be used.
"""

model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False)
model_type: Literal[ModelType.INVALID] = ModelType.INVALID


# Random seed optional for reproducibility
rand_seed: int | None = Field(
Expand Down Expand Up @@ -240,9 +241,7 @@ class ModelConfigBase(BaseModel):
"``comb_substrate``."
),
)

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)

def build(self) -> mtenn.model.Model:
"""
Expand Down Expand Up @@ -394,7 +393,7 @@ def _check_grouped(values):
Makes sure that a Combination method is passed if using a GroupedModel. Only
needs to be called for structure-based models.
"""
if values["grouped"] and (not values["combination"]):
if values.grouped and not values.combination:
raise ValueError("combination must be specified for a GroupedModel.")


Expand Down Expand Up @@ -436,7 +435,7 @@ class GATModelConfig(ModelConfigBase):
"biases": bool,
} #: :meta private:

model_type: ModelType = Field(ModelType.GAT, const=True)
model_type: Literal[ModelType.GAT] = ModelType.GAT

in_feats: int = Field(
_CanonicalAtomFeaturizer().feat_size(),
Expand Down Expand Up @@ -527,14 +526,16 @@ class GATModelConfig(ModelConfigBase):
# num_layers
_from_num_layers = False

@root_validator(pre=False)
def massage_into_lists(cls, values) -> GATModelConfig:
@model_validator(mode="after")
def massage_into_lists(self) -> GATModelConfig:
"""
Validator to handle unifying all the values into the proper list forms based on
the rules described in the class docstring.
"""
values = self.dict()

# First convert string lists to actual lists
for param, param_type in cls.LIST_PARAMS.items():
for param, param_type in self.LIST_PARAMS.items():
param_val = values[param]
if isinstance(param_val, str):
try:
Expand All @@ -548,7 +549,7 @@ def massage_into_lists(cls, values) -> GATModelConfig:

# Get sizes of all lists
list_lens = {}
for p in cls.LIST_PARAMS:
for p in self.LIST_PARAMS:
param_val = values[p]
if not isinstance(param_val, list):
# Shouldn't be possible at this point but just in case
Expand Down Expand Up @@ -577,14 +578,16 @@ def massage_into_lists(cls, values) -> GATModelConfig:
# If we just want a model with one layer, can return early since we've already
# converted everything into lists
if num_layers == 1:
return values
# update self with the new values
self.__dict__.update(values)


# Adjust any length 1 list to be the right length
for p, list_len in list_lens.items():
if list_len == 1:
values[p] = values[p] * num_layers

return values
return self.__dict__.update(values)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think update actually returns anything

Suggested change
return self.__dict__.update(values)
self.__dict__.update(values)
return self.__dict__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you mention it we should return self

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is due to how the after mode validator works in pydantic 2.


def _build(self, mtenn_params={}):
"""
Expand Down Expand Up @@ -681,7 +684,7 @@ class SchNetModelConfig(ModelConfigBase):
given in PyG.
"""

model_type: ModelType = Field(ModelType.schnet, const=True)
model_type: Literal[ModelType.schnet] = ModelType.schnet

hidden_channels: int = Field(128, description="Hidden embedding size.")
num_filters: int = Field(
Expand Down Expand Up @@ -738,13 +741,14 @@ class SchNetModelConfig(ModelConfigBase):
),
)

@root_validator(pre=False)
@model_validator(mode="after")
@classmethod
def validate(cls, values):
# Make sure the grouped stuff is properly assigned
ModelConfigBase._check_grouped(values)

# Make sure atomref length is correct (this is required by PyG)
atomref = values["atomref"]
atomref = values.atomref
if (atomref is not None) and (len(atomref) != 100):
raise ValueError(f"atomref must be length 100 (got {len(atomref)})")

Expand Down Expand Up @@ -816,7 +820,7 @@ class E3NNModelConfig(ModelConfigBase):
Class for constructing an e3nn ML model.
"""

model_type: ModelType = Field(ModelType.e3nn, const=True)
model_type: Literal[ModelType.e3nn] = ModelType.e3nn

num_atom_types: int = Field(
100,
Expand Down Expand Up @@ -862,7 +866,8 @@ class E3NNModelConfig(ModelConfigBase):
num_neighbors: float = Field(25, description="Typical number of neighbor nodes.")
num_nodes: float = Field(4700, description="Typical number of nodes in a graph.")

@root_validator(pre=False)
@model_validator(mode="after")
@classmethod
def massage_irreps(cls, values):
"""
Check that the value given for ``irreps_hidden`` can be converted into an Irreps
Expand All @@ -874,7 +879,7 @@ def massage_irreps(cls, values):
ModelConfigBase._check_grouped(values)

# Now deal with irreps
irreps = values["irreps_hidden"]
irreps = values.irreps_hidden
# First see if this string should be converted into a dict
if isinstance(irreps, str):
if ":" in irreps:
Expand Down Expand Up @@ -923,7 +928,7 @@ def massage_irreps(cls, values):
except ValueError:
raise ValueError(f"Couldn't parse irreps dict: {orig_irreps}")

values["irreps_hidden"] = irreps
values.irreps_hidden = irreps
return values

def _build(self, mtenn_params={}):
Expand Down Expand Up @@ -994,7 +999,7 @@ class ViSNetModelConfig(ModelConfigBase):
given in PyG.
"""

model_type: ModelType = Field(ModelType.visnet, const=True)
model_type: Literal[ModelType.visnet] = ModelType.visnet
lmax: int = Field(1, description="The maximum degree of the spherical harmonics.")
vecnorm_type: str | None = Field(
None, description="The type of normalization to apply to the vectors."
Expand Down Expand Up @@ -1041,7 +1046,8 @@ class ViSNetModelConfig(ModelConfigBase):
),
)

@root_validator(pre=False)
@model_validator(mode="after")
@classmethod
def validate(cls, values):
"""
Check that ``atomref`` and ``max_z`` agree.
Expand All @@ -1050,10 +1056,10 @@ def validate(cls, values):
ModelConfigBase._check_grouped(values)

# Make sure atomref length is correct (this is required by PyG)
atomref = values["atomref"]
if (atomref is not None) and (len(atomref) != values["max_z"]):
atomref = values.atomref
if (atomref is not None) and (len(atomref) != values.max_z):
raise ValueError(
f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})"
f"atomref length must match max_z. (Expected {values.max_z}, got {len(atomref)})"
)

return values
Expand Down
Loading