Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ViSNet #41

Merged
merged 43 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
de57eda
add visnet template
fyng Jan 23, 2024
eadb44c
ViSNet outputs vector embedding
fyng Jan 24, 2024
f61d2d1
add visnet to ModelType
fyng Jan 25, 2024
bdcadfb
standardize naming for visnet
fyng Jan 25, 2024
6a0d111
add visnet to init
fyng Jan 25, 2024
9c9d96a
fix typo
fyng Jan 25, 2024
14864db
fix last layer of model
fyng Jan 25, 2024
81f8818
install packages from pytorch-nightly
fyng Jan 26, 2024
be05591
pip install pytorch geometric nightly instead
fyng Jan 26, 2024
7be95b7
try importing visnet, implement has_visnet_flag
fyng Jan 26, 2024
f78ed8c
add test for visnet
fyng Jan 29, 2024
4f8975f
VisNet import error handling for older PyG version
fyng Jan 29, 2024
6c15f04
change variable name
fyng Jan 29, 2024
549e8f7
add import guard to VisNet test
fyng Jan 29, 2024
efad5c3
create two test environment files
fyng Jan 29, 2024
25be213
update CI to test stable and nightly builds
fyng Jan 29, 2024
73765c4
fix CI
fyng Jan 29, 2024
f92491c
add pytest import
fyng Jan 29, 2024
eb81626
fix typo and style
fyng Jan 29, 2024
176fccc
Comments for Issue 42
fyng Jan 29, 2024
a277bee
comments for issue #42
fyng Jan 29, 2024
c083277
fix a typo
fyng Jan 29, 2024
03d0f4d
Fix a typo
fyng Jan 30, 2024
6833c21
Visnet mean, std cannot be None
fyng Jan 30, 2024
017460d
VisNet accepts atomref = none
fyng Jan 31, 2024
2c9bb99
atomref should match max_z
fyng Jan 31, 2024
e8fe374
Merge branch 'add-conversion-utils' of github.com:choderalab/mtenn in…
fyng Jan 31, 2024
48721fb
clean up todo
fyng Jan 31, 2024
ffa571c
bring in prior_model from PyG visnet
fyng Jan 31, 2024
ec3c0c2
fix indentation
fyng Jan 31, 2024
34c8ace
add a visnet test
fyng Feb 1, 2024
ba91938
add import warning to visnet import guard
fyng Feb 5, 2024
a406012
fix typo
fyng Feb 5, 2024
5f161a3
remove redundant visnet set_config test
fyng Feb 5, 2024
d14f65d
fix typo in docstring
fyng Feb 7, 2024
60aa245
fix visnet instantiation from pyg test
fyng Feb 7, 2024
c85b995
Update mtenn/tests/test_model_config.py
fyng Feb 9, 2024
ec8b6c2
minor changes
fyng Feb 9, 2024
4a53919
update visnet instantiation of pyg and test
fyng Feb 9, 2024
3458514
fix mtenn visnet instantiation from pyg visnet
fyng Feb 9, 2024
cd2086b
guard import
fyng Feb 9, 2024
d5c1b46
update docstrings
fyng Feb 12, 2024
3924d3f
update doc strings
fyng Feb 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ defaults:

jobs:
test:
name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}
name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}, Env ${{ matrix.deps }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macOS-latest, ubuntu-latest]
python-version: ["3.10", "3.11"]
deps: ["devtools/conda-envs/test_env.yaml", "devtools/conda-envs/test_env-nightly.yaml"]

steps:
- name: Checkout Repository
Expand All @@ -45,7 +46,7 @@ jobs:
- name: Setup Micromamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: devtools/conda-envs/test_env.yaml
environment-file: ${{ matrix.deps }}
environment-name: test
create-args: >-
python==${{ matrix.python-version }}
Expand All @@ -68,5 +69,5 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}
name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}-env${{ matrix.deps }}
fail_ci_if_error: false
24 changes: 24 additions & 0 deletions devtools/conda-envs/test_env-nightly.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: test
channels:
- conda-forge
dependencies:
- pip
- pytorch
- pytorch_cluster
- pytorch_scatter
- pytorch_sparse
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
# testing dependencies
- pytest
- pytest-cov
- codecov
- pydantic >=1.10.8,<2.0.0a0

- pip:
- pyg-nightly
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ dependencies:
- pytest
- pytest-cov
- codecov
- pydantic >=1.10.8,<2.0.0a0
- pydantic >=1.10.8,<2.0.0a0
126 changes: 124 additions & 2 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pydantic import BaseModel, Field, root_validator
import random
from typing import Callable, ClassVar

