Skip to content

Commit

Permalink
Merge pull request #49 from choderalab/fix-e3nn-ref-model
Browse files Browse the repository at this point in the history
Fix e3nn model building
  • Loading branch information
hmacdope authored Feb 15, 2024
2 parents 3b08d3e + 0d610d3 commit 6fcec40
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 17 deletions.
37 changes: 20 additions & 17 deletions mtenn/conversion_utils/e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@ def __init__(self, *args, model=None, **kwargs):
super(E3NN, self).__init__(*args, **kwargs)
self.model_parameters = kwargs
else:
# this will need changing to include model features of e3nn
atomref = model.atomref.weight.detach().clone()
model_params = (
model.hidden_channels,
model.num_filters,
model.num_interactions,
model.num_gaussians,
model.cutoff,
model.max_num_neighbors,
model.readout,
model.dipole,
model.mean,
model.std,
atomref,
)
super(E3NN, self).__init__(*model_params)
self.model_parameters = model_params
model_kwargs = {
"irreps_in": model.irreps_in,
"irreps_hidden": model.irreps_hidden,
"irreps_out": model.irreps_out,
"irreps_node_attr": model.irreps_node_attr,
"irreps_edge_attr": model.irreps_edge_attr,
"layers": len(model.layers) - 1,
"max_radius": model.max_radius,
"number_of_basis": model.number_of_basis,
"num_nodes": model.num_nodes,
"reduce_output": model.reduce_output,
}
# These need a bit of work to get
# Use last layer bc guaranteed to be present and is just a Convolution
conv = model.layers[-1]
model_kwargs["radial_layers"] = len(conv.fc.hs) - 2
model_kwargs["radial_neurons"] = conv.fc.hs[1]
model_kwargs["num_neighbors"] = conv.num_neighbors
super(E3NN, self).__init__(**model_kwargs)
self.model_parameters = model_kwargs
self.load_state_dict(model.state_dict())

def forward(self, data):
Expand Down
70 changes: 70 additions & 0 deletions mtenn/tests/test_e3nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest

from e3nn.nn.models.gate_points_2101 import Network
from e3nn.o3 import Irreps
from mtenn.conversion_utils.e3nn import E3NN


@pytest.fixture
def e3nn_kwargs():
return {
"irreps_in": "5x0e+2x1o",
"irreps_hidden": "10x0e+10x0o+1o+1e",
"irreps_out": "0e",
"irreps_node_attr": "0e",
"irreps_edge_attr": Irreps.spherical_harmonics(2),
"layers": 5,
"max_radius": 10,
"number_of_basis": 5,
"radial_layers": 5,
"radial_neurons": 32,
"num_neighbors": 10,
"num_nodes": 100,
"reduce_output": True,
}


def test_build_e3nn_directly_kwargs(e3nn_kwargs):
model = E3NN(**e3nn_kwargs)

# Directly stored parameters
assert model.irreps_in == Irreps(e3nn_kwargs["irreps_in"])
assert model.irreps_hidden == Irreps(e3nn_kwargs["irreps_hidden"])
assert model.irreps_out == Irreps(e3nn_kwargs["irreps_out"])
assert model.irreps_node_attr == Irreps(e3nn_kwargs["irreps_node_attr"])
assert model.irreps_edge_attr == Irreps(e3nn_kwargs["irreps_edge_attr"])
assert len(model.layers) == e3nn_kwargs["layers"] + 1
assert model.max_radius == e3nn_kwargs["max_radius"]
assert model.number_of_basis == e3nn_kwargs["number_of_basis"]
assert model.num_nodes == e3nn_kwargs["num_nodes"]
assert model.reduce_output == e3nn_kwargs["reduce_output"]

# Indirect ones
conv = model.layers[-1]
assert len(conv.fc.hs) - 2 == e3nn_kwargs["radial_layers"]
assert conv.fc.hs[1] == e3nn_kwargs["radial_neurons"]
assert conv.num_neighbors == e3nn_kwargs["num_neighbors"]


def test_build_e3nn_from_e3nn_network(e3nn_kwargs):
ref_model = Network(**e3nn_kwargs)
model = E3NN(model=ref_model)

# Directly stored parameters
assert model.irreps_in == ref_model.irreps_in
assert model.irreps_hidden == ref_model.irreps_hidden
assert model.irreps_out == ref_model.irreps_out
assert model.irreps_node_attr == ref_model.irreps_node_attr
assert model.irreps_edge_attr == ref_model.irreps_edge_attr
assert len(model.layers) == len(ref_model.layers)
assert model.max_radius == ref_model.max_radius
assert model.number_of_basis == ref_model.number_of_basis
assert model.num_nodes == ref_model.num_nodes
assert model.reduce_output == ref_model.reduce_output

# Indirect ones
ref_conv = ref_model.layers[-1]
conv = model.layers[-1]
assert len(conv.fc.hs) == len(ref_conv.fc.hs)
assert conv.fc.hs[1] == ref_conv.fc.hs[1]
assert conv.num_neighbors == ref_conv.num_neighbors

0 comments on commit 6fcec40

Please sign in to comment.