From b568f4c9cff1e0aff66bba6b7f15b468ffc22e65 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Thu, 31 Oct 2024 08:19:18 -0400 Subject: [PATCH] Add initial ops to compute / reweight Gsolv (#117) --- .github/workflows/ci.yaml | 4 +- devtools/envs/base.yaml | 4 + pyproject.toml | 8 +- smee/mm/__init__.py | 6 + smee/mm/_fe.py | 374 ++++++++++++++++++++++++++++++++++++++ smee/mm/_mm.py | 30 +-- smee/mm/_ops.py | 172 +++++++++++++++++- smee/mm/_utils.py | 29 +++ smee/tests/conftest.py | 6 + smee/tests/mm/test_fe.py | 72 ++++++++ smee/tests/mm/test_mm.py | 6 +- smee/tests/mm/test_ops.py | 71 +++++++- 12 files changed, 744 insertions(+), 38 deletions(-) create mode 100644 smee/mm/_fe.py create mode 100644 smee/mm/_utils.py create mode 100644 smee/tests/mm/test_fe.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3a1d133..f07922c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,7 +27,9 @@ jobs: make test-examples make docs - # TODO: Remove this line once pydantic 1.0 support is dropped + # TODO: Remove this once pydantic 1.0 support is dropped + # We remove absolv as femto needs pydantic >=2 + mamba remove --name smee --yes "absolv" mamba install --name smee --yes "pydantic <2" make test diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 253bc9a..9f57540 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -31,6 +31,10 @@ dependencies: - numpy - msgpack-python + # FE simulations + - mdtraj + - absolv >=1.0.1 + # Examples - jupyter - nbconvert diff --git a/pyproject.toml b/pyproject.toml index 2e4b390..2624277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ select = ["B","C","E","F","W","B9"] convention = "google" [tool.coverage.run] -omit = ["**/tests/*"] +omit = ["**/tests/*", "smee/mm/_fe.py"] [tool.coverage.report] exclude_lines = [ @@ -39,3 +39,9 @@ exclude_lines = [ "if TYPE_CHECKING:", "if typing.TYPE_CHECKING:", ] + +[tool.pytest.ini_options] +markers = [ + "fe: run free energy regression tests", +] +addopts = "-m 'not fe'" diff --git a/smee/mm/__init__.py b/smee/mm/__init__.py index 363e97b..24ead1d 100755 --- a/smee/mm/__init__.py +++ b/smee/mm/__init__.py @@ -1,17 +1,23 @@ """Compute differentiable ensemble averages using OpenMM and SMEE.""" from smee.mm._config import GenerateCoordsConfig, MinimizationConfig, SimulationConfig +from smee.mm._fe import generate_dg_solv_data from smee.mm._mm import generate_system_coords, simulate from smee.mm._ops import ( NotEnoughSamplesError, + compute_dg_solv, compute_ensemble_averages, + reweight_dg_solv, reweight_ensemble_averages, ) from smee.mm._reporters import TensorReporter, tensor_reporter, unpack_frames __all__ = [ + "compute_dg_solv", "compute_ensemble_averages", + "generate_dg_solv_data", "generate_system_coords", + "reweight_dg_solv", "reweight_ensemble_averages", "simulate", "GenerateCoordsConfig", diff --git a/smee/mm/_fe.py b/smee/mm/_fe.py new file mode 100644 index 0000000..f5b10d5 --- /dev/null +++ b/smee/mm/_fe.py @@ -0,0 +1,374 @@ +"""Compute ddG from the output of ``femto`` / ``absolv``""" + +import logging +import pathlib +import pickle +import typing + +import numpy +import openff.toolkit +import openmm.unit +import torch + +import smee +import smee.converters +import smee.mm._utils +import smee.utils + +if typing.TYPE_CHECKING: + import absolv.config + +_LOGGER = logging.getLogger(__name__) + +_NM_TO_ANGSTROM = 10.0 + + +def generate_dg_solv_data( + solute: smee.TensorTopology, + solvent: smee.TensorTopology, + force_field: smee.TensorForceField, + temperature: openmm.unit.Quantity = 298.15 * openmm.unit.kelvin, + pressure: openmm.unit.Quantity = 1.0 * openmm.unit.atmosphere, + vacuum_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None, + solvent_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None, + n_solvent: int = 216, + output_dir: pathlib.Path | None = None, +): + """Run a solvation free energy calculation using ``absolv``, and saves the output + such that a differentiable free energy can be computed. + + Args: + solute: The solute topology. + solvent: The solvent topology. + force_field: The force field to parameterize the system with. + temperature: The temperature to simulate at. + pressure: The pressure to simulate at. + vacuum_protocol: The protocol to use for the vacuum phase. + solvent_protocol: The protocol to use for the solvent phase. + n_solvent: The number of solvent molecules to use. + output_dir: The directory to write the output FEP data to. + """ + import absolv.config + import absolv.runner + import femto.md.config + + output_dir = pathlib.Path.cwd() if output_dir is None else output_dir + + if vacuum_protocol is None: + vacuum_protocol = absolv.config.EquilibriumProtocol( + production_protocol=absolv.config.HREMDProtocol( + n_steps_per_cycle=500, + n_cycles=2000, + integrator=femto.md.config.LangevinIntegrator( + timestep=1.0 * openmm.unit.femtosecond + ), + trajectory_interval=1, + ), + lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_VACUUM, + lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_VACUUM, + ) + if solvent_protocol is None: + solvent_protocol = absolv.config.EquilibriumProtocol( + production_protocol=absolv.config.HREMDProtocol( + n_steps_per_cycle=500, + n_cycles=1000, + integrator=femto.md.config.LangevinIntegrator( + timestep=4.0 * openmm.unit.femtosecond + ), + trajectory_interval=1, + trajectory_enforce_pbc=True, + ), + lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_SOLVENT, + lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_SOLVENT, + ) + + config = absolv.config.Config( + temperature=temperature, + pressure=pressure, + alchemical_protocol_a=vacuum_protocol, + alchemical_protocol_b=solvent_protocol, + ) + + solute_mol = openff.toolkit.Molecule.from_rdkit( + smee.mm._utils.topology_to_rdkit(solute), + allow_undefined_stereo=True, + ) + solvent_mol = openff.toolkit.Molecule.from_rdkit( + smee.mm._utils.topology_to_rdkit(solvent), + allow_undefined_stereo=True, + ) + + system_config = absolv.config.System( + solutes={solute_mol.to_smiles(mapped=True): 1}, + solvent_a=None, + solvent_b={solvent_mol.to_smiles(mapped=True): n_solvent}, + ) + + topologies = { + "solvent-a": smee.TensorSystem([solute], [1], is_periodic=False), + "solvent-b": smee.TensorSystem( + [solute, solvent], [1, n_solvent], is_periodic=True + ), + } + pressures = { + "solvent-a": None, + "solvent-b": pressure.value_in_unit(openmm.unit.atmosphere), + } + + for phase, topology in topologies.items(): + state = { + "system": topology, + "temperature": temperature.value_in_unit(openmm.unit.kelvin), + "pressure": pressures[phase], + } + + (output_dir / phase).mkdir(exist_ok=True, parents=True) + (output_dir / phase / "system.pkl").write_bytes(pickle.dumps(state)) + + def _parameterize( + top, coords, phase: typing.Literal["solvent-a", "solvent-b"] + ) -> openmm.System: + return smee.converters.convert_to_openmm_system(force_field, topologies[phase]) + + prepared_system_a, prepared_system_b = absolv.runner.setup( + system_config, config, _parameterize + ) + return absolv.runner.run_eq( + config, prepared_system_a, prepared_system_b, "CUDA", output_dir, parallel=True + ) + + +def _uncorrelated_frames(length: int, g: float) -> list[int]: + """Return the indices of frames that are un-correlated. + + Args: + length: The total number of correlated frames. + g: The statistical inefficiency of the data. + + Returns: + The indices of un-correlated frames. + """ + indices = [] + n = 0 + + while int(round(n * g)) < length: + t = int(round(n * g)) + if n == 0 or t != indices[n - 1]: + indices.append(t) + n += 1 + + return indices + + +def _load_trajectory( + trajectory_dir: pathlib.Path, + system: smee.TensorSystem, + replica_to_state_idx: numpy.ndarray, + state_idx: int = 0, +) -> tuple[numpy.ndarray, numpy.ndarray | None]: + import mdtraj + + n_states = len(list(trajectory_dir.glob("r*.dcd"))) + + topology_omm = smee.converters.convert_to_openmm_topology(system) + topology_md = mdtraj.Topology.from_openmm(topology_omm) + + trajectories = [ + mdtraj.load(str(trajectory_dir / f"r{i}.dcd"), top=topology_md) + for i in range(n_states) + ] + state_idxs = (replica_to_state_idx.reshape(-1, n_states).T == state_idx).argmax( + axis=0 + ) + + xyz = numpy.stack( + [ + trajectories[traj_idx].xyz[frame_idx] * _NM_TO_ANGSTROM + for frame_idx, traj_idx in enumerate(state_idxs) + ] + ) + + if trajectories[0].unitcell_vectors is None: + return xyz, None + + box = numpy.stack( + [ + trajectories[traj_idx].unitcell_vectors[frame_idx] * _NM_TO_ANGSTROM + for frame_idx, traj_idx in enumerate(state_idxs) + ] + ) + + return xyz, box + + +def _load_samples( + output_dir: pathlib.Path, device: str | torch.device, dtype: torch.dtype +) -> tuple[ + smee.TensorSystem, + float, + float | None, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, +]: + import pyarrow + import pymbar.timeseries + + state = pickle.loads((output_dir / "state.pkl").read_bytes()) + system: smee.TensorSystem = state["system"] + + temperature = state["temperature"] * openmm.unit.kelvin + + beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * temperature) + beta = beta.value_in_unit(openmm.unit.kilocalorie_per_mole**-1) + + pressure = None + + if state["pressure"] is not None: + pressure = state["pressure"] * openmm.unit.atmosphere + pressure = (pressure * openmm.unit.AVOGADRO_CONSTANT_NA).value_in_unit( + openmm.unit.kilocalorie_per_mole / openmm.unit.angstrom**3 + ) + + with pyarrow.OSFile(str(output_dir / "samples.arrow"), "rb") as file: + with pyarrow.RecordBatchStreamReader(file) as reader: + output_table = reader.read_all() + + replica_to_state_idx = numpy.hstack( + [numpy.array(x) for x in output_table["replica_to_state_idx"].to_pylist()] + ) + + # group the data along axis 1 so that data sampled in the same state is grouped. + # this will let us more easily de-correlate the data. + u_kn = numpy.hstack([numpy.array(x) for x in output_table["u_kn"].to_pylist()]) + u_kn_per_k = [u_kn[:, replica_to_state_idx == i] for i in range(len(u_kn))] + + xyz_0, box_0 = _load_trajectory( + output_dir / "trajectories", system, replica_to_state_idx + ) + + n_uncorrelated = u_kn.shape[1] // u_kn.shape[0] + + g = pymbar.timeseries.statistical_inefficiency_multiple( + [ + u_kn_per_k[i][i, i * n_uncorrelated : (i + 1) * n_uncorrelated] + for i in range(len(u_kn)) + ] + ) + uncorrelated_frames = _uncorrelated_frames(n_uncorrelated, g) + + xyz_0 = xyz_0[uncorrelated_frames] + box_0 = box_0[uncorrelated_frames] if box_0 is not None else None + + for state_idx, state_u_kn in enumerate(u_kn_per_k): + u_kn_per_k[state_idx] = state_u_kn[:, uncorrelated_frames] + + u_kn = numpy.hstack(u_kn_per_k) + n_k = numpy.array([len(uncorrelated_frames)] * u_kn.shape[0]) + + n_expected_frames = int(u_kn.shape[1] // len(u_kn)) + + assert len(xyz_0) == n_expected_frames + assert box_0 is None or len(box_0) == n_expected_frames + + return ( + system.to(device), + beta, + pressure, + torch.tensor(u_kn, device=device, dtype=dtype), + torch.tensor(n_k, device=device), + torch.tensor(xyz_0, device=device, dtype=dtype), + torch.tensor(box_0, device=device, dtype=dtype) if box_0 is not None else None, + ) + + +def _compute_energy( + system: smee.TensorSystem, + ff: smee.TensorForceField, + xyz_0: torch.Tensor, + box_0: torch.Tensor, +) -> torch.Tensor: + if system.is_periodic: + energy_per_frame = [ + smee.compute_energy(system, ff, c, b) + for c, b in zip(xyz_0, box_0, strict=True) + ] + energy = torch.concat(energy_per_frame) + else: + energy = smee.compute_energy(system, ff, xyz_0, box_0) + + return energy + + +def compute_dg_and_grads( + force_field: smee.TensorForceField, + theta: tuple[torch.Tensor, ...], + output_dir: pathlib.Path, +) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + import pymbar + + device = force_field.potentials[0].parameters.device + dtype = force_field.potentials[0].parameters.dtype + + system, beta, _, u_kn, n_k, xyz_0, box_0 = _load_samples(output_dir, device, dtype) + assert (box_0 is not None) == system.is_periodic + + mbar = pymbar.MBAR(u_kn.detach().cpu().numpy(), n_k.detach().cpu().numpy()) + + f_i = mbar.compute_free_energy_differences()["Delta_f"][0, :] + dg = (f_i[-1] - f_i[0]) / beta + + with torch.enable_grad(): + energy = _compute_energy(system, force_field, xyz_0, box_0) + grads = () + + if len(theta) > 0: + grads = torch.autograd.grad(energy.mean(), theta) + + return smee.utils.tensor_like(dg, energy), grads + + +def reweight_dg_and_grads( + force_field: smee.TensorForceField, + theta: tuple[torch.Tensor, ...], + output_dir: pathlib.Path, +) -> tuple[torch.Tensor, tuple[torch.Tensor, ...], float]: + import pymbar + + device = force_field.potentials[0].parameters.device + dtype = force_field.potentials[0].parameters.dtype + + system, beta, pressure, u_kn, n_k, xyz_0, box_0 = _load_samples( + output_dir, device, dtype + ) + assert (box_0 is not None) == system.is_periodic + assert (box_0 is not None) == (pressure is not None) + + u_0_old = u_kn[0, : n_k[0]] + + with torch.enable_grad(): + energy_0 = _compute_energy(system, force_field, xyz_0, box_0) + + u_0_new = energy_0.detach().clone() * beta + + if pressure is not None: + u_0_new += pressure * torch.det(box_0) * beta + + u_kn = numpy.stack([u_0_old.cpu().numpy(), u_0_new.cpu().numpy()]) + n_k = numpy.array([n_k[0].cpu(), 0]) + + mbar = pymbar.MBAR(u_kn, n_k) + + n_eff = mbar.compute_effective_sample_number().min().item() + + f_i = mbar.compute_free_energy_differences()["Delta_f"][0, :] + dg = (f_i[-1] - f_i[0]) / beta + + weights = smee.utils.tensor_like(mbar.W_nk[:, 1], energy_0) + grads = () + + if len(theta) > 0: + grads = torch.autograd.grad((energy_0 * weights).sum(), theta) + + return smee.utils.tensor_like(dg, energy_0), grads, n_eff diff --git a/smee/mm/_mm.py b/smee/mm/_mm.py index 98d7e50..0d26e3b 100644 --- a/smee/mm/_mm.py +++ b/smee/mm/_mm.py @@ -12,13 +12,12 @@ import openmm.app import openmm.unit import torch -from rdkit import Chem -from rdkit.Chem import AllChem import smee import smee.converters import smee.mm._config import smee.mm._reporters +import smee.mm._utils _LOGGER = logging.getLogger("smee.mm") @@ -73,36 +72,11 @@ def _apply_hmr( idx_offset += topology.n_particles -def _topology_to_rdkit(topology: smee.TensorTopology) -> Chem.Mol: - """Convert a topology to an RDKit molecule.""" - mol = Chem.RWMol() - - for atomic_num, formal_charge in zip( - topology.atomic_nums, topology.formal_charges, strict=True - ): - atom = Chem.Atom(int(atomic_num)) - atom.SetFormalCharge(int(formal_charge)) - mol.AddAtom(atom) - - for bond_idxs, bond_order in zip( - topology.bond_idxs, topology.bond_orders, strict=True - ): - idx_a, idx_b = int(bond_idxs[0]), int(bond_idxs[1]) - mol.AddBond(idx_a, idx_b, Chem.BondType(bond_order)) - - mol = Chem.Mol(mol) - mol.UpdatePropertyCache() - - AllChem.EmbedMolecule(mol) - - return mol - - def _topology_to_xyz( topology: smee.TensorTopology, force_field: smee.TensorForceField | None ) -> str: """Convert a topology to an RDKit molecule.""" - mol = _topology_to_rdkit(topology) + mol = smee.mm._utils.topology_to_rdkit(topology) elements = [atom.GetSymbol() for atom in mol.GetAtoms()] coords = torch.tensor(mol.GetConformer().GetPositions()) diff --git a/smee/mm/_ops.py b/smee/mm/_ops.py index 88ec422..9388aaf 100644 --- a/smee/mm/_ops.py +++ b/smee/mm/_ops.py @@ -180,11 +180,9 @@ def _compute_frame_observables( return values volume = torch.det(box_vectors) - values.update({"volume": volume, "volume^2": volume**2}) total_mass = _compute_mass(system) - values["density"] = total_mass / volume * _DENSITY_CONVERSION if pressure is not None: @@ -613,3 +611,173 @@ def reweight_ensemble_averages( *avg_outputs, columns = _ReweightAverageOp.apply(kwargs, *tensors) return {column: avg for avg, column in zip(avg_outputs, columns, strict=True)} + + +class _ComputeDGSolv(torch.autograd.Function): + @staticmethod + def forward(ctx, kwargs, *theta: torch.Tensor): + from smee.mm._fe import compute_dg_and_grads + + force_field = _unpack_force_field( + theta, + kwargs["parameter_lookup"], + kwargs["attribute_lookup"], + kwargs["has_v_sites"], + kwargs["force_field"], + ) + + needs_grad = [ + i for i, v in enumerate(theta) if v is not None and v.requires_grad + ] + theta_grad = tuple(theta[i] for i in needs_grad) + + dg_a, dg_d_theta_a = compute_dg_and_grads( + force_field, theta_grad, kwargs["fep_dir"] / "solvent-a" + ) + dg_b, dg_d_theta_b = compute_dg_and_grads( + force_field, theta_grad, kwargs["fep_dir"] / "solvent-b" + ) + + dg = dg_a - dg_b + dg_d_theta = [None] * len(theta) + + for grad_idx, orig_idx in enumerate(needs_grad): + dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx] + + ctx.save_for_backward(*dg_d_theta) + + return dg + + @staticmethod + def backward(ctx, *grad_outputs): + dg_d_theta_0 = ctx.saved_tensors + + grads = [None if v is None else v * grad_outputs[0] for v in dg_d_theta_0] + return tuple([None] + grads) + + +class _ReweightDGSolv(torch.autograd.Function): + @staticmethod + def forward(ctx, kwargs, *theta: torch.Tensor): + from smee.mm._fe import reweight_dg_and_grads + + force_field = _unpack_force_field( + theta, + kwargs["parameter_lookup"], + kwargs["attribute_lookup"], + kwargs["has_v_sites"], + kwargs["force_field"], + ) + + dg_0 = kwargs["dg_0"] + + needs_grad = [ + i for i, v in enumerate(theta) if v is not None and v.requires_grad + ] + theta_grad = tuple(theta[i] for i in needs_grad) + + # new FF G - old FF G + dg_a, dg_d_theta_a, n_effective_a = reweight_dg_and_grads( + force_field, theta_grad, kwargs["fep_dir"] / "solvent-a" + ) + dg_b, dg_d_theta_b, n_effective_b = reweight_dg_and_grads( + force_field, theta_grad, kwargs["fep_dir"] / "solvent-b" + ) + + dg = -dg_a + dg_0 + dg_b + dg_d_theta = [None] * len(theta) + + for grad_idx, orig_idx in enumerate(needs_grad): + dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx] + + ctx.save_for_backward(*dg_d_theta) + + return dg, min(n_effective_a, n_effective_b) + + @staticmethod + def backward(ctx, *grad_outputs): + dg_d_theta_0 = ctx.saved_tensors + + grads = [None if v is None else v * grad_outputs[0] for v in dg_d_theta_0] + return tuple([None] + grads) + + +def compute_dg_solv( + force_field: smee.TensorForceField, fep_dir: pathlib.Path +) -> torch.Tensor: + """Computes ∆G_solv from existing FEP data. + + Notes: + Currently the gradient of the pure solvent is not computed. This will mean the + gradient w.r.t. water parameters currently will be incorrect when using this + to compute hydration free energies. + + Args: + force_field: The force field used to generate the FEP data. + fep_dir: The directory containing the FEP data. + + Returns: + ∆G_solv [kcal/mol]. + """ + + tensors, parameter_lookup, attribute_lookup, has_v_sites = _pack_force_field( + force_field + ) + + kwargs = { + "force_field": force_field, + "parameter_lookup": parameter_lookup, + "attribute_lookup": attribute_lookup, + "has_v_sites": has_v_sites, + "fep_dir": fep_dir, + } + return _ComputeDGSolv.apply(kwargs, *tensors) + + +def reweight_dg_solv( + force_field: smee.TensorForceField, + fep_dir: pathlib.Path, + dg_0: torch.Tensor, + min_samples: int = 50, +) -> tuple[torch.Tensor, float]: + """Computes ∆G_solv by re-weighting existing FEP data. + + Notes: + Currently the gradient of the pure solvent is not computed. This will mean the + gradient w.r.t. water parameters currently will be incorrect when using this + to compute hydration free energies. + + Args: + force_field: The force field to reweight to. + fep_dir: The directory containing the FEP data. + dg_0: ∆G_solv [kcal/mol] computed with the force field used to generate the + FEP data. + min_samples: The minimum number of effective samples required to re-weight. + + Raises: + NotEnoughSamplesError: If the number of effective samples is less than + ``min_samples``. + + Returns: + The re-weighted ∆G_solv [kcal/mol], and the minimum number of effective samples + between the two phases. + """ + tensors, parameter_lookup, attribute_lookup, has_v_sites = _pack_force_field( + force_field + ) + + kwargs = { + "force_field": force_field, + "parameter_lookup": parameter_lookup, + "attribute_lookup": attribute_lookup, + "has_v_sites": has_v_sites, + "fep_dir": fep_dir, + "dg_0": dg_0, + } + + dg, n_eff = _ReweightDGSolv.apply(kwargs, *tensors) + + if n_eff < min_samples: + raise NotEnoughSamplesError + + return dg, n_eff diff --git a/smee/mm/_utils.py b/smee/mm/_utils.py new file mode 100644 index 0000000..265e711 --- /dev/null +++ b/smee/mm/_utils.py @@ -0,0 +1,29 @@ +from rdkit import Chem +from rdkit.Chem import AllChem + +import smee + + +def topology_to_rdkit(topology: smee.TensorTopology) -> Chem.Mol: + """Convert a topology to an RDKit molecule.""" + mol = Chem.RWMol() + + for atomic_num, formal_charge in zip( + topology.atomic_nums, topology.formal_charges, strict=True + ): + atom = Chem.Atom(int(atomic_num)) + atom.SetFormalCharge(int(formal_charge)) + mol.AddAtom(atom) + + for bond_idxs, bond_order in zip( + topology.bond_idxs, topology.bond_orders, strict=True + ): + idx_a, idx_b = int(bond_idxs[0]), int(bond_idxs[1]) + mol.AddBond(idx_a, idx_b, Chem.BondType(bond_order)) + + mol = Chem.Mol(mol) + mol.UpdatePropertyCache() + + AllChem.EmbedMolecule(mol) + + return mol diff --git a/smee/tests/conftest.py b/smee/tests/conftest.py index 408dbd3..2a5276a 100644 --- a/smee/tests/conftest.py +++ b/smee/tests/conftest.py @@ -17,6 +17,12 @@ _E = openff.units.unit.elementary_charge +@pytest.fixture +def tmp_cwd(tmp_path, monkeypatch) -> pathlib.Path: + monkeypatch.chdir(tmp_path) + yield tmp_path + + @pytest.fixture def test_data_dir() -> pathlib.Path: return pathlib.Path(__file__).parent / "data" diff --git a/smee/tests/mm/test_fe.py b/smee/tests/mm/test_fe.py new file mode 100644 index 0000000..76eb764 --- /dev/null +++ b/smee/tests/mm/test_fe.py @@ -0,0 +1,72 @@ +import pathlib + +import openff.interchange +import openff.toolkit +import pytest +import torch + +import smee.converters +import smee.mm + + +def load_systems(solute: str, solvent: str): + ff_off = openff.toolkit.ForceField("openff-2.0.0.offxml") + + solute_inter = openff.interchange.Interchange.from_smirnoff( + ff_off, + openff.toolkit.Molecule.from_smiles(solute).to_topology(), + ) + solvent_inter = openff.interchange.Interchange.from_smirnoff( + ff_off, + openff.toolkit.Molecule.from_smiles(solvent).to_topology(), + ) + solvent_inter.to_openmm_system() + + ff, (top_solute, top_solvent) = smee.converters.convert_interchange( + [solute_inter, solvent_inter] + ) + + return top_solute, top_solvent, ff + + +@pytest.mark.fe +def test_fe_ops(tmp_cwd): + # taken from a run on commit 7915d1e323318d2314a8b0322e7f44968c660c21 + expected_dg = torch.tensor(-3.8262).double() + expected_dg_dtheta = torch.tensor( + [ + [1.0288e01], + [1.3976e01], + [2.6423e01], + [9.1453e00], + [9.0158e00], + [9.5534e00], + [1.0414e01], + [1.1257e01], + [-4.0618e00], + [5.0233e03], + [-1.3574e03], + ] + ).double() + + top_solute, top_solvent, ff = load_systems("CCO", "O") + + output_dir = pathlib.Path("CCO") + output_dir.mkdir(parents=True, exist_ok=True) + + smee.mm.generate_dg_solv_data(top_solute, top_solvent, ff, output_dir=output_dir) + + params = ff.potentials_by_type["Electrostatics"].parameters + params.requires_grad_(True) + + dg = smee.mm.compute_dg_solv(ff, output_dir) + dg_dtheta = torch.autograd.grad(dg, params)[0] + + assert dg == pytest.approx(expected_dg, abs=0.5) + assert dg_dtheta == pytest.approx(expected_dg_dtheta, rel=1.1) + + dg, n_eff = smee.mm.reweight_dg_solv(ff, output_dir, dg) + dg_dtheta = torch.autograd.grad(dg, params)[0] + + assert dg == pytest.approx(expected_dg, abs=0.5) + assert dg_dtheta == pytest.approx(expected_dg_dtheta, rel=1.1) diff --git a/smee/tests/mm/test_mm.py b/smee/tests/mm/test_mm.py index 50012fa..af294ee 100644 --- a/smee/tests/mm/test_mm.py +++ b/smee/tests/mm/test_mm.py @@ -11,6 +11,7 @@ import smee import smee.converters import smee.mm +import smee.mm._utils import smee.tests.utils from smee.mm._mm import ( _apply_hmr, @@ -20,7 +21,6 @@ _get_platform, _get_state_log, _run_simulation, - _topology_to_rdkit, _topology_to_xyz, generate_system_coords, simulate, @@ -133,7 +133,7 @@ def test_topology_to_rdkit(): constraints=None, ) - mol = _topology_to_rdkit(topology) + mol = smee.mm._utils.topology_to_rdkit(topology) assert Chem.MolToSmiles(mol) == "[H]C([H])[O-].[H]O[H]" atomic_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()] @@ -153,7 +153,7 @@ def test_topology_to_xyz(mocker): mock_molecule.AddConformer(conformer) - mocker.patch("smee.mm._mm._topology_to_rdkit", return_value=mock_molecule) + mocker.patch("smee.mm._utils.topology_to_rdkit", return_value=mock_molecule) interchange = openff.interchange.Interchange.from_smirnoff( openff.toolkit.ForceField("tip4p_fb.offxml"), diff --git a/smee/tests/mm/test_ops.py b/smee/tests/mm/test_ops.py index 97ebdba..c4be74d 100644 --- a/smee/tests/mm/test_ops.py +++ b/smee/tests/mm/test_ops.py @@ -252,9 +252,9 @@ def test_compute_observables(tmp_path, mock_argon_tensors, mock_argon_params): beta = 2.0 expected_du_d_eps = 4.0 * ((sig / distances) ** 12 - (sig / distances) ** 6) - expected_du_d_sig = ( - eps * (sig**5) * (48.0 * (sig**6) - 24.0 * distances**6) - ) / (distances**12) + expected_du_d_sig = (eps * (sig**5) * (48.0 * (sig**6) - 24.0 * distances**6)) / ( + distances**12 + ) expected_potential = eps * expected_du_d_eps @@ -433,3 +433,68 @@ def test_reweight_ensemble_averages(mocker, tmp_path, mock_argon_tensors): assert reweight_grad.shape == ensemble_grad.shape assert torch.allclose(reweight_grad, ensemble_grad) + + +def test_compute_dg_solv(mocker, tmp_path, mock_argon_tensors): + tensor_ff, _ = mock_argon_tensors + + params = tensor_ff.potentials_by_type["vdW"].parameters + params.requires_grad = True + + mocker.patch( + "smee.mm._fe.compute_dg_and_grads", + side_effect=[ + (torch.tensor(1.0).double(), (torch.tensor([[2.0, 3.0]]).double(),)), + (torch.tensor(4.0).double(), (torch.tensor([[5.0, 6.0]]).double(),)), + ], + ) + + dg = smee.mm.compute_dg_solv(tensor_ff, tmp_path) + dg_dtheta = torch.autograd.grad(dg, params)[0] + + assert torch.isclose(dg, torch.tensor(-3.0).double()) + assert torch.allclose(dg_dtheta, torch.tensor([[3.0, 3.0]]).double()) + + +def test_reweight_dg_solv(mocker, tmp_path, mock_argon_tensors): + tensor_ff, _ = mock_argon_tensors + + params = tensor_ff.potentials_by_type["vdW"].parameters + params.requires_grad = True + + mocker.patch( + "smee.mm._fe.reweight_dg_and_grads", + side_effect=[ + (torch.tensor(1.0).double(), (torch.tensor([[2.0, 3.0]]).double(),), 4.0), + (torch.tensor(5.0).double(), (torch.tensor([[6.0, 7.0]]).double(),), 8.0), + ], + ) + + dg_0 = torch.tensor(-3.0).double() + + dg, n_eff = smee.mm.reweight_dg_solv(tensor_ff, tmp_path, dg_0, 3) + dg_dtheta = torch.autograd.grad(dg, params)[0] + + assert torch.isclose(dg, torch.tensor(1.0).double()) + assert torch.allclose(dg_dtheta, torch.tensor([[4.0, 4.0]]).double()) + + assert n_eff == 4.0 + + +def test_reweight_dg_solv_error(mocker, tmp_path, mock_argon_tensors): + tensor_ff, _ = mock_argon_tensors + + params = tensor_ff.potentials_by_type["vdW"].parameters + params.requires_grad = True + + mocker.patch( + "smee.mm._fe.reweight_dg_and_grads", + side_effect=[ + (torch.tensor(1.0).double(), (torch.tensor([[2.0, 3.0]]).double(),), 4.0), + (torch.tensor(5.0).double(), (torch.tensor([[6.0, 7.0]]).double(),), 8.0), + ], + ) + dg_0 = torch.tensor(-3.0).double() + + with pytest.raises(smee.mm.NotEnoughSamplesError): + smee.mm.reweight_dg_solv(tensor_ff, tmp_path, dg_0, 100)