Skip to content

Commit

Permalink
Change model_type to a Field so it gets exported with the ModelConfig.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaminow committed Dec 7, 2023
1 parent ca5ac11 commit 66713e0
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class CombinationConfig(StringEnum):


class ModelConfigBase(BaseModel):
model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID
model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False)

# Shared parameters for MTENN
grouped: bool = Field(False, description="Model is a grouped (multi-pose) model.")
Expand Down Expand Up @@ -172,6 +172,9 @@ class ModelConfigBase(BaseModel):
),
)

class Config:
validate_assignment = True

@abc.abstractmethod
def _build(self, mtenn_params={}) -> mtenn.model.Model:
...
Expand Down Expand Up @@ -275,7 +278,7 @@ class GATModelConfig(ModelConfigBase):
"biases": bool,
}

model_type: ClassVar[ModelType.GAT] = ModelType.GAT
model_type: ModelType = Field(ModelType.GAT, const=True)

in_feats: int = Field(
CanonicalAtomFeaturizer().feat_size(),
Expand Down Expand Up @@ -479,7 +482,7 @@ class SchNetModelConfig(ModelConfigBase):
given in PyG.
"""

model_type: ClassVar[ModelType.schnet] = ModelType.schnet
model_type: ModelType = Field(ModelType.schnet, const=True)

hidden_channels: int = Field(128, description="Hidden embedding size.")
num_filters: int = Field(
Expand Down Expand Up @@ -601,7 +604,7 @@ class E3NNModelConfig(ModelConfigBase):
Class for constructing an e3nn ML model.
"""

model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn
model_type: ModelType = Field(ModelType.e3nn, const=True)

num_atom_types: int = Field(
100,
Expand Down

0 comments on commit 66713e0

Please sign in to comment.