diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 13a49d9..0524403 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -53,7 +53,7 @@ dependencies: - codecov # Docs - - mkdocs + - mkdocs <1.6 - mkdocs-material - mkdocs-gen-files - mkdocs-literate-nav diff --git a/smee/_models.py b/smee/_models.py index 668a9d9..3c762e3 100644 --- a/smee/_models.py +++ b/smee/_models.py @@ -164,6 +164,19 @@ class TensorTopology: """Distance constraints that should be applied **during MD simulations**. These will not be used outside of MD simulations.""" + residue_idxs: list[int] | None = None + """The index of the residue that each atom in the topology belongs to with + ``length=n_atoms``.""" + residue_ids: list[str] | None = None + """The names of the residues that each atom belongs to with ``length=n_residues``. + """ + + chain_idxs: list[int] | None = None + """The index of the chain that each atom in the topology belongs to with + ``length=n_atoms``.""" + chain_ids: list[str] | None = None + """The names of the chains that each atom belongs to with ``length=n_chains``.""" + @property def n_atoms(self) -> int: """The number of atoms in the topology.""" @@ -174,6 +187,16 @@ def n_bonds(self) -> int: """The number of bonds in the topology.""" return len(self.bond_idxs) + @property + def n_residues(self) -> int: + """The number of residues in the topology""" + return 0 if self.residue_ids is None else len(self.residue_ids) + + @property + def n_chains(self) -> int: + """The number of chains in the topology""" + return 0 if self.chain_ids is None else len(self.chain_ids) + @property def n_v_sites(self) -> int: """The number of v-sites in the topology.""" @@ -200,6 +223,9 @@ def to( if self.constraints is None else self.constraints.to(device, precision) ), + self.residue_idxs, + self.residue_ids, + self.chain_ids, ) diff --git a/smee/tests/test_models.py b/smee/tests/test_models.py index 82e1097..791a3f9 100644 --- a/smee/tests/test_models.py +++ b/smee/tests/test_models.py @@ -54,6 +54,24 @@ def test_n_bonds(self): expected_n_bonds = 5 assert topology.n_bonds == expected_n_bonds + def test_n_residues(self): + topology = smee.tests.utils.topology_from_smiles("[Ar]") + topology.residue_ids = None + topology.residue_idxs = None + assert topology.n_residues == 0 + + topology.residue_ids = ["Ar"] + topology.residue_idxs = [0] + assert topology.n_residues == 1 + + def test_n_chains(self) -> int: + topology = smee.tests.utils.topology_from_smiles("[Ar]") + topology.residue_ids = [0] + topology.residue_idxs = ["UNK"] + topology.chain_idxs = [0] + topology.chain_ids = ["A"] + assert topology.n_chains == 1 + def test_n_v_sites(self): topology = smee.tests.utils.topology_from_smiles("CO")