import mtenn
import numpy as np
import torch
Expand Down Expand Up @@ -39,7 +38,7 @@ class ModelType(StringEnum):
schnet = "schnet"
e3nn = "e3nn"
INVALID = "INVALID"

visnet = "visnet"

class StrategyConfig(StringEnum):
"""
Expand Down Expand Up @@ -749,3 +748,126 @@ def _build(self, mtenn_params={}):
pred_readout=pred_readout,
comb_readout=comb_readout,
)


class ViSNetModelConfig(ModelConfigBase):
"""
Class for constructing a VisNet ML model. Default values here are the default values
given in PyG.
"""

model_type: ModelType = Field(ModelType.visnet, const=True)
lmax: int = Field(1, description="The maximum degree of the spherical harmonics.")
vecnorm_type: str | None = Field(
None, description="The type of normalization to apply to the vectors."
)
trainable_vecnorm: bool = Field(
False, description="Whether the normalization weights are trainable."
)
num_heads: int = Field(8, description="The number of attention heads.")
num_layers: int = Field(6, description="The number of layers in the network.")
hidden_channels: int = Field(
128, description="The number of hidden channels in the node embeddings."
)
num_rbf: int = Field(32, description="The number of radial basis functions.")
trainable_rbf: bool = Field(
False, description="Whether the radial basis function parameters are trainable."
)
max_z: int = Field(100, description="The maximum atomic numbers.")
cutoff: float = Field(5.0, description="The cutoff distance.")
max_num_neighbors: int = Field(
32,
description="The maximum number of neighbors considered for each atom."
)
vertex: bool = Field(
False,
description="Whether to use vertex geometric features."
)
atomref: list[float] | None = Field(
None,
description=(
"Reference values for single-atom properties. Should have length max_z"
)
)
reduce_op: str = Field(
"sum",
description="The type of reduction operation to apply. ['sum', 'mean']"
)
mean: float = Field(0.0, description="The mean of the output distribution.")
std: float = Field(1.0, description="The standard deviation of the output distribution.")
derivative: bool = Field(
False,
description="Whether to compute the derivative of the output with respect to the positions."
)

@root_validator(pre=False)
def validate(cls, values):
# Make sure the grouped stuff is properly assigned
ModelConfigBase._check_grouped(values)

# Make sure atomref length is correct (this is required by PyG)
atomref = values["atomref"]
if (atomref is not None) and (len(atomref) != values["max_z"]):
raise ValueError(f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})")

return values



def _build(self, mtenn_params={}):
"""
Build an MTENN ViSNet Model from this config.

Parameters
----------
mtenn_params: dict
Dict giving the MTENN Readout. This will be passed by the `build` method in
the abstract base class

Returns
-------
mtenn.model.Model
MTENN ViSNet Model/GroupedModel
"""
# Create an MTENN ViSNet model from PyG ViSNet model

from mtenn.conversion_utils.visnet import HAS_VISNET
if HAS_VISNET:
from mtenn.conversion_utils import ViSNet

model = ViSNet(
lmax=self.lmax,
vecnorm_type=self.vecnorm_type,
trainable_vecnorm=self.trainable_vecnorm,
num_heads=self.num_heads,
num_layers=self.num_layers,
hidden_channels=self.hidden_channels,
num_rbf=self.num_rbf,
trainable_rbf=self.trainable_rbf,
max_z=self.max_z,
cutoff=self.cutoff,
max_num_neighbors=self.max_num_neighbors,
vertex=self.vertex,
reduce_op=self.reduce_op,
mean=self.mean,
std=self.std,
derivative=self.derivative,
atomref=self.atomref,
)
combination = mtenn_params.get("combination", None)
pred_readout = mtenn_params.get("pred_readout", None)
comb_readout = mtenn_params.get("comb_readout", None)

return ViSNet.get_model(
model=model,
grouped=self.grouped,
fix_device=True,
strategy=self.strategy,
combination=combination,
pred_readout=pred_readout,
comb_readout=comb_readout,
)

else:
raise ImportError("ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42.")

7 changes: 6 additions & 1 deletion mtenn/conversion_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@
from .gat import GAT
from .schnet import SchNet

__all__ = ["E3NN", "GAT", "SchNet"]
# refer to issue #42
from .visnet import HAS_VISNET
if HAS_VISNET:
from .visnet import ViSNet

__all__ = ["E3NN", "GAT", "SchNet", "ViSNet"]
Loading
Loading