Skip to content

Commit

Permalink
atomref should match max_z
Browse files Browse the repository at this point in the history
  • Loading branch information
fyng committed Jan 31, 2024
1 parent c083277 commit 2c9bb99
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,7 @@ class ViSNetModelConfig(ModelConfigBase):
atomref: list[float] | None = Field(
None,
description=(
"Reference values for single-atom properties. Should have length of 100 to "
"match with PyG."
"Reference values for single-atom properties. Should have length max_z"
)
)
reduce_op: str = Field(
Expand All @@ -808,8 +807,8 @@ def validate(cls, values):

# Make sure atomref length is correct (this is required by PyG)
atomref = values["atomref"]
if (atomref is not None) and (len(atomref) != 100):
raise ValueError(f"atomref must be length 100 (got {len(atomref)})")
if (atomref is not None) and (len(atomref) != values["max_z"]):
raise ValueError(f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})")

return values

Expand Down

0 comments on commit 2c9bb99

Please sign in to comment.