Skip to content

Commit

Permalink
Fix how e3nn models are constructed when passed a ref model.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaminow committed Feb 15, 2024
1 parent 3b08d3e commit 9dbfed6
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions mtenn/conversion_utils/e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@ def __init__(self, *args, model=None, **kwargs):
super(E3NN, self).__init__(*args, **kwargs)
self.model_parameters = kwargs
else:
# this will need changing to include model features of e3nn
atomref = model.atomref.weight.detach().clone()
model_params = (
model.hidden_channels,
model.num_filters,
model.num_interactions,
model.num_gaussians,
model.cutoff,
model.max_num_neighbors,
model.readout,
model.dipole,
model.mean,
model.std,
atomref,
)
super(E3NN, self).__init__(*model_params)
self.model_parameters = model_params
model_kwargs = {
"irreps_in": model.irreps_in,
"irreps_hidden": model.irreps_hidden,
"irreps_out": model.irreps_out,
"irreps_node_attr": model.irreps_node_attr,
"irreps_edge_attr": model.irreps_edge_attr,
"layers": len(model.layers) - 1,
"max_radius": model.max_radius,
"number_of_basis": model.number_of_basis,
"num_nodes": model.num_nodes,
"reduce_output": model.reduce_output,
}
# These need a bit of work to get
# Use last layer bc guaranteed to be present and is just a Convolution
conv = model.layers[-1]
model_kwargs["radial_layers"] = len(conv.fc.hs) - 2
model_kwargs["radial_neurons"] = conv.fc.hs[1]
model_kwargs["num_neighbors"] = conv.num_neighbors
super(E3NN, self).__init__(**model_kwargs)
self.model_parameters = model_kwargs
self.load_state_dict(model.state_dict())

def forward(self, data):
Expand Down

0 comments on commit 9dbfed6

Please sign in to comment.