From 2c9bb9980d8e5ebe0ec83df3981f8a3fd842a663 Mon Sep 17 00:00:00 2001 From: fyng Date: Wed, 31 Jan 2024 10:03:26 -0500 Subject: [PATCH] atomref should match max_z --- mtenn/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 0ef6f38..e55661e 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -786,8 +786,7 @@ class ViSNetModelConfig(ModelConfigBase): atomref: list[float] | None = Field( None, description=( - "Reference values for single-atom properties. Should have length of 100 to " - "match with PyG." + "Reference values for single-atom properties. Should have length max_z" ) ) reduce_op: str = Field( @@ -808,8 +807,8 @@ def validate(cls, values): # Make sure atomref length is correct (this is required by PyG) atomref = values["atomref"] - if (atomref is not None) and (len(atomref) != 100): - raise ValueError(f"atomref must be length 100 (got {len(atomref)})") + if (atomref is not None) and (len(atomref) != values["max_z"]): + raise ValueError(f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})") return values