diff --git a/mtenn/conversion_utils/visnet.py b/mtenn/conversion_utils/visnet.py index 67644d2..e1508aa 100644 --- a/mtenn/conversion_utils/visnet.py +++ b/mtenn/conversion_utils/visnet.py @@ -90,6 +90,9 @@ def forward(self, data): x = self.visnet.output_model.pre_reduce(x, v) x = x * self.visnet.std + if self.visnet.prior_model is not None: + x = self.visnet.prior_model(x, z) + return x def _get_representation(self):