From 03d0f4de3941da9daabbe15f93956d5c57fcd054 Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:18:12 -0500 Subject: [PATCH 1/3] Fix a typo Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/conversion_utils/visnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 1872174..6079611 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -37,7 +37,7 @@ def forward(self, x): class ViSNet(torch.nn.Module): def __init__(self, *args, model=None, **kwargs): super().__init__() - ## If no model is passed, construct default SchNet model, otherwise copy + ## If no model is passed, construct default ViSNet model, otherwise copy ## all parameters and weights over if model is None: self.visnet = PygVisNet(*args, **kwargs) From 6833c21f347754ac9ed734f54a8f6028ad00fa66 Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:18:49 -0500 Subject: [PATCH 2/3] Visnet mean, std cannot be None Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index 0ef6f38..77d3994 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -794,8 +794,8 @@ class ViSNetModelConfig(ModelConfigBase): "sum", description="The type of reduction operation to apply. ['sum', 'mean']" ) - mean: float | None = Field(0.0, description="The mean of the output distribution.") - std: float | None = Field(1.0, description="The standard deviation of the output distribution.") + mean: float = Field(0.0, description="The mean of the output distribution.") + std: float = Field(1.0, description="The standard deviation of the output distribution.") derivative: bool = Field( False, description="Whether to compute the derivative of the output with respect to the positions." From 017460d0377d94af502b6a0d2f7b590dccd54fdc Mon Sep 17 00:00:00 2001 From: Feiyang Huang <69661474+fyng@users.noreply.github.com> Date: Wed, 31 Jan 2024 10:02:43 -0500 Subject: [PATCH 3/3] VisNet accepts atomref = none Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/conversion_utils/visnet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 6079611..6b6df43 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -42,10 +42,7 @@ def __init__(self, *args, model=None, **kwargs): if model is None: self.visnet = PygVisNet(*args, **kwargs) else: - try: - atomref = model.atomref.weight.detach().clone() - except AttributeError: - atomref = None + atomref = model.prior_model.atomref.weight.detach().clone() model_params = ( model.lmax, model.vecnorm_type,