Skip to content

Commit

Permalink
Add support for computing dG gradient w.r.t solvent params (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 10, 2024
1 parent a6ea409 commit c01d912
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 36 deletions.
124 changes: 111 additions & 13 deletions smee/mm/_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,30 @@
_NM_TO_ANGSTROM = 10.0


def _extract_pure_solvent(
force_field: smee.TensorForceField, output_dir: pathlib.Path
) -> tuple[smee.TensorSystem, torch.Tensor, torch.Tensor, float, float, torch.Tensor]:
device = force_field.potentials[0].parameters.device
dtype = force_field.potentials[0].parameters.dtype

system, beta, pressure, u_kn, n_k, xyz, box = _load_samples(
output_dir, device, dtype, coord_state_idx=-1
)

if len(system.topologies) != 2 or system.n_copies[0] != 1:
raise NotImplementedError("only single solute systems are supported.")

n_solute_atoms = system.topologies[0].n_atoms
xyz = xyz[:, n_solute_atoms:, :]

system = smee.TensorSystem(
[system.topologies[1]], [system.n_copies[1]], is_periodic=True
)
energy = _compute_energy(system, force_field, xyz, box)

return system, xyz, box, beta, pressure, energy


def generate_dg_solv_data(
solute: smee.TensorTopology,
solvent: smee.TensorTopology,
Expand Down Expand Up @@ -145,10 +169,15 @@ def _parameterize(
parmed.openmm.load_topology(prepared_system_a.topology.to_openmm()),
)

return absolv.runner.run_eq(
result = absolv.runner.run_eq(
config, prepared_system_a, prepared_system_b, "CUDA", output_dir, parallel=True
)

solvent_b_output = _extract_pure_solvent(force_field, output_dir / "solvent-b")
torch.save(solvent_b_output, output_dir / "solvent-b" / "pure.pt")

return result


def _uncorrelated_frames(length: int, g: float) -> list[int]:
"""Return the indices of frames that are un-correlated.
Expand Down Expand Up @@ -214,7 +243,10 @@ def _load_trajectory(


def _load_samples(
output_dir: pathlib.Path, device: str | torch.device, dtype: torch.dtype
output_dir: pathlib.Path,
device: str | torch.device,
dtype: torch.dtype,
coord_state_idx: int = 0,
) -> tuple[
smee.TensorSystem,
float,
Expand Down Expand Up @@ -256,10 +288,6 @@ def _load_samples(
u_kn = numpy.hstack([numpy.array(x) for x in output_table["u_kn"].to_pylist()])
u_kn_per_k = [u_kn[:, replica_to_state_idx == i] for i in range(len(u_kn))]

xyz_0, box_0 = _load_trajectory(
output_dir / "trajectories", system, replica_to_state_idx
)

n_uncorrelated = u_kn.shape[1] // u_kn.shape[0]

g = pymbar.timeseries.statistical_inefficiency_multiple(
Expand All @@ -270,9 +298,6 @@ def _load_samples(
)
uncorrelated_frames = _uncorrelated_frames(n_uncorrelated, g)

xyz_0 = xyz_0[uncorrelated_frames]
box_0 = box_0[uncorrelated_frames] if box_0 is not None else None

for state_idx, state_u_kn in enumerate(u_kn_per_k):
u_kn_per_k[state_idx] = state_u_kn[:, uncorrelated_frames]

Expand All @@ -281,17 +306,25 @@ def _load_samples(

n_expected_frames = int(u_kn.shape[1] // len(u_kn))

assert len(xyz_0) == n_expected_frames
assert box_0 is None or len(box_0) == n_expected_frames
if coord_state_idx < 0:
coord_state_idx = len(u_kn) + coord_state_idx

xyz_i, box_i = _load_trajectory(
output_dir / "trajectories", system, replica_to_state_idx, coord_state_idx
)
xyz_i = xyz_i[uncorrelated_frames]
box_i = box_i[uncorrelated_frames] if box_i is not None else None
assert len(xyz_i) == n_expected_frames
assert box_i is None or len(box_i) == n_expected_frames

return (
system.to(device),
beta,
pressure,
torch.tensor(u_kn, device=device, dtype=dtype),
torch.tensor(n_k, device=device),
torch.tensor(xyz_0, device=device, dtype=dtype),
torch.tensor(box_0, device=device, dtype=dtype) if box_0 is not None else None,
torch.tensor(xyz_i, device=device, dtype=dtype),
torch.tensor(box_i, device=device, dtype=dtype) if box_i is not None else None,
)


Expand Down Expand Up @@ -345,6 +378,26 @@ def compute_dg_and_grads(
return smee.utils.tensor_like(dg, force_field.potentials[0].parameters), grads


def compute_grads_solvent(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[torch.Tensor, ...]:
device = force_field.potentials[0].parameters.device

system, xyz, box, *_ = torch.load(output_dir / "pure.pt")
system.to(device)

grads = ()

if len(theta) > 0:
with torch.enable_grad():
energy = _compute_energy(system, force_field, xyz, box)
grads = torch.autograd.grad(energy.mean(), theta)

return grads


def reweight_dg_and_grads(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
Expand Down Expand Up @@ -392,3 +445,48 @@ def reweight_dg_and_grads(
grads = torch.autograd.grad((energy_0 * weights).sum(), theta)

return smee.utils.tensor_like(dg, energy_0), grads, n_eff


def reweight_grads_solvent(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[tuple[torch.Tensor, ...], float]:
import pymbar

device = force_field.potentials[0].parameters.device

system, xyz, box, beta, pressure, energy_old = torch.load(output_dir / "pure.pt")
system.to(device)

u_old = energy_old.detach().clone() * beta

if pressure is not None:
u_old += pressure * torch.det(box) * beta

with torch.enable_grad():
energy_new = _compute_energy(system, force_field, xyz, box)

u_new = energy_new.detach().clone() * beta

if pressure is not None:
u_new += pressure * torch.det(box) * beta

u_kn = numpy.stack([u_old.cpu().numpy(), u_new.cpu().numpy()])
n_k = numpy.array([len(u_old), 0])

mbar = pymbar.MBAR(
u_kn,
n_k,
solver_protocol=[{"method": "adaptive", "options": {"min_sc_iter": 0}}],
)

n_eff = mbar.compute_effective_sample_number().min().item()

weights = smee.utils.tensor_like(mbar.W_nk[:, 1], energy_new)
grads = ()

if len(theta) > 0:
grads = torch.autograd.grad((energy_new * weights).sum(), theta)

return grads, n_eff
31 changes: 22 additions & 9 deletions smee/mm/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def reweight_ensemble_averages(
class _ComputeDGSolv(torch.autograd.Function):
@staticmethod
def forward(ctx, kwargs, *theta: torch.Tensor):
from smee.mm._fe import compute_dg_and_grads
from smee.mm._fe import compute_dg_and_grads, compute_grads_solvent

force_field = _unpack_force_field(
theta,
Expand All @@ -638,11 +638,19 @@ def forward(ctx, kwargs, *theta: torch.Tensor):
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

if (kwargs["fep_dir"] / "solvent-a" / "pure.pt").exists():
raise NotImplementedError("solvent-a is expected to be vacuum")

dg_solv_b_d_theta = compute_grads_solvent(
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

dg = dg_a - dg_b
dg_d_theta = [None] * len(theta)

for grad_idx, orig_idx in enumerate(needs_grad):
dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx]
dg_d_theta[orig_idx] -= dg_solv_b_d_theta[grad_idx]

ctx.save_for_backward(*dg_d_theta)

Expand All @@ -659,7 +667,7 @@ def backward(ctx, *grad_outputs):
class _ReweightDGSolv(torch.autograd.Function):
@staticmethod
def forward(ctx, kwargs, *theta: torch.Tensor):
from smee.mm._fe import reweight_dg_and_grads
from smee.mm._fe import reweight_dg_and_grads, reweight_grads_solvent

force_field = _unpack_force_field(
theta,
Expand All @@ -684,15 +692,23 @@ def forward(ctx, kwargs, *theta: torch.Tensor):
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

if (kwargs["fep_dir"] / "solvent-a" / "pure.pt").exists():
raise NotImplementedError("solvent-a is expected to be vacuum")

dg_solv_b_d_theta, n_effective_solv = reweight_grads_solvent(
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

dg = -dg_a + dg_0 + dg_b
dg_d_theta = [None] * len(theta)

for grad_idx, orig_idx in enumerate(needs_grad):
dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx]
dg_d_theta[orig_idx] -= dg_solv_b_d_theta[grad_idx]

ctx.save_for_backward(*dg_d_theta)

return dg, min(n_effective_a, n_effective_b)
return dg, min(n_effective_a, n_effective_b, n_effective_solv)

@staticmethod
def backward(ctx, *grad_outputs):
Expand All @@ -708,9 +724,8 @@ def compute_dg_solv(
"""Computes ∆G_solv from existing FEP data.
Notes:
Currently the gradient of the pure solvent is not computed. This will mean the
gradient w.r.t. water parameters currently will be incorrect when using this
to compute hydration free energies.
It is assumed that FEP data was generated using the same force field as
``force_field``, and using ``generate_dg_solv_data``
Args:
force_field: The force field used to generate the FEP data.
Expand Down Expand Up @@ -743,9 +758,7 @@ def reweight_dg_solv(
"""Computes ∆G_solv by re-weighting existing FEP data.
Notes:
Currently the gradient of the pure solvent is not computed. This will mean the
gradient w.r.t. water parameters currently will be incorrect when using this
to compute hydration free energies.
It is assumed that FEP data was generated using ``generate_dg_solv_data``.
Args:
force_field: The force field to reweight to.
Expand Down
30 changes: 18 additions & 12 deletions smee/tests/mm/test_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ def load_systems(solute: str, solvent: str):

@pytest.mark.fe
def test_fe_ops(tmp_cwd):
# taken from a run on commit 7915d1e323318d2314a8b0322e7f44968c660c21
# taken from a run on commit ec3d272b466f761ed838e16a5ba7b97ceadc463b
expected_dg = torch.tensor(-3.8262).double()
expected_dg_dtheta = torch.tensor(
[
[1.0288e01],
[1.3976e01],
[2.6423e01],
[9.1453e00],
[9.0158e00],
[9.5534e00],
[1.0414e01],
[1.1257e01],
[-4.0618e00],
[5.0233e03],
[-1.3574e03],
[10.2679],
[13.3933],
[25.3670],
[9.3747],
[9.3279],
[9.1520],
[10.5614],
[9.6908],
[-4.4326],
[-17.3971],
[-38.5407],
]
).double()

Expand All @@ -62,11 +62,17 @@ def test_fe_ops(tmp_cwd):
dg = smee.mm.compute_dg_solv(ff, output_dir)
dg_dtheta = torch.autograd.grad(dg, params)[0]

print("dg COMP", dg, flush=True)
print("dg_dtheta COMP", dg_dtheta, dg, flush=True)

assert dg == pytest.approx(expected_dg, abs=0.5)
assert dg_dtheta == pytest.approx(expected_dg_dtheta, rel=1.1)

dg, n_eff = smee.mm.reweight_dg_solv(ff, output_dir, dg)
dg_dtheta = torch.autograd.grad(dg, params)[0]

print("dg REWEIGHT", dg, flush=True)
print("dg_dtheta REWEIGHT", dg_dtheta, dg, flush=True)

assert dg == pytest.approx(expected_dg, abs=0.5)
assert dg_dtheta == pytest.approx(expected_dg_dtheta, rel=1.1)
17 changes: 15 additions & 2 deletions smee/tests/mm/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,16 @@ def test_compute_dg_solv(mocker, tmp_path, mock_argon_tensors):
(torch.tensor(4.0).double(), (torch.tensor([[5.0, 6.0]]).double(),)),
],
)
mocker.patch(
"smee.mm._fe.compute_grads_solvent",
side_effect=[(torch.tensor([[8.0, 9.0]]).double(),)],
)

dg = smee.mm.compute_dg_solv(tensor_ff, tmp_path)
dg_dtheta = torch.autograd.grad(dg, params)[0]

assert torch.isclose(dg, torch.tensor(-3.0).double())
assert torch.allclose(dg_dtheta, torch.tensor([[3.0, 3.0]]).double())
assert torch.allclose(dg_dtheta, torch.tensor([[-5.0, -6.0]]).double())


def test_reweight_dg_solv(mocker, tmp_path, mock_argon_tensors):
Expand All @@ -469,14 +473,18 @@ def test_reweight_dg_solv(mocker, tmp_path, mock_argon_tensors):
(torch.tensor(5.0).double(), (torch.tensor([[6.0, 7.0]]).double(),), 8.0),
],
)
mocker.patch(
"smee.mm._fe.reweight_grads_solvent",
side_effect=[((torch.tensor([[9.0, 10.0]]).double(),), 11.0)],
)

dg_0 = torch.tensor(-3.0).double()

dg, n_eff = smee.mm.reweight_dg_solv(tensor_ff, tmp_path, dg_0, 3)
dg_dtheta = torch.autograd.grad(dg, params)[0]

assert torch.isclose(dg, torch.tensor(1.0).double())
assert torch.allclose(dg_dtheta, torch.tensor([[4.0, 4.0]]).double())
assert torch.allclose(dg_dtheta, torch.tensor([[-5.0, -6.0]]).double())

assert n_eff == 4.0

Expand All @@ -494,6 +502,11 @@ def test_reweight_dg_solv_error(mocker, tmp_path, mock_argon_tensors):
(torch.tensor(5.0).double(), (torch.tensor([[6.0, 7.0]]).double(),), 8.0),
],
)
mocker.patch(
"smee.mm._fe.reweight_grads_solvent",
side_effect=[((torch.tensor([[9.0, 10.0]]).double(),), 11.0)],
)

dg_0 = torch.tensor(-3.0).double()

with pytest.raises(smee.mm.NotEnoughSamplesError):
Expand Down

0 comments on commit c01d912

Please sign in to comment.