Skip to content

Commit

Permalink
Support residue information (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored May 16, 2024
1 parent cf42531 commit 04e1072
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
2 changes: 1 addition & 1 deletion devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies:
- codecov

# Docs
- mkdocs
- mkdocs <1.6
- mkdocs-material
- mkdocs-gen-files
- mkdocs-literate-nav
Expand Down
26 changes: 26 additions & 0 deletions smee/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ class TensorTopology:
"""Distance constraints that should be applied **during MD simulations**. These
will not be used outside of MD simulations."""

residue_idxs: list[int] | None = None
"""The index of the residue that each atom in the topology belongs to with
``length=n_atoms``."""
residue_ids: list[str] | None = None
"""The names of the residues that each atom belongs to with ``length=n_residues``.
"""

chain_idxs: list[int] | None = None
"""The index of the chain that each atom in the topology belongs to with
``length=n_atoms``."""
chain_ids: list[str] | None = None
"""The names of the chains that each atom belongs to with ``length=n_chains``."""

@property
def n_atoms(self) -> int:
"""The number of atoms in the topology."""
Expand All @@ -174,6 +187,16 @@ def n_bonds(self) -> int:
"""The number of bonds in the topology."""
return len(self.bond_idxs)

@property
def n_residues(self) -> int:
"""The number of residues in the topology"""
return 0 if self.residue_ids is None else len(self.residue_ids)

@property
def n_chains(self) -> int:
"""The number of chains in the topology"""
return 0 if self.chain_ids is None else len(self.chain_ids)

@property
def n_v_sites(self) -> int:
"""The number of v-sites in the topology."""
Expand All @@ -200,6 +223,9 @@ def to(
if self.constraints is None
else self.constraints.to(device, precision)
),
self.residue_idxs,
self.residue_ids,
self.chain_ids,
)


Expand Down
18 changes: 18 additions & 0 deletions smee/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def test_n_bonds(self):
expected_n_bonds = 5
assert topology.n_bonds == expected_n_bonds

def test_n_residues(self):
topology = smee.tests.utils.topology_from_smiles("[Ar]")
topology.residue_ids = None
topology.residue_idxs = None
assert topology.n_residues == 0

topology.residue_ids = ["Ar"]
topology.residue_idxs = [0]
assert topology.n_residues == 1

def test_n_chains(self) -> int:
topology = smee.tests.utils.topology_from_smiles("[Ar]")
topology.residue_ids = [0]
topology.residue_idxs = ["UNK"]
topology.chain_idxs = [0]
topology.chain_ids = ["A"]
assert topology.n_chains == 1

def test_n_v_sites(self):
topology = smee.tests.utils.topology_from_smiles("CO")

Expand Down

0 comments on commit 04e1072

Please sign in to comment.