Skip to content

Commit

Permalink
Merge branch 'add-conversion-utils' of github.com:choderalab/mtenn in…
Browse files Browse the repository at this point in the history
…to add-conversion-utils
  • Loading branch information
fyng committed Jan 31, 2024
2 parents 2c9bb99 + 017460d commit e8fe374
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
4 changes: 2 additions & 2 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,8 @@ class ViSNetModelConfig(ModelConfigBase):
"sum",
description="The type of reduction operation to apply. ['sum', 'mean']"
)
mean: float | None = Field(0.0, description="The mean of the output distribution.")
std: float | None = Field(1.0, description="The standard deviation of the output distribution.")
mean: float = Field(0.0, description="The mean of the output distribution.")
std: float = Field(1.0, description="The standard deviation of the output distribution.")
derivative: bool = Field(
False,
description="Whether to compute the derivative of the output with respect to the positions."
Expand Down
7 changes: 2 additions & 5 deletions mtenn/conversion_utils/visnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,12 @@ def forward(self, x):
class ViSNet(torch.nn.Module):
def __init__(self, *args, model=None, **kwargs):
super().__init__()
## If no model is passed, construct default SchNet model, otherwise copy
## If no model is passed, construct default ViSNet model, otherwise copy
## all parameters and weights over
if model is None:
self.visnet = PygVisNet(*args, **kwargs)
else:
try:
atomref = model.atomref.weight.detach().clone()
except AttributeError:
atomref = None
atomref = model.prior_model.atomref.weight.detach().clone()
model_params = (
model.lmax,
model.vecnorm_type,
Expand Down

0 comments on commit e8fe374

Please sign in to comment.