Skip to content

Commit

Permalink
Fix computing observables without gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Mar 21, 2024
1 parent 88d1509 commit 9972471
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions smee/mm/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,16 @@ def _compute_observables(
with torch.enable_grad():
potential = smee.compute_energy(system, force_field, coords, box_vectors)

du_d_theta_subset = torch.autograd.grad(
potential,
[theta[i] for i in needs_grad],
[smee.utils.ones_like(1, potential)],
retain_graph=False,
allow_unused=True,
)
du_d_theta_subset = []

if len(needs_grad) > 0:
du_d_theta_subset = torch.autograd.grad(
potential,
[theta[i] for i in needs_grad],
[smee.utils.ones_like(1, potential)],
retain_graph=False,
allow_unused=True,
)

for idx, i in enumerate(needs_grad):
du_d_theta[i].append(du_d_theta_subset[idx].float())
Expand Down

0 comments on commit 9972471

Please sign in to comment.