diff --git a/mtenn/config.py b/mtenn/config.py index 48c3577..0808a19 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -82,7 +82,7 @@ class CombinationConfig(StringEnum): class ModelConfigBase(BaseModel): - model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID + model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False) # Shared parameters for MTENN grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.") @@ -172,6 +172,9 @@ class ModelConfigBase(BaseModel): ), ) + class Config: + validate_assignment = True + @abc.abstractmethod def _build(self, mtenn_params={}) -> mtenn.model.Model: ... @@ -275,7 +278,7 @@ class GATModelConfig(ModelConfigBase): "biases": bool, } - model_type: ClassVar[ModelType.GAT] = ModelType.GAT + model_type: ModelType = Field(ModelType.GAT, const=True) in_feats: int = Field( CanonicalAtomFeaturizer().feat_size(), @@ -479,7 +482,7 @@ class SchNetModelConfig(ModelConfigBase): given in PyG. """ - model_type: ClassVar[ModelType.schnet] = ModelType.schnet + model_type: ModelType = Field(ModelType.schnet, const=True) hidden_channels: int = Field(128, description="Hidden embedding size.") num_filters: int = Field( @@ -601,7 +604,7 @@ class E3NNModelConfig(ModelConfigBase): Class for constructing an e3nn ML model. """ - model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn + model_type: ModelType = Field(ModelType.e3nn, const=True) num_atom_types: int = Field( 100,