From 66713e0b1ce848461c32a05891d47c85335d965d Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 7 Dec 2023 14:09:01 -0500 Subject: [PATCH] Change model_type to a Field so it gets exported with the ModelConfig. --- mtenn/config.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 48c3577..0808a19 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -82,7 +82,7 @@ class CombinationConfig(StringEnum): class ModelConfigBase(BaseModel): - model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID + model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False) # Shared parameters for MTENN grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.") @@ -172,6 +172,9 @@ class ModelConfigBase(BaseModel): ), ) + class Config: + validate_assignment = True + @abc.abstractmethod def _build(self, mtenn_params={}) -> mtenn.model.Model: ... @@ -275,7 +278,7 @@ class GATModelConfig(ModelConfigBase): "biases": bool, } - model_type: ClassVar[ModelType.GAT] = ModelType.GAT + model_type: ModelType = Field(ModelType.GAT, const=True) in_feats: int = Field( CanonicalAtomFeaturizer().feat_size(), @@ -479,7 +482,7 @@ class SchNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ClassVar[ModelType.schnet] = ModelType.schnet + model_type: ModelType = Field(ModelType.schnet, const=True) hidden_channels: int = Field(128, description="Hidden embedding size.") num_filters: int = Field( @@ -601,7 +604,7 @@ class E3NNModelConfig(ModelConfigBase): Class for constructing an e3nn ML model. """ - model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn + model_type: ModelType = Field(ModelType.e3nn, const=True) num_atom_types: int = Field( 100,