diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 7f303fe..9f338ef 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -17,23 +17,26 @@ def __init__(self, *args, model=None, **kwargs): super(E3NN, self).__init__(*args, **kwargs) self.model_parameters = kwargs else: - # this will need changing to include model features of e3nn - atomref = model.atomref.weight.detach().clone() - model_params = ( - model.hidden_channels, - model.num_filters, - model.num_interactions, - model.num_gaussians, - model.cutoff, - model.max_num_neighbors, - model.readout, - model.dipole, - model.mean, - model.std, - atomref, - ) - super(E3NN, self).__init__(*model_params) - self.model_parameters = model_params + model_kwargs = { + "irreps_in": model.irreps_in, + "irreps_hidden": model.irreps_hidden, + "irreps_out": model.irreps_out, + "irreps_node_attr": model.irreps_node_attr, + "irreps_edge_attr": model.irreps_edge_attr, + "layers": len(model.layers) - 1, + "max_radius": model.max_radius, + "number_of_basis": model.number_of_basis, + "num_nodes": model.num_nodes, + "reduce_output": model.reduce_output, + } + # These need a bit of work to get + # Use last layer bc guaranteed to be present and is just a Convolution + conv = model.layers[-1] + model_kwargs["radial_layers"] = len(conv.fc.hs) - 2 + model_kwargs["radial_neurons"] = conv.fc.hs[1] + model_kwargs["num_neighbors"] = conv.num_neighbors + super(E3NN, self).__init__(**model_kwargs) + self.model_parameters = model_kwargs self.load_state_dict(model.state_dict()) def forward(self, data): diff --git a/mtenn/tests/test_e3nn.py b/mtenn/tests/test_e3nn.py new file mode 100644 index 0000000..1f927d8 --- /dev/null +++ b/mtenn/tests/test_e3nn.py @@ -0,0 +1,70 @@ +import pytest + +from e3nn.nn.models.gate_points_2101 import Network +from e3nn.o3 import Irreps +from mtenn.conversion_utils.e3nn import E3NN + + +@pytest.fixture +def e3nn_kwargs(): + return { + "irreps_in": "5x0e+2x1o", + "irreps_hidden": "10x0e+10x0o+1o+1e", + "irreps_out": "0e", + "irreps_node_attr": "0e", + "irreps_edge_attr": Irreps.spherical_harmonics(2), + "layers": 5, + "max_radius": 10, + "number_of_basis": 5, + "radial_layers": 5, + "radial_neurons": 32, + "num_neighbors": 10, + "num_nodes": 100, + "reduce_output": True, + } + + +def test_build_e3nn_directly_kwargs(e3nn_kwargs): + model = E3NN(**e3nn_kwargs) + + # Directly stored parameters + assert model.irreps_in == Irreps(e3nn_kwargs["irreps_in"]) + assert model.irreps_hidden == Irreps(e3nn_kwargs["irreps_hidden"]) + assert model.irreps_out == Irreps(e3nn_kwargs["irreps_out"]) + assert model.irreps_node_attr == Irreps(e3nn_kwargs["irreps_node_attr"]) + assert model.irreps_edge_attr == Irreps(e3nn_kwargs["irreps_edge_attr"]) + assert len(model.layers) == e3nn_kwargs["layers"] + 1 + assert model.max_radius == e3nn_kwargs["max_radius"] + assert model.number_of_basis == e3nn_kwargs["number_of_basis"] + assert model.num_nodes == e3nn_kwargs["num_nodes"] + assert model.reduce_output == e3nn_kwargs["reduce_output"] + + # Indirect ones + conv = model.layers[-1] + assert len(conv.fc.hs) - 2 == e3nn_kwargs["radial_layers"] + assert conv.fc.hs[1] == e3nn_kwargs["radial_neurons"] + assert conv.num_neighbors == e3nn_kwargs["num_neighbors"] + + +def test_build_e3nn_from_e3nn_network(e3nn_kwargs): + ref_model = Network(**e3nn_kwargs) + model = E3NN(model=ref_model) + + # Directly stored parameters + assert model.irreps_in == ref_model.irreps_in + assert model.irreps_hidden == ref_model.irreps_hidden + assert model.irreps_out == ref_model.irreps_out + assert model.irreps_node_attr == ref_model.irreps_node_attr + assert model.irreps_edge_attr == ref_model.irreps_edge_attr + assert len(model.layers) == len(ref_model.layers) + assert model.max_radius == ref_model.max_radius + assert model.number_of_basis == ref_model.number_of_basis + assert model.num_nodes == ref_model.num_nodes + assert model.reduce_output == ref_model.reduce_output + + # Indirect ones + ref_conv = ref_model.layers[-1] + conv = model.layers[-1] + assert len(conv.fc.hs) == len(ref_conv.fc.hs) + assert conv.fc.hs[1] == ref_conv.fc.hs[1] + assert conv.num_neighbors == ref_conv.num_neighbors