Skip to content

Commit

Permalink
Add support for computing dG gradient w.r.t solvent params
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Nov 9, 2024
1 parent a6ea409 commit 392c7e3
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 22 deletions.
124 changes: 111 additions & 13 deletions smee/mm/_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,30 @@
_NM_TO_ANGSTROM = 10.0


def _extract_pure_solvent(
force_field: smee.TensorForceField, output_dir: pathlib.Path
) -> tuple[smee.TensorSystem, torch.Tensor, torch.Tensor, float, float, torch.Tensor]:
device = force_field.potentials[0].parameters.device
dtype = force_field.potentials[0].parameters.dtype

system, beta, pressure, u_kn, n_k, xyz, box = _load_samples(
output_dir, device, dtype, coord_state_idx=-1
)

if len(system.topologies) != 2 or system.n_copies[0] != 1:
raise NotImplementedError("only single solute systems are supported.")

n_solute_atoms = system.topologies[0].n_atoms
xyz = xyz[:, n_solute_atoms:, :]

system = smee.TensorSystem(
[system.topologies[1]], [system.n_copies[1]], is_periodic=True
)
energy = _compute_energy(system, force_field, xyz, box)

return system, xyz, box, beta, pressure, energy


def generate_dg_solv_data(
solute: smee.TensorTopology,
solvent: smee.TensorTopology,
Expand Down Expand Up @@ -145,10 +169,15 @@ def _parameterize(
parmed.openmm.load_topology(prepared_system_a.topology.to_openmm()),
)

return absolv.runner.run_eq(
result = absolv.runner.run_eq(
config, prepared_system_a, prepared_system_b, "CUDA", output_dir, parallel=True
)

solvent_b_output = _extract_pure_solvent(force_field, output_dir / "solvent-b")
torch.save(solvent_b_output, output_dir / "solvent-b" / "pure.pt")

return result


def _uncorrelated_frames(length: int, g: float) -> list[int]:
"""Return the indices of frames that are un-correlated.
Expand Down Expand Up @@ -214,7 +243,10 @@ def _load_trajectory(


def _load_samples(
output_dir: pathlib.Path, device: str | torch.device, dtype: torch.dtype
output_dir: pathlib.Path,
device: str | torch.device,
dtype: torch.dtype,
coord_state_idx: int = 0,
) -> tuple[
smee.TensorSystem,
float,
Expand Down Expand Up @@ -256,10 +288,6 @@ def _load_samples(
u_kn = numpy.hstack([numpy.array(x) for x in output_table["u_kn"].to_pylist()])
u_kn_per_k = [u_kn[:, replica_to_state_idx == i] for i in range(len(u_kn))]

xyz_0, box_0 = _load_trajectory(
output_dir / "trajectories", system, replica_to_state_idx
)

n_uncorrelated = u_kn.shape[1] // u_kn.shape[0]

g = pymbar.timeseries.statistical_inefficiency_multiple(
Expand All @@ -270,9 +298,6 @@ def _load_samples(
)
uncorrelated_frames = _uncorrelated_frames(n_uncorrelated, g)

xyz_0 = xyz_0[uncorrelated_frames]
box_0 = box_0[uncorrelated_frames] if box_0 is not None else None

for state_idx, state_u_kn in enumerate(u_kn_per_k):
u_kn_per_k[state_idx] = state_u_kn[:, uncorrelated_frames]

Expand All @@ -281,17 +306,25 @@ def _load_samples(

n_expected_frames = int(u_kn.shape[1] // len(u_kn))

assert len(xyz_0) == n_expected_frames
assert box_0 is None or len(box_0) == n_expected_frames
if coord_state_idx < 0:
coord_state_idx = len(u_kn) + coord_state_idx

xyz_i, box_i = _load_trajectory(
output_dir / "trajectories", system, replica_to_state_idx, coord_state_idx
)
xyz_i = xyz_i[uncorrelated_frames]
box_i = box_i[uncorrelated_frames] if box_i is not None else None
assert len(xyz_i) == n_expected_frames
assert box_i is None or len(box_i) == n_expected_frames

return (
system.to(device),
beta,
pressure,
torch.tensor(u_kn, device=device, dtype=dtype),
torch.tensor(n_k, device=device),
torch.tensor(xyz_0, device=device, dtype=dtype),
torch.tensor(box_0, device=device, dtype=dtype) if box_0 is not None else None,
torch.tensor(xyz_i, device=device, dtype=dtype),
torch.tensor(box_i, device=device, dtype=dtype) if box_i is not None else None,
)


Expand Down Expand Up @@ -345,6 +378,26 @@ def compute_dg_and_grads(
return smee.utils.tensor_like(dg, force_field.potentials[0].parameters), grads


def compute_grads_solvent(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[torch.Tensor, ...]:
device = force_field.potentials[0].parameters.device

system, xyz, box, *_ = torch.load(output_dir / "pure.pt")
system.to(device)

grads = ()

if len(theta) > 0:
with torch.enable_grad():
energy = _compute_energy(system, force_field, xyz, box)
grads = torch.autograd.grad(energy.mean(), theta)

return grads


def reweight_dg_and_grads(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
Expand Down Expand Up @@ -392,3 +445,48 @@ def reweight_dg_and_grads(
grads = torch.autograd.grad((energy_0 * weights).sum(), theta)

return smee.utils.tensor_like(dg, energy_0), grads, n_eff


def reweight_grads_solvent(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[tuple[torch.Tensor, ...], float]:
import pymbar

device = force_field.potentials[0].parameters.device

system, xyz, box, beta, pressure, energy_old = torch.load(output_dir / "pure.pt")
system.to(device)

u_old = energy_old.detach().clone() * beta

if pressure is not None:
u_old += pressure * torch.det(box) * beta

with torch.enable_grad():
energy_new = _compute_energy(system, force_field, xyz, box)

u_new = energy_new.detach().clone() * beta

if pressure is not None:
u_new += pressure * torch.det(box) * beta

u_kn = numpy.stack([u_old.cpu().numpy(), u_new.cpu().numpy()])
n_k = numpy.array([len(u_old), 0])

mbar = pymbar.MBAR(
u_kn,
n_k,
solver_protocol=[{"method": "adaptive", "options": {"min_sc_iter": 0}}],
)

n_eff = mbar.compute_effective_sample_number().min().item()

weights = smee.utils.tensor_like(mbar.W_nk[:, 1], energy_new)
grads = ()

if len(theta) > 0:
grads = torch.autograd.grad((energy_new * weights).sum(), theta)

return grads, n_eff
31 changes: 22 additions & 9 deletions smee/mm/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def reweight_ensemble_averages(
class _ComputeDGSolv(torch.autograd.Function):
@staticmethod
def forward(ctx, kwargs, *theta: torch.Tensor):
from smee.mm._fe import compute_dg_and_grads
from smee.mm._fe import compute_dg_and_grads, compute_grads_solvent

force_field = _unpack_force_field(
theta,
Expand All @@ -638,11 +638,19 @@ def forward(ctx, kwargs, *theta: torch.Tensor):
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

if (kwargs["fep_dir"] / "solvent-a" / "pure.pt").exists():
raise NotImplementedError("solvent-a is expected to be vacuum")

dg_solv_b_d_theta = compute_grads_solvent(
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

dg = dg_a - dg_b
dg_d_theta = [None] * len(theta)

for grad_idx, orig_idx in enumerate(needs_grad):
dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx]
dg_d_theta[orig_idx] += dg_solv_b_d_theta[grad_idx]

ctx.save_for_backward(*dg_d_theta)

Expand All @@ -659,7 +667,7 @@ def backward(ctx, *grad_outputs):
class _ReweightDGSolv(torch.autograd.Function):
@staticmethod
def forward(ctx, kwargs, *theta: torch.Tensor):
from smee.mm._fe import reweight_dg_and_grads
from smee.mm._fe import reweight_dg_and_grads, reweight_grads_solvent

force_field = _unpack_force_field(
theta,
Expand All @@ -684,15 +692,23 @@ def forward(ctx, kwargs, *theta: torch.Tensor):
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

if (kwargs["fep_dir"] / "solvent-a" / "pure.pt").exists():
raise NotImplementedError("solvent-a is expected to be vacuum")

dg_solv_b_d_theta, n_effective_solv = reweight_grads_solvent(
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

dg = -dg_a + dg_0 + dg_b
dg_d_theta = [None] * len(theta)

for grad_idx, orig_idx in enumerate(needs_grad):
dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx]
dg_d_theta[orig_idx] += dg_solv_b_d_theta[grad_idx]

ctx.save_for_backward(*dg_d_theta)

return dg, min(n_effective_a, n_effective_b)
return dg, min(n_effective_a, n_effective_b, n_effective_solv)

@staticmethod
def backward(ctx, *grad_outputs):
Expand All @@ -708,9 +724,8 @@ def compute_dg_solv(
"""Computes ∆G_solv from existing FEP data.
Notes:
Currently the gradient of the pure solvent is not computed. This will mean the
gradient w.r.t. water parameters currently will be incorrect when using this
to compute hydration free energies.
It is assumed that FEP data was generated using the same force field as
``force_field``, and using ``generate_dg_solv_data``
Args:
force_field: The force field used to generate the FEP data.
Expand Down Expand Up @@ -743,9 +758,7 @@ def reweight_dg_solv(
"""Computes ∆G_solv by re-weighting existing FEP data.
Notes:
Currently the gradient of the pure solvent is not computed. This will mean the
gradient w.r.t. water parameters currently will be incorrect when using this
to compute hydration free energies.
It is assumed that FEP data was generated using ``generate_dg_solv_data``.
Args:
force_field: The force field to reweight to.
Expand Down

0 comments on commit 392c7e3

Please sign in to comment.