diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 7f303fe..9f338ef 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -17,23 +17,26 @@ def __init__(self, *args, model=None, **kwargs): super(E3NN, self).__init__(*args, **kwargs) self.model_parameters = kwargs else: - # this will need changing to include model features of e3nn - atomref = model.atomref.weight.detach().clone() - model_params = ( - model.hidden_channels, - model.num_filters, - model.num_interactions, - model.num_gaussians, - model.cutoff, - model.max_num_neighbors, - model.readout, - model.dipole, - model.mean, - model.std, - atomref, - ) - super(E3NN, self).__init__(*model_params) - self.model_parameters = model_params + model_kwargs = { + "irreps_in": model.irreps_in, + "irreps_hidden": model.irreps_hidden, + "irreps_out": model.irreps_out, + "irreps_node_attr": model.irreps_node_attr, + "irreps_edge_attr": model.irreps_edge_attr, + "layers": len(model.layers) - 1, + "max_radius": model.max_radius, + "number_of_basis": model.number_of_basis, + "num_nodes": model.num_nodes, + "reduce_output": model.reduce_output, + } + # These need a bit of work to get + # Use last layer bc guaranteed to be present and is just a Convolution + conv = model.layers[-1] + model_kwargs["radial_layers"] = len(conv.fc.hs) - 2 + model_kwargs["radial_neurons"] = conv.fc.hs[1] + model_kwargs["num_neighbors"] = conv.num_neighbors + super(E3NN, self).__init__(**model_kwargs) + self.model_parameters = model_kwargs self.load_state_dict(model.state_dict()) def forward(self, data):