diff --git a/MDANSE/Src/MDANSE/Chemistry/ChemicalSystem.py b/MDANSE/Src/MDANSE/Chemistry/ChemicalSystem.py index 7cf476dde8..0f22ad5b75 100644 --- a/MDANSE/Src/MDANSE/Chemistry/ChemicalSystem.py +++ b/MDANSE/Src/MDANSE/Chemistry/ChemicalSystem.py @@ -15,7 +15,7 @@ # from __future__ import annotations -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Dict, Any, Set import copy from functools import reduce @@ -229,6 +229,11 @@ def number_of_atoms(self) -> int: """The number of non-ghost atoms in the ChemicalSystem.""" return self._total_number_of_atoms + @property + def all_indices(self) -> Set[int]: + """The number of non-ghost atoms in the ChemicalSystem.""" + return set(self._atom_indices) + @property def total_number_of_atoms(self) -> int: """The number of all atoms in the ChemicalSystem, including ghost ones.""" diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/__init__.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/__init__.py index 1ce9f3c792..268c9cc94f 100644 --- a/MDANSE/Src/MDANSE/Framework/AtomSelector/__init__.py +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/__init__.py @@ -13,4 +13,3 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -from .selector import Selector diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/all_selector.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/all_selector.py deleted file mode 100644 index 8b1813e10c..0000000000 --- a/MDANSE/Src/MDANSE/Framework/AtomSelector/all_selector.py +++ /dev/null @@ -1,54 +0,0 @@ -# This file is part of MDANSE. -# -# MDANSE is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# -from typing import Union -from MDANSE.MolecularDynamics.Trajectory import Trajectory - - -def select_all( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all atoms in the chemical system except for the dummy - atoms. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - All atom indices except for dummy atoms or a bool if checking - match. - """ - system = trajectory.chemical_system - if check_exists: - return True - else: - dummy_list = [] - atom_list = system.atom_list - for atm in system._unique_elements: - if trajectory.get_atom_property(atm, "dummy"): - dummy_list.append(atm) - return set( - [ - index - for index in system._atom_indices - if atom_list[index] not in dummy_list - ] - ) diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/atom_selection.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/atom_selection.py new file mode 100644 index 0000000000..10280df9a5 --- /dev/null +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/atom_selection.py @@ -0,0 +1,79 @@ +# This file is part of MDANSE. +# +# MDANSE is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +from collections.abc import Sequence +from typing import Optional + +from MDANSE.MolecularDynamics.Trajectory import Trajectory + + +def select_atoms( + trajectory: Trajectory, + *, + index_list: Optional[Sequence[int]] = None, + index_range: Optional[Sequence[int]] = None, + index_slice: Optional[Sequence[int]] = None, + atom_types: Sequence[str] = (), + atom_names: Sequence[str] = (), + **_kwargs: str, +) -> set[int]: + """Select specific atoms in the trajectory. + + Atoms can be selected based + on indices, atom type or trajectory-specific atom name. + The atom type is normally the chemical element, while + the atom name can be more specific and depend on the + force field used. + + Parameters + ---------- + trajectory : Trajectory + A trajectory instance to which the selection is applied + index_list : Sequence[int] + a list of indices to be selected + index_range : Sequence[int] + a pair of (first, last+1) indices defining a range + index_slice : Sequence[int] + a sequence of (first, last+1, step) indices defining a slice + atom_types : Sequence[str] + a list of atom types (i.e. chemical elements) to be selected, given as string + atom_names : Sequence[str] + a list of atom names (as used by the MD engine, force field, etc.) to be selected + + Returns + ------- + set[int] + A set of indices which have been selected + + """ + selection = set() + system = trajectory.chemical_system + element_list = system.atom_list + name_list = system.name_list + indices = system.all_indices + if index_list is not None: + selection |= indices & set(index_list) + if index_range is not None: + selection |= indices & set(range(*index_range)) + if index_slice is not None: + selection |= indices & set(range(*index_slice)) + if atom_types: + new_indices = {index for index in indices if element_list[index] in atom_types} + selection |= new_indices + if atom_names: + new_indices = {index for index in indices if name_list[index] in atom_names} + selection |= new_indices + return selection diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/atom_selectors.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/atom_selectors.py deleted file mode 100644 index 1b7a054818..0000000000 --- a/MDANSE/Src/MDANSE/Framework/AtomSelector/atom_selectors.py +++ /dev/null @@ -1,236 +0,0 @@ -# This file is part of MDANSE. -# -# MDANSE is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# -from typing import Union -from MDANSE.MolecularDynamics.Trajectory import Trajectory - - -__all__ = [ - "select_element", - "select_dummy", - "select_atom_name", - "select_atom_fullname", - "select_hs_on_element", - "select_hs_on_heteroatom", - "select_index", -] - - -def select_element( - trajectory: Trajectory, symbol: str, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all atoms for the input element. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - symbol : str - Symbol of the element. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms. - """ - system = trajectory.chemical_system - pattern = f"[#{trajectory.get_atom_property(symbol, 'atomic_number')}]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) - - -def select_dummy( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all dummy atoms in the chemical system. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - All dummy atom indices or a bool if checking match. - """ - system = trajectory.chemical_system - dummy_list = ["Du", "dummy"] - if check_exists: - for atm in system.atom_list: - if atm in dummy_list: - return True - elif trajectory.get_atom_property(atm, "dummy"): - return True - return False - else: - for atm in system._unique_elements: - if trajectory.get_atom_property(atm, "dummy"): - dummy_list.append(atm) - return set( - [ - index - for index, element in enumerate(system.atom_list) - if element in dummy_list - ] - ) - - -def select_atom_name( - trajectory: Trajectory, name: str, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all atoms with the input name in the chemical system. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - name : str - The name of the atom to match. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - All atom indices or a bool if checking match. - """ - system = trajectory.chemical_system - if check_exists: - if name in system.atom_list: - return True - return False - else: - return set( - [index for index, element in enumerate(system.atom_list) if element == name] - ) - - -def select_atom_fullname( - trajectory: Trajectory, fullname: str, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all atoms with the input fullname in the chemical system. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - fullname : str - The fullname of the atom to match. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - All atom indices or a bool if checking match. - """ - system = trajectory.chemical_system - if check_exists: - if fullname in system.name_list: - return True - return False - else: - return set( - [index for index, name in enumerate(system.name_list) if name == fullname] - ) - - -def select_hs_on_element( - trajectory: Trajectory, symbol: str, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all H atoms bonded to the input element. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - symbol : str - Symbol of the element that the H atoms are bonded to. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms. - """ - system = trajectory.chemical_system - num = trajectory.get_atom_property(symbol, "atomic_number") - if check_exists: - return system.has_substructure_match(f"[#{num}]~[H]") - else: - xh_matches = system.get_substructure_matches(f"[#{num}]~[H]") - x_matches = system.get_substructure_matches(f"[#{num}]") - return xh_matches - x_matches - - -def select_hs_on_heteroatom( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects all H atoms bonded to any atom except carbon and - hydrogen. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms. - """ - system = trajectory.chemical_system - if check_exists: - return system.has_substructure_match("[!#6&!#1]~[H]") - else: - xh_matches = system.get_substructure_matches("[!#6&!#1]~[H]") - x_matches = system.get_substructure_matches("[!#6&!#1]") - return xh_matches - x_matches - - -def select_index( - trajectory: Trajectory, index: Union[int, str], check_exists: bool = False -) -> Union[set[int], bool]: - """Selects atom with index - just returns the set with the - index in it. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - index : int or str - The index to select. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The index in a set or a bool if checking match. - """ - if check_exists: - return True - else: - return {int(index)} diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/general_selection.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/general_selection.py new file mode 100644 index 0000000000..ea5b5ff18b --- /dev/null +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/general_selection.py @@ -0,0 +1,78 @@ +# This file is part of MDANSE. +# +# MDANSE is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +from MDANSE.MolecularDynamics.Trajectory import Trajectory + + +def select_all(trajectory: Trajectory, **_kwargs: str) -> set[int]: + """Select all the atoms in the trajectory. + + Parameters + ---------- + trajectory : Trajectory + A trajectory instance to which the selection is applied + + Returns + ------- + Set[int] + Set of all the atom indices + + """ + return trajectory.chemical_system.all_indices + + +def select_none(_trajectory: Trajectory, **_kwargs: str) -> set[int]: + """Return an empty selection. + + Parameters + ---------- + _trajectory : Trajectory + A trajectory instance, ignored in this selection + + Returns + ------- + Set[int] + An empty set. + + """ + return set() + + +def invert_selection( + trajectory: Trajectory, + selection: set[int], + **_kwargs: str, +) -> set[int]: + """Invert the current selection for the input trajectory. + + Return a set of all the indices that are present in the trajectory + and were not included in the input selection. + + Parameters + ---------- + trajectory : Trajectory + a trajectory containing atoms to be selected + selection : Set[int] + set of indices to be excluded from the set of all indices + + Returns + ------- + Set[int] + set of all the indices in the trajectory which were not in the input selection + + """ + all_indices = select_all(trajectory) + return all_indices - selection diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/group_selection.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/group_selection.py new file mode 100644 index 0000000000..2bb0a6ad9c --- /dev/null +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/group_selection.py @@ -0,0 +1,76 @@ +# This file is part of MDANSE. +# +# MDANSE is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +from collections.abc import Sequence + +from MDANSE.MolecularDynamics.Trajectory import Trajectory + + +def select_labels( + trajectory: Trajectory, + atom_labels: Sequence[str] = (), + **_kwargs: str, +) -> set[int]: + """Select atoms with a specific label in the trajectory. + + A residue name can be used as a label by MDANSE. + + Parameters + ---------- + trajectory : Trajectory + A trajectory instance to which the selection is applied + atom_labels : Sequence[str] + a list of string labels (e.g. residue names) by which to select atoms + + Returns + ------- + Set[int] + Set of atom indices corresponding to the selected labels + + """ + system = trajectory.chemical_system + return {system._labels[label] for label in atom_labels if label in system._labels} + + +def select_pattern( + trajectory: Trajectory, + *, + rdkit_pattern: str, + **_kwargs: str, +) -> set[int]: + """Select atoms according to the SMARTS string given as input. + + This will only work if molecules and bonds have been detected in the system. + If the bond information was not read from the input trajectory on conversion, + it can still be determined in a TrajectoryEditor run. + + Parameters + ---------- + trajectory : Trajectory + A trajectory instance to which the selection is applied + rdkit_pattern : str + a SMARTS string to be matched + + Returns + ------- + Set[int] + Set of atom indices matched by rdkit + + """ + selection = set() + system = trajectory.chemical_system + selection = system.get_substructure_matches(rdkit_pattern) + return selection diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/group_selectors.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/group_selectors.py deleted file mode 100644 index 0bc1e85672..0000000000 --- a/MDANSE/Src/MDANSE/Framework/AtomSelector/group_selectors.py +++ /dev/null @@ -1,177 +0,0 @@ -# This file is part of MDANSE. -# -# MDANSE is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# -from typing import Union -from MDANSE.MolecularDynamics.Trajectory import Trajectory - - -__all__ = [ - "select_primary_amine", - "select_hydroxy", - "select_methyl", - "select_phosphate", - "select_sulphate", - "select_thiol", -] - - -def select_primary_amine( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the N and H atoms of all primary amines. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms or a bool if checking match. - """ - system = trajectory.chemical_system - pattern = "[#7X3;H2;!$([#7][#6X3][!#6]);!$([#7][#6X2][!#6])](~[H])~[H]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) - - -def select_hydroxy( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the O and H atoms of all hydroxy groups including water. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms or a bool if checking match. - """ - system = trajectory.chemical_system - pattern = "[#8;H1,H2]~[H]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) - - -def select_methyl( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the C and H atoms of all methyl groups. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms or a bool if checking match. - """ - system = trajectory.chemical_system - pattern = "[#6;H3](~[H])(~[H])~[H]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) - - -def select_phosphate( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the P and O atoms of all phosphate groups. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - set[int] - The atom indices of the matched atoms or a bool if checking match. - """ - system = trajectory.chemical_system - pattern = "[#15X4](~[#8])(~[#8])(~[#8])~[#8]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) - - -def select_sulphate( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the S and O atoms of all sulphate groups. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms or a bool if checking match. - """ - system = trajectory.chemical_system - pattern = "[#16X4](~[#8])(~[#8])(~[#8])~[#8]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) - - -def select_thiol( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the S and H atoms of all thiol groups. - - Parameters - ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. - - Returns - ------- - Union[set[int], bool] - The atom indices of the matched atoms or a bool if checking match. - """ - system = trajectory.chemical_system - pattern = "[#16X2;H1]~[H]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/molecule_selectors.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/molecule_selection.py similarity index 54% rename from MDANSE/Src/MDANSE/Framework/AtomSelector/molecule_selectors.py rename to MDANSE/Src/MDANSE/Framework/AtomSelector/molecule_selection.py index 366f599833..4d8c388b9b 100644 --- a/MDANSE/Src/MDANSE/Framework/AtomSelector/molecule_selectors.py +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/molecule_selection.py @@ -13,35 +13,35 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -from typing import Union -from MDANSE.MolecularDynamics.Trajectory import Trajectory +from typing import Set, Sequence -__all__ = [ - "select_water", -] +from MDANSE.MolecularDynamics.Trajectory import Trajectory -def select_water( - trajectory: Trajectory, check_exists: bool = False -) -> Union[set[int], bool]: - """Selects the O and H atoms of all water molecules. +def select_molecules( + trajectory: Trajectory, molecule_names: Sequence[str] = (), **_kwargs: str +) -> Set[int]: + """Selects all the atoms belonging to the specified molecule types. Parameters ---------- - system : ChemicalSystem - The MDANSE chemical system. - check_exists : bool, optional - Check if a match exists. + trajectory : Trajectory + A trajectory instance to which the selection is applied + molecule_names : Sequence[str] + a list of molecule names (str) which are keys of ChemicalSystem._clusters Returns ------- - Union[set[int], bool] - The atom indices of the matched atoms or a bool if checking match. + Set[int] + Set of indices of atoms belonging to molecules from molecule_names """ + selection = set() system = trajectory.chemical_system - pattern = "[#8X2;H2](~[H])~[H]" - if check_exists: - return system.has_substructure_match(pattern) - else: - return system.get_substructure_matches(pattern) + selection = { + index + for molecule in molecule_names + for cluster in system._clusters.get(molecule, ()) + for index in cluster + } + return selection diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/selector.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/selector.py index b721f7db32..534a2d4ffe 100644 --- a/MDANSE/Src/MDANSE/Framework/AtomSelector/selector.py +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/selector.py @@ -13,362 +13,223 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -import copy + import json -from typing import Union +from typing import Any, Optional -from MDANSE.Chemistry.ChemicalSystem import ChemicalSystem -from MDANSE.MolecularDynamics.Trajectory import Trajectory -from MDANSE.Framework.AtomSelector.all_selector import select_all -from MDANSE.Framework.AtomSelector.atom_selectors import ( - select_atom_fullname, - select_atom_name, - select_dummy, - select_element, - select_hs_on_element, - select_hs_on_heteroatom, - select_index, +from MDANSE.Framework.AtomSelector.atom_selection import select_atoms +from MDANSE.Framework.AtomSelector.general_selection import ( + invert_selection, + select_all, + select_none, ) -from MDANSE.Framework.AtomSelector.group_selectors import ( - select_hydroxy, - select_methyl, - select_phosphate, - select_primary_amine, - select_sulphate, - select_thiol, +from MDANSE.Framework.AtomSelector.group_selection import select_labels, select_pattern +from MDANSE.Framework.AtomSelector.molecule_selection import select_molecules +from MDANSE.Framework.AtomSelector.spatial_selection import ( + select_positions, + select_sphere, ) -from MDANSE.Framework.AtomSelector.molecule_selectors import select_water - +from MDANSE.MolecularDynamics.Trajectory import Trajectory -class Selector: - """Used to get the indices of a subset of atoms of a chemical system. +function_lookup = { + function.__name__: function + for function in [ + select_all, + select_none, + invert_selection, + select_atoms, + select_molecules, + select_labels, + select_pattern, + select_positions, + select_sphere, + ] +} + + +class ReusableSelection: + """Stores an applies atom selection operations. + + A reusable sequence of operations which, when applied + to a trajectory, returns a set of atom indices based + on the specified criteria. - Attributes - ---------- - _default : dict[str, bool | dict] - The default settings. - _funcs : dict[str, Callable] - A dictionary of the functions. - _kwarg_keys : dict[str, str] - A dictionary of the function arg keys. """ - _default = { - "all": True, - "dummy": False, - "hs_on_heteroatom": False, - "primary_amine": False, - "hydroxy": False, - "methyl": False, - "phosphate": False, - "sulphate": False, - "thiol": False, - "water": False, - # e.g. {"S": True} - "hs_on_element": {}, - "element": {}, - "name": {}, - "fullname": {}, - # e.g. {1: True} - "index": {}, - } - - _funcs = { - "all": select_all, - "dummy": select_dummy, - "hs_on_heteroatom": select_hs_on_heteroatom, - "primary_amine": select_primary_amine, - "hydroxy": select_hydroxy, - "methyl": select_methyl, - "phosphate": select_phosphate, - "sulphate": select_sulphate, - "thiol": select_thiol, - "water": select_water, - "hs_on_element": select_hs_on_element, - "element": select_element, - "name": select_atom_name, - "fullname": select_atom_fullname, - "index": select_index, - } - - _kwarg_keys = { - "hs_on_element": "symbol", - "element": "symbol", - "name": "name", - "fullname": "fullname", - "index": "index", - } - - def __init__(self, trajectory: Trajectory) -> None: - """ + def __init__(self) -> None: + """Create an empty selection. + Parameters ---------- trajectory: Trajectory The chemical system to apply the selection to. - """ - system = trajectory.chemical_system - self.system = system - self.trajectory = trajectory - self.all_idxs = set(system._atom_indices) - self.settings = copy.deepcopy(self._default) - - symbols = set(system.atom_list) - # all possible values for the system - self._kwarg_vals = { - "element": symbols, - "hs_on_element": set( - [ - symbol - for symbol in symbols - if select_hs_on_element(trajectory, symbol, check_exists=True) - ] - ), - "name": set(system.atom_list), - "fullname": set(system.name_list), - "index": self.all_idxs, - } - - # figure out if a match exists for the selector function - self.match_exists = self.create_default_settings() - for k0, v0 in self.match_exists.items(): - if isinstance(v0, dict): - for k1 in v0.keys(): - self.match_exists[k0][k1] = True - else: - self.match_exists[k0] = self._funcs[k0]( - self.trajectory, check_exists=True - ) - self.settings = self.create_default_settings() - - def create_default_settings(self) -> dict[str, Union[bool, dict]]: - """Create a new settings dictionary with default settings. - - Returns - ------- - dict[str, Union[bool, dict]] - A settings dictionary. """ - settings = copy.deepcopy(self._default) - for k, vs in self._kwarg_vals.items(): - for v in sorted(vs): - settings[k][v] = False - return settings - - def reset_settings(self) -> None: - """Resets the settings back to the defaults.""" - self.settings = self.create_default_settings() - - def update_settings( - self, settings: dict[str, Union[bool, dict]], reset_first: bool = False - ) -> None: - """Updates the selection settings. + self.reset() + + def reset(self): + """Initialise the attributes to an empty list of operations.""" + self.system = None + self.trajectory = None + self.all_idxs = set() + self.operations = {} + + def set_selection( + self, + *, + number: Optional[int] = None, + function_parameters: dict[str, Any], + ): + """Append a new selection operation, or overwrite an existing one. Parameters ---------- - settings : dict[str, bool | dict] - The selection settings. - reset_first : bool, optional - Resets the settings to the default before loading. - - Raises - ------ - ValueError - Raises a ValueError if the inputted settings are not valid. + number : Union[int, None], optional + the position of the new selection in the sequence of operations + function_parameters : Dict[str, Any], optional + the dictionary of keyword arguments defining a selection operation + """ - if not self.check_valid_setting(settings): - raise ValueError( - f"Settings are not valid for the given chemical system - {settings}." - ) + number = int(number) if number is not None else len(self.operations) + self.operations[number] = function_parameters - if reset_first: - self.reset_settings() + def apply_single_selection( + self, + function_parameters: dict[str, Any], + trajectory: Trajectory, + selection: set[int], + ) -> set[int]: + """Modify the input selection based on input parameters. - for k0, v0 in settings.items(): - if isinstance(self.settings[k0], dict): - for k1, v1 in v0.items(): - self.settings[k0][k1] = v1 - else: - self.settings[k0] = v0 + This method applied a single selection operation to + an already exising selection for a specific trajectory. - def get_idxs(self) -> set[int]: - """The atom indices after applying the selection to the system. + Parameters + ---------- + function_parameters : dict[str, Any] + All the inputs needed to call an atom selection function + trajectory : Trajectory + Instance of the trajectory in which we are selecting atoms + selection : set[int] + indices of atoms that resulted from previous steps Returns ------- set[int] - The atoms indices. + indices of selected atoms from all operations so far """ - idxs = set([]) - - for k, v in self.settings.items(): - if isinstance(v, dict): - args = [{self._kwarg_keys[k]: i} for i in v.keys()] - switches = v.values() + function_name = function_parameters.get("function_name", "select_all") + if function_name == "invert_selection": + new_selection = self.all_idxs.difference(selection) + else: + operation_type = function_parameters.get("operation_type", "union") + function = function_lookup[function_name] + temp_selection = function(trajectory, **function_parameters) + if operation_type == "union": + new_selection = selection | temp_selection + elif operation_type == "intersection": + new_selection = selection & temp_selection + elif operation_type == "difference": + new_selection = selection - temp_selection else: - args = [{}] - switches = [v] - - for arg, switch in zip(args, switches): - if not switch: - continue - - idxs.update(self._funcs[k](self.trajectory, **arg)) + new_selection = temp_selection + return new_selection - return idxs + def validate_selection_string( + self, + json_string: str, + trajectory: Trajectory, + current_selection: set[int], + ) -> bool: + """Check if the new selection string changes the current selection. - def update_with_idxs(self, idxs: set[int]) -> None: - """Using the inputted idxs change the selection setting so - that it would return the same idxs with get_idxs. It will - switch off the setting if idxs is not a superset of the - selection for that setting. + Checks if the selection operation encoded in the input JSON string + will add any new atoms to the current selection on the given trajectory. Parameters ---------- - idxs : set[int] - With the indices of the atom selection. - """ - new_settings = self.create_default_settings() - new_settings["all"] = False - - added = set([]) - for k, v in self.settings.items(): - if k == "index": - continue - - if isinstance(v, dict): - args = [{self._kwarg_keys[k]: i} for i in v.keys()] - switches = v.values() - else: - args = [{}] - switches = [v] - - for arg, switch in zip(args, switches): - if not switch: - continue - - selection = self._funcs[k](self.trajectory, **arg) - if not idxs.issuperset(selection): - continue - - added.update(selection) - if isinstance(v, dict): - new_settings[k][arg[self._kwarg_keys[k]]] = True - else: - new_settings[k] = True - - for idx in idxs - added: - new_settings["index"][idx] = True - - self.settings = new_settings - - def settings_to_json(self) -> str: - """Return the minimal json string required to achieve the same - settings with the settings_from_json method. + json_string : str + new selection operation in a JSON string + trajectory : Trajectory + a trajectory instance for which current_selection is defined + current_selection : Set[int] + set of currently selected atom indices Returns ------- - str - A JSON string. + bool + True if the operation changes selection, False otherwise + """ - minimal_dict = {} - for k0, v0 in self.settings.items(): - if isinstance(v0, bool) and (k0 == "all" or k0 != "all" and v0): - minimal_dict[k0] = v0 - elif isinstance(v0, dict): - sub_list = [] - for k1, v1 in v0.items(): - if v1: - sub_list.append(k1) - if sub_list: - minimal_dict[k0] = sorted(sub_list) - return json.dumps(minimal_dict) - - def json_to_settings(self, json_string: str) -> dict[str, Union[bool, dict]]: - """Loads the json string and converts to a settings. + function_parameters = json.loads(json_string) + if not self.operations: + return True + operation_type = function_parameters.get("operation_type", "union") + selection = self.apply_single_selection( + function_parameters, trajectory, current_selection + ) + return ((selection - current_selection) and operation_type == "union") or ( + (current_selection - selection) and operation_type != "union" + ) + + def select_in_trajectory(self, trajectory: Trajectory) -> set[int]: + """Select atoms in the input trajectory. + + Applies all the selection operations in sequence to the + input trajectory, and returns the resulting set of indices. Parameters ---------- - json_string : str - The JSON string of settings. + trajectory : Trajectory + trajectory object in which the atoms will be selected Returns ------- - dict[str, Union[bool, dict]] - The selection settings. - """ - json_setting = json.loads(json_string) - settings = {} - for k0, v0 in json_setting.items(): - if isinstance(v0, bool): - settings[k0] = v0 - elif isinstance(v0, list): - sub_dict = {} - for k1 in v0: - sub_dict[k1] = True - if sub_dict: - settings[k0] = sub_dict - return settings - - def load_from_json(self, json_string: str) -> None: - """Load the selection settings from a JSON string. + set[int] + set of atom indices that have been selected in the input trajectory - Parameters - ---------- - json_string : str - The JSON string of settings. """ - self.update_settings(self.json_to_settings(json_string), reset_first=True) + selection = set() + self.all_idxs = trajectory.chemical_system.all_indices + sequence = sorted(map(int, self.operations)) + if not sequence: + return self.all_idxs + for number in sequence: + function_parameters = self.operations[number] + selection = self.apply_single_selection( + function_parameters, trajectory, selection + ) + return selection - def check_valid_setting(self, settings: dict[str, Union[bool, dict]]) -> bool: - """Checks that the input settings are valid. + def convert_to_json(self) -> str: + """Output all the operations as a JSON string. - Parameters - ---------- - settings : dict[str, bool | dict] - The selection settings. + For the purpose of storing the selection independent of the + trajectory it is acting on, this method encodes the sequence + of selection operations as a string. Returns ------- - bool - True if settings are valid. + str + All the operations of this selection, encoded as string + """ - setting_keys = self._default.keys() - dict_setting_keys = self._kwarg_keys.keys() - for k0, v0 in settings.items(): - if k0 not in setting_keys: - return False - - if k0 not in dict_setting_keys: - if not isinstance(v0, bool): - return False - - if k0 in dict_setting_keys: - if not isinstance(v0, dict): - return False - for k1, v1 in v0.items(): - if k1 not in self._kwarg_vals[k0]: - return False - if not isinstance(v1, bool): - return False - - return True - - def check_valid_json_settings(self, json_string: str) -> bool: - """Checks that the input JSON setting string is valid. + return json.dumps(self.operations) + + def load_from_json(self, json_string: str): + """Populate the operations sequence from the input string. + + Loads the atom selection operations from a JSON string. + Adds the operations to the selection sequence. Parameters ---------- json_string : str - The JSON string of settings. + A sequence of selection operations, encoded as a JSON string - Returns - ------- - bool - True if settings are valid. """ - try: - settings = self.json_to_settings(json_string) - except ValueError: - return False - return self.check_valid_setting(settings) + json_setting = json.loads(json_string) + for k0, v0 in json_setting.items(): + if not isinstance(v0, dict): + raise TypeError(f"Selection {v0} is not a dictionary.") + self.set_selection(number=k0, function_parameters=v0) diff --git a/MDANSE/Src/MDANSE/Framework/AtomSelector/spatial_selection.py b/MDANSE/Src/MDANSE/Framework/AtomSelector/spatial_selection.py new file mode 100644 index 0000000000..00a3c04634 --- /dev/null +++ b/MDANSE/Src/MDANSE/Framework/AtomSelector/spatial_selection.py @@ -0,0 +1,105 @@ +# This file is part of MDANSE. +# +# MDANSE is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +from collections.abc import Sequence +from typing import Optional + +import numpy as np +from scipy.spatial import KDTree + +from MDANSE.MolecularDynamics.Trajectory import Trajectory + + +def select_positions( + trajectory: Trajectory, + *, + frame_number: int = 0, + position_minimum: Optional[Sequence[float]] = None, + position_maximum: Optional[Sequence[float]] = None, + **_kwargs: str, +) -> set[int]: + """Select atoms based on their positions at a specified frame number. + + Lower and upper limits of x, y and z coordinates can be given as input. + + Parameters + ---------- + trajectory : Trajectory + a trajectory instance in which the atoms are being selected + frame_number : int, optional + trajectory frame at which to check the coordinates, by default 0 + position_minimum : Sequence[float], optional + (x, y, z) lower limits of coordinates to be selected, by default None + position_maximum : Sequence[float], optional + (x, y, z) upper limits of coordinates to be selected, by default None + + Returns + ------- + set[int] + indicies of atoms with coordinates within limits + + """ + coordinates = trajectory.coordinates(frame_number) + lower_limits = ( + np.array(position_minimum) + if position_minimum is not None + else np.array([-np.inf] * 3) + ) + upper_limits = ( + np.array(position_maximum) + if position_maximum is not None + else np.array([np.inf] * 3) + ) + valid = np.where( + ((coordinates > lower_limits) & (coordinates < upper_limits)).all(axis=1) + ) + return set(valid[0]) + + +def select_sphere( + trajectory: Trajectory, + *, + frame_number: int = 0, + sphere_centre: Sequence[float], + sphere_radius: float, + **_kwargs: str, +) -> set[int]: + """Select atoms within a sphere. + + Selects atoms at a given distance from a fixed point in space, + based on coordinates at a specific frame number. + + Parameters + ---------- + trajectory : Trajectory + A trajectory instance to which the selection is applied + frame_number : int, optional + trajectory frame at which to check the coordinates, by default 0 + sphere_centre : Sequence[float] + (x, y, z) coordinates of the centre of the selection + sphere_radius : float + distance from the centre within which to select atoms + + Returns + ------- + set[int] + set of indices of atoms inside the sphere + + """ + coordinates = trajectory.coordinates(frame_number) + kdtree = KDTree(coordinates) + indices = kdtree.query_ball_point(sphere_centre, sphere_radius) + return set(indices) diff --git a/MDANSE/Src/MDANSE/Framework/Configurators/AtomSelectionConfigurator.py b/MDANSE/Src/MDANSE/Framework/Configurators/AtomSelectionConfigurator.py index 5ed11f7e87..0602042a4e 100644 --- a/MDANSE/Src/MDANSE/Framework/Configurators/AtomSelectionConfigurator.py +++ b/MDANSE/Src/MDANSE/Framework/Configurators/AtomSelectionConfigurator.py @@ -14,12 +14,17 @@ # along with this program. If not, see . # +from collections import Counter, defaultdict +from json import JSONDecodeError + +from MDANSE.Framework.AtomSelector.selector import ReusableSelection from MDANSE.Framework.Configurators.IConfigurator import IConfigurator -from MDANSE.Framework.AtomSelector import Selector class AtomSelectionConfigurator(IConfigurator): - """This configurator allows the selection of a specific set of + """Selects atoms in trajectory based on the input string. + + This configurator allows the selection of a specific set of atoms on which the analysis will be performed. The defaults setting selects all atoms. @@ -27,9 +32,10 @@ class AtomSelectionConfigurator(IConfigurator): ---------- _default : str The defaults selection setting. + """ - _default = '{"all": true}' + _default = "{}" def configure(self, value: str) -> None: """Configure an input value. @@ -38,8 +44,12 @@ def configure(self, value: str) -> None: ---------- value : str The selection setting in a json readable format. + """ + self._original_input = value + trajConfig = self._configurable[self._dependencies["trajectory"]] + self.selector = ReusableSelection() if value is None: value = self._default @@ -48,28 +58,28 @@ def configure(self, value: str) -> None: self.error_status = "Invalid input value." return - selector = Selector(trajConfig["instance"]) - if not selector.check_valid_json_settings(value): + try: + self.selector.load_from_json(value) + except JSONDecodeError: self.error_status = "Invalid JSON string." return self["value"] = value - selector.load_from_json(value) - indices = selector.get_idxs() - - self["flatten_indices"] = sorted(list(indices)) + self.selector.load_from_json(value) + indices = self.selector.select_in_trajectory(trajConfig["instance"]) - trajConfig = self._configurable[self._dependencies["trajectory"]] + self["flatten_indices"] = sorted(indices) atoms = trajConfig["instance"].chemical_system.atom_list + self["total_number_of_atoms"] = len(atoms) selectedAtoms = [atoms[idx] for idx in self["flatten_indices"]] self["selection_length"] = len(self["flatten_indices"]) self["indices"] = [[idx] for idx in self["flatten_indices"]] self["elements"] = [[at] for at in selectedAtoms] - self["names"] = [at for at in selectedAtoms] + self["names"] = list(selectedAtoms) self["unique_names"] = sorted(set(self["names"])) self["masses"] = [ [trajConfig["instance"].get_atom_property(n, "atomic_weight")] @@ -78,50 +88,53 @@ def configure(self, value: str) -> None: if self["selection_length"] == 0: self.error_status = "The atom selection is empty." return - else: - self.error_status = "OK" + self.error_status = "OK" def get_natoms(self) -> dict[str, int]: - """ + """Count the selected atoms, per element. + Returns ------- dict A dictionary of the number of atom per element. - """ - nAtomsPerElement = {} - for v in self["names"]: - if v in nAtomsPerElement: - nAtomsPerElement[v] += 1 - else: - nAtomsPerElement[v] = 1 - return nAtomsPerElement + """ + return Counter(self["names"]) def get_total_natoms(self) -> int: - """ + """Count all the selected atoms. + Returns ------- int The total number of atoms selected. + """ return len(self["names"]) - def get_indices(self): - indicesPerElement = {} + def get_indices(self) -> dict[str, list[int]]: + """Group atom indices per chemical element. + + Returns + ------- + dict[str, list[int]] + For each atom type, a list of indices of selected atoms + + """ + indicesPerElement = defaultdict(list) for i, v in enumerate(self["names"]): - if v in indicesPerElement: - indicesPerElement[v].extend(self["indices"][i]) - else: - indicesPerElement[v] = self["indices"][i][:] + indicesPerElement[v].extend(self["indices"][i]) return indicesPerElement def get_information(self) -> str: - """ + """Create a text summary of the selection. + Returns ------- str - Some information on the atom selection. + Human-readable information on the atom selection. + """ if "selection_length" not in self: return "Not configured yet\n" @@ -131,15 +144,3 @@ def get_information(self) -> str: info.append(f"Selected elements:{self['unique_names']}") return "\n".join(info) + "\n" - - def get_selector(self) -> Selector: - """ - Returns - ------- - Selector - The atom selector object initialised with the trajectories - chemical system. - """ - traj_config = self._configurable[self._dependencies["trajectory"]] - selector = Selector(traj_config["instance"]) - return selector diff --git a/MDANSE/Src/MDANSE/Framework/Configurators/AtomTransmutationConfigurator.py b/MDANSE/Src/MDANSE/Framework/Configurators/AtomTransmutationConfigurator.py index af8eee2abe..20c6f8d016 100644 --- a/MDANSE/Src/MDANSE/Framework/Configurators/AtomTransmutationConfigurator.py +++ b/MDANSE/Src/MDANSE/Framework/Configurators/AtomTransmutationConfigurator.py @@ -18,8 +18,8 @@ from MDANSE.Framework.Configurators.IConfigurator import IConfigurator from MDANSE.Chemistry import ATOMS_DATABASE +from MDANSE.Framework.AtomSelector.selector import ReusableSelection from MDANSE.MolecularDynamics.Trajectory import Trajectory -from MDANSE.Framework.AtomSelector import Selector class AtomTransmuter: @@ -34,32 +34,30 @@ def __init__(self, trajectory: Trajectory) -> None: system : ChemicalSystem The chemical system object. """ - self.selector = Selector(trajectory) + self.selector = ReusableSelection() self._original_map = {} for number, element in enumerate(trajectory.chemical_system.atom_list): self._original_map[number] = element self._new_map = {} + self._current_trajectory = trajectory - def apply_transmutation( - self, selection_dict: dict[str, Union[bool, dict]], symbol: str - ) -> None: + def apply_transmutation(self, selection_string: str, symbol: str) -> None: """With the selection dictionary update selector and then update the transmutation map. Parameters ---------- - selection_dict: dict[str, Union[bool, dict]] - The selection setting to get the indices to map the inputted - symbol. + selection_string: str + the JSON string of the selection operation to use. symbol: str The element to map the selected atoms to. """ if symbol not in ATOMS_DATABASE: raise ValueError(f"{symbol} not found in the atom database.") - self.selector.update_settings(selection_dict, reset_first=True) - for idx in self.selector.get_idxs(): - self._new_map[idx] = symbol + self.selector.load_from_json(selection_string) + indices = self.selector.select_in_trajectory(self._current_trajectory) + self._new_map.update(dict.fromkeys(indices, symbol)) def get_setting(self) -> dict[int, str]: """ @@ -88,6 +86,7 @@ def get_json_setting(self) -> str: def reset_setting(self) -> None: """Resets the transmutation setting.""" self._new_map = {} + self.selector.reset() class AtomTransmutationConfigurator(IConfigurator): diff --git a/MDANSE/Src/MDANSE/Framework/Configurators/PartialChargeConfigurator.py b/MDANSE/Src/MDANSE/Framework/Configurators/PartialChargeConfigurator.py index a85bd46a9e..3f66c7f897 100644 --- a/MDANSE/Src/MDANSE/Framework/Configurators/PartialChargeConfigurator.py +++ b/MDANSE/Src/MDANSE/Framework/Configurators/PartialChargeConfigurator.py @@ -17,7 +17,7 @@ import json from MDANSE.Framework.Configurators.IConfigurator import IConfigurator -from MDANSE.Framework.AtomSelector import Selector +from MDANSE.Framework.AtomSelector.selector import ReusableSelection from MDANSE.MolecularDynamics.Trajectory import Trajectory @@ -35,7 +35,7 @@ def __init__(self, trajectory: Trajectory) -> None: """ system = trajectory.chemical_system charges = trajectory.charges(0) - self.selector = Selector(trajectory) + self._current_trajectory = trajectory self._original_map = {} for at_num, at in enumerate(system.atom_list): try: @@ -44,9 +44,7 @@ def __init__(self, trajectory: Trajectory) -> None: self._original_map[at_num] = 0.0 self._new_map = {} - def update_charges( - self, selection_dict: dict[str, Union[bool, dict]], charge: float - ) -> None: + def update_charges(self, selection_string: str, charge: float) -> None: """With the selection dictionary update the selector and then update the partial charge map. @@ -58,8 +56,10 @@ def update_charges( charge: float The partial charge to map the selected atoms to. """ - self.selector.update_settings(selection_dict, reset_first=True) - for idx in self.selector.get_idxs(): + selector = ReusableSelection() + selector.load_from_json(selection_string) + indices = selector.select_in_trajectory(self._current_trajectory) + for idx in indices: self._new_map[idx] = charge def get_full_setting(self) -> dict[int, float]: diff --git a/MDANSE/Src/MDANSE/Framework/Jobs/IJob.py b/MDANSE/Src/MDANSE/Framework/Jobs/IJob.py index e39881e448..edb27ce832 100644 --- a/MDANSE/Src/MDANSE/Framework/Jobs/IJob.py +++ b/MDANSE/Src/MDANSE/Framework/Jobs/IJob.py @@ -169,6 +169,20 @@ def initialize(self): ) except KeyError: LOG.error("IJob did not find 'write_logs' in output_files") + if selection := self.configuration.get("atom_selection"): + try: + array_length = selection["total_number_of_atoms"] + except KeyError: + LOG.warning( + "Job could not find total number of atoms in atom selection." + ) + else: + valid_indices = selection["flatten_indices"] + self._outputData.add( + "selected_atoms", + "LineOutputVariable", + [index in valid_indices for index in range(array_length)], + ) @abc.abstractmethod def run_step(self, index): diff --git a/MDANSE/Tests/UnitTests/AtomSelector/__init__.py b/MDANSE/Tests/UnitTests/AtomSelector/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/MDANSE/Tests/UnitTests/AtomSelector/test_all_selector.py b/MDANSE/Tests/UnitTests/AtomSelector/test_all_selector.py deleted file mode 100644 index 22eed92b8d..0000000000 --- a/MDANSE/Tests/UnitTests/AtomSelector/test_all_selector.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import pytest -from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData -from MDANSE.Framework.AtomSelector.all_selector import select_all - - -traj_2vb1 = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "Converted", "2vb1.mdt" -) - - -@pytest.fixture(scope="module") -def protein_trajectory(): - protein_trajectory = HDFTrajectoryInputData(traj_2vb1) - return protein_trajectory.trajectory - - -def test_select_all_returns_correct_number_of_atoms_matches(protein_trajectory): - selection = select_all(protein_trajectory) - assert len(selection) == 30714 diff --git a/MDANSE/Tests/UnitTests/AtomSelector/test_atom_selectors.py b/MDANSE/Tests/UnitTests/AtomSelector/test_atom_selectors.py deleted file mode 100644 index 6903bbd3a0..0000000000 --- a/MDANSE/Tests/UnitTests/AtomSelector/test_atom_selectors.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import pytest -from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData -from MDANSE.Framework.AtomSelector.atom_selectors import ( - select_element, - select_hs_on_heteroatom, - select_hs_on_element, - select_dummy, -) - - -traj_2vb1 = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "Converted", "2vb1.mdt" -) - - -@pytest.fixture(scope="module") -def protein_trajectory(): - protein_trajectory = HDFTrajectoryInputData(traj_2vb1) - return protein_trajectory.trajectory - - -def test_select_element_returns_true_as_match_exist( - protein_trajectory, -): - exists = select_element(protein_trajectory, "S", check_exists=True) - assert exists - - -def test_select_element_returns_false_as_match_does_not_exist( - protein_trajectory, -): - exists = select_element(protein_trajectory, "Si", check_exists=True) - assert not exists - - -def test_select_element_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_element(protein_trajectory, "S") - assert len(selection) == 10 - - -def test_select_hs_on_carbon_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hs_on_element(protein_trajectory, "C") - assert len(selection) == 696 - - -def test_select_hs_on_nitrogen_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hs_on_element(protein_trajectory, "N") - assert len(selection) == 243 - - -def test_select_hs_on_oxygen_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hs_on_element(protein_trajectory, "O") - assert len(selection) == 19184 - - -def test_select_hs_on_sulfur_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hs_on_element(protein_trajectory, "S") - assert len(selection) == 0 - - -def test_select_hs_on_silicon_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hs_on_element(protein_trajectory, "Si") - assert len(selection) == 0 - - -def test_select_hs_on_heteroatom_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hs_on_heteroatom(protein_trajectory) - assert len(selection) == 19427 - - -def test_select_dummy_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_dummy(protein_trajectory) - assert len(selection) == 0 diff --git a/MDANSE/Tests/UnitTests/AtomSelector/test_group_selectors.py b/MDANSE/Tests/UnitTests/AtomSelector/test_group_selectors.py deleted file mode 100644 index 322697cc32..0000000000 --- a/MDANSE/Tests/UnitTests/AtomSelector/test_group_selectors.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import pytest -from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData -from MDANSE.Framework.AtomSelector.group_selectors import ( - select_primary_amine, - select_hydroxy, - select_methyl, - select_phosphate, - select_sulphate, - select_thiol, -) - - -traj_2vb1 = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "Converted", "2vb1.mdt" -) -traj_1gip = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "Converted", "1gip.mdt" -) - - -@pytest.fixture(scope="module") -def protein_trajectory(): - protein_trajectory = HDFTrajectoryInputData(traj_2vb1) - return protein_trajectory.trajectory - - -@pytest.fixture(scope="module") -def nucleic_acid_chemical_system(): - protein_trajectory = HDFTrajectoryInputData(traj_1gip) - return protein_trajectory.trajectory - - -def test_select_primary_amine_returns_true_as_match_exists( - protein_trajectory, -): - exists = select_primary_amine(protein_trajectory, check_exists=True) - assert exists - - -def test_select_sulphate_returns_false_as_match_does_not_exist( - nucleic_acid_chemical_system, -): - exists = select_sulphate(nucleic_acid_chemical_system, check_exists=True) - assert not exists - - -def test_select_primary_amine_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_primary_amine(protein_trajectory) - assert len(selection) == 117 - - -def test_select_hydroxy_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_hydroxy(protein_trajectory) - assert len(selection) == 28786 - - -def test_select_methyl_returns_correct_number_of_atom_matches(protein_trajectory): - selection = select_methyl(protein_trajectory) - assert len(selection) == 244 - - -def test_select_phosphate_returns_correct_number_of_atom_matches( - nucleic_acid_chemical_system, -): - selection = select_phosphate(nucleic_acid_chemical_system) - assert len(selection) == 110 - - -def test_select_sulphate_returns_correct_number_of_atom_matches( - nucleic_acid_chemical_system, -): - selection = select_sulphate(nucleic_acid_chemical_system) - assert len(selection) == 0 - - -def test_select_thiol_returns_correct_number_of_atoms_matches(protein_trajectory): - selection = select_thiol(protein_trajectory) - assert len(selection) == 0 diff --git a/MDANSE/Tests/UnitTests/AtomSelector/test_molecule_selectors.py b/MDANSE/Tests/UnitTests/AtomSelector/test_molecule_selectors.py deleted file mode 100644 index d540fa6fd0..0000000000 --- a/MDANSE/Tests/UnitTests/AtomSelector/test_molecule_selectors.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import pytest -from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData -from MDANSE.Framework.AtomSelector.molecule_selectors import select_water - - -traj_2vb1 = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "Converted", "2vb1.mdt" -) - - -@pytest.fixture(scope="module") -def protein_trajectory(): - protein_trajectory = HDFTrajectoryInputData(traj_2vb1) - return protein_trajectory.trajectory - - -def test_select_water_returns_true_as_match_exists( - protein_trajectory, -): - exists = select_water(protein_trajectory, check_exists=True) - assert exists - - -def test_select_water_returns_correct_number_of_atom_matches( - protein_trajectory, -): - selection = select_water(protein_trajectory) - assert len(selection) == 28746 diff --git a/MDANSE/Tests/UnitTests/AtomSelector/test_selector.py b/MDANSE/Tests/UnitTests/AtomSelector/test_selector.py deleted file mode 100644 index a853a036f3..0000000000 --- a/MDANSE/Tests/UnitTests/AtomSelector/test_selector.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import pytest -from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData -from MDANSE.Framework.AtomSelector.selector import Selector - - -traj_2vb1 = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "Converted", "2vb1.mdt" -) - - -@pytest.fixture(scope="module") -def protein_trajectory(): - protein_trajectory = HDFTrajectoryInputData(traj_2vb1) - return protein_trajectory.trajectory - - -def test_selector_returns_all_atom_idxs(protein_trajectory): - selector = Selector(protein_trajectory) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 30714 - - -def test_selector_returns_all_atom_idxs_with_all_and_sulfurs_selected( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.settings["all"] = True - selector.settings["element"] = {"S": True} - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 30714 - - -def test_selector_returns_correct_number_of_atom_idxs_when_sulfur_atoms_are_selected( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.settings["all"] = False - selector.settings["element"] = {"S": True} - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 10 - - -def test_selector_returns_correct_number_of_atom_idxs_when_sulfur_atoms_are_selected_when_get_idxs_is_called_twice( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.settings["all"] = False - selector.settings["element"] = {"S": True} - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 10 - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 10 - - -def test_selector_returns_correct_number_of_atom_idxs_when_waters_are_selected( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.settings["all"] = False - selector.settings["water"] = True - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 28746 - - -def test_selector_returns_correct_number_of_atom_idxs_when_water_is_turned_on_and_off( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.settings["all"] = False - selector.settings["water"] = True - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 28746 - selector.settings["water"] = False - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 0 - - -def test_selector_returns_correct_number_of_atom_idxs_when_waters_and_sulfurs_are_selected( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.settings["all"] = False - selector.settings["water"] = True - selector.settings["element"] = {"S": True} - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 28746 + 10 - - -def test_selector_returns_correct_number_of_atom_idxs_when_waters_and_sulfurs_are_selected_with_settings_loaded_as_a_dict( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "element": {"S": True}, "water": True}) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 28746 + 10 - - -def test_selector_json_dump_0(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "element": {"S": True}}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "element": ["S"]}' - - -def test_selector_json_dump_1(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "element": {"S": True}, "water": True}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "water": true, "element": ["S"]}' - - -def test_selector_json_dump_2(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "water": True}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "water": true}' - - -def test_selector_json_dump_3(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings( - {"all": False, "element": {"S": True, "H": True}, "water": True} - ) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "water": true, "element": ["H", "S"]}' - - -def test_selector_json_dump_4(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings( - { - "all": False, - "element": {"S": True, "H": True}, - "water": True, - "index": {0: True, 1: True}, - } - ) - json_dump = selector.settings_to_json() - assert ( - json_dump - == '{"all": false, "water": true, "element": ["H", "S"], "index": [0, 1]}' - ) - - -def test_selector_json_dump_with_second_update(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False}) - selector.update_settings({"element": {"S": True, "O": True}, "water": True}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "water": true, "element": ["O", "S"]}' - - -def test_selector_json_dump_with_third_update(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False}) - selector.update_settings({"element": {"S": True, "O": True}, "water": True}) - selector.update_settings({"element": {"S": False}}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "water": true, "element": ["O"]}' - - -def test_selector_json_dump_with_fourth_update(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False}) - selector.update_settings({"element": {"S": True, "O": True}, "water": True}) - selector.update_settings({"element": {"S": False}}) - selector.update_settings({"water": False}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "element": ["O"]}' - - -def test_selector_returns_correct_number_of_atom_idxs_after_setting_settings_again_with_reset_first( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "element": {"S": True}, "water": True}) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 28746 + 10 - - selector.update_settings( - { - "all": False, - "element": {"S": True}, - }, - reset_first=True, - ) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 10 - - -def test_selector_json_dump_and_load_0(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "index": {0: True, 1: True}}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "index": [0, 1]}' - selector.load_from_json(json_dump) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 2 - - -def test_selector_json_dump_and_load_1(protein_trajectory): - selector = Selector(protein_trajectory) - selector.update_settings({"all": False, "element": {"S": True}, "water": True}) - json_dump = selector.settings_to_json() - assert json_dump == '{"all": false, "water": true, "element": ["S"]}' - selector.load_from_json(json_dump) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 28746 + 10 - - -def test_selector_returns_correct_number_of_atom_idxs_when_indexes_0_and_1_are_selected( - protein_trajectory, -): - selector = Selector(protein_trajectory) - selector.update_settings( - { - "all": False, - "index": {0: True, 1: True}, - } - ) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 2 - - -def test_selector_returns_true_with_correct_setting_check(protein_trajectory): - selector = Selector(protein_trajectory) - assert selector.check_valid_setting( - { - "all": False, - "index": {0: True, 1: True}, - } - ) - - -def test_selector_returns_false_with_incorrect_setting_check_0(protein_trajectory): - selector = Selector(protein_trajectory) - assert not selector.check_valid_setting( - { - "alle": False, - "index": {0: True, 1: True}, - } - ) - - -def test_selector_returns_false_with_incorrect_setting_check_1(protein_trajectory): - selector = Selector(protein_trajectory) - assert not selector.check_valid_setting( - { - "all": False, - "index": {-1: True, 1: True}, - } - ) - - -def test_selector_returns_false_with_incorrect_setting_check_2(protein_trajectory): - selector = Selector(protein_trajectory) - assert not selector.check_valid_setting( - { - "all": False, - "index": {0: True, 1: True}, - "element": {"Ss": True}, - } - ) - - -def test_selector_returns_true_with_correct_json_setting_0(protein_trajectory): - selector = Selector(protein_trajectory) - assert selector.check_valid_json_settings( - '{"all": false, "water": true, "element": {"S": true}}' - ) - - -def test_selector_returns_true_with_correct_json_setting_1(protein_trajectory): - selector = Selector(protein_trajectory) - assert selector.check_valid_json_settings('{"all": false, "index": [0, 1]}') - - -def test_selector_returns_false_with_incorrect_json_setting_0(protein_trajectory): - selector = Selector(protein_trajectory) - assert not selector.check_valid_json_settings( - '{all: false, "water": true, "element": {"S": true}}' - ) - - -def test_selector_returns_false_with_incorrect_json_setting_1(protein_trajectory): - selector = Selector(protein_trajectory) - assert not selector.check_valid_json_settings('{"all": false, "index": [0, "1"]}') - - -def test_selector_returns_false_with_incorrect_json_setting_2(protein_trajectory): - selector = Selector(protein_trajectory) - assert not selector.check_valid_json_settings('{"all": False, "index": ["0", "1"]}') - - -@pytest.mark.xfail(reason="see docstring") -def test_selector_with_atom_fullname(protein_trajectory): - """For the moment, full names of atoms are not implemented - in the ChemicalSystem.""" - selector = Selector(protein_trajectory) - selector.update_settings( - { - "all": False, - "fullname": {"...LYS1.N": True, "...VAL2.O": True}, - } - ) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 2 - - -@pytest.mark.xfail(reason="see docstring") -def test_selector_with_atom_name(protein_trajectory): - """At the moment the oxygen in water has the same - atom name as the oxygen in the protein. - We will have to decide if this is acceptable. - """ - selector = Selector(protein_trajectory) - selector.update_settings( - { - "all": False, - "name": {"N": True, "O": True}, - } - ) - atm_idxs = selector.get_idxs() - assert len(atm_idxs) == 258 diff --git a/MDANSE/Tests/UnitTests/AtomTransmutation/test_transmutation.py b/MDANSE/Tests/UnitTests/AtomTransmutation/test_transmutation.py index 60e92cc42a..eb3eaa6712 100644 --- a/MDANSE/Tests/UnitTests/AtomTransmutation/test_transmutation.py +++ b/MDANSE/Tests/UnitTests/AtomTransmutation/test_transmutation.py @@ -29,7 +29,7 @@ def test_atom_transmutation_return_dict_with_transmutations_with_incorrect_eleme atm_transmuter = AtomTransmuter(protein_trajectory) with pytest.raises(ValueError): atm_transmuter.apply_transmutation( - {"all": False, "element": {"S": True}}, "CCC" + '{"0": {"function_name": "select_atoms", "atom_types": ["S"]}}', "CCC" ) @@ -37,7 +37,7 @@ def test_atom_transmutation_return_dict_with_transmutations_with_s_element_trans protein_trajectory, ): atm_transmuter = AtomTransmuter(protein_trajectory) - atm_transmuter.apply_transmutation({"all": False, "element": {"S": True}}, "C") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "atom_types": ["S"]}}', "C") mapping = atm_transmuter.get_setting() assert mapping == { 98: "C", @@ -57,8 +57,8 @@ def test_atom_transmutation_return_dict_with_transmutations_with_s_element_trans protein_trajectory, ): atm_transmuter = AtomTransmuter(protein_trajectory) - atm_transmuter.apply_transmutation({"all": False, "element": {"S": True}}, "C") - atm_transmuter.apply_transmutation({"all": False, "index": {98: True}}, "N") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "atom_types": ["S"]}}', "C") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "index_list": [98]}}', "N") mapping = atm_transmuter.get_setting() assert mapping == { 98: "N", @@ -78,8 +78,8 @@ def test_atom_transmutation_return_dict_with_transmutations_with_s_element_trans protein_trajectory, ): atm_transmuter = AtomTransmuter(protein_trajectory) - atm_transmuter.apply_transmutation({"all": False, "element": {"S": True}}, "C") - atm_transmuter.apply_transmutation({"all": False, "index": {98: True}}, "S") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "atom_types": ["S"]}}', "C") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "index_list": [98]}}', "S") mapping = atm_transmuter.get_setting() assert mapping == { 175: "C", @@ -98,10 +98,8 @@ def test_atom_transmutation_return_dict_with_transmutations_with_s_element_trans protein_trajectory, ): atm_transmuter = AtomTransmuter(protein_trajectory) - atm_transmuter.apply_transmutation({"all": False, "element": {"S": True}}, "C") - atm_transmuter.apply_transmutation( - {"all": False, "index": {98: True, 99: True}}, "S" - ) + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "atom_types": ["S"]}}', "C") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "index_list": [98, 99]}}', "S") mapping = atm_transmuter.get_setting() assert mapping == { 99: "S", @@ -119,10 +117,8 @@ def test_atom_transmutation_return_dict_with_transmutations_with_s_element_trans def test_atom_transmutation_return_empty_dict_after_reset(protein_trajectory): atm_transmuter = AtomTransmuter(protein_trajectory) - atm_transmuter.apply_transmutation({"all": False, "element": {"S": True}}, "C") - atm_transmuter.apply_transmutation( - {"all": False, "index": {98: True, 99: True}}, "S" - ) + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "atom_types": ["S"]}}', "C") + atm_transmuter.apply_transmutation('{"0": {"function_name": "select_atoms", "index_list": [98, 99]}}', "S") atm_transmuter.reset_setting() mapping = atm_transmuter.get_setting() assert mapping == {} diff --git a/MDANSE/Tests/UnitTests/TrajectoryEditor/test_editor.py b/MDANSE/Tests/UnitTests/TrajectoryEditor/test_editor.py index c2ef04713c..40d980a780 100644 --- a/MDANSE/Tests/UnitTests/TrajectoryEditor/test_editor.py +++ b/MDANSE/Tests/UnitTests/TrajectoryEditor/test_editor.py @@ -119,7 +119,7 @@ def test_editor_atoms(): parameters["output_files"] = (temp_name, 64, 128, "gzip", "INFO") parameters["trajectory"] = short_traj parameters["frames"] = (0, 501, 1) - parameters["atom_selection"] = '{"all": false, "element": ["H"]}' + parameters["atom_selection"] = '{"0": {"function_name": "select_atoms", "atom_types": ["H"]}}' temp = IJob.create("TrajectoryEditor") temp.run(parameters, status=True) assert path.exists(temp_name + ".mdt") diff --git a/MDANSE/Tests/UnitTests/test_reusable_selection.py b/MDANSE/Tests/UnitTests/test_reusable_selection.py new file mode 100644 index 0000000000..0e61fd70e8 --- /dev/null +++ b/MDANSE/Tests/UnitTests/test_reusable_selection.py @@ -0,0 +1,201 @@ + +import os + +import pytest + +from MDANSE.Framework.AtomSelector.selector import ReusableSelection +from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData + + +short_traj = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "short_trajectory_after_changes.mdt", +) +mdmc_traj = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "Ar_mdmc_h5md.h5", +) +com_traj = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "com_trajectory.mdt", +) +traj_2vb1 = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "2vb1.mdt" +) + +@pytest.fixture(scope='module') +def trajectory(request): + return HDFTrajectoryInputData(request.param) + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_select_all(trajectory): + n_atoms = len(trajectory.chemical_system.atom_list) + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_all'}) + selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(selection) == n_atoms + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_empty_json_string_selects_all(trajectory): + n_atoms = len(trajectory.chemical_system.atom_list) + reusable_selection = ReusableSelection() + reusable_selection.load_from_json('{}') + selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(selection) == n_atoms + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_select_none(trajectory): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_none'}) + selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(selection) == 0 + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_inverted_all_is_none(trajectory): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_all'}) + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'invert_selection'}) + selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(selection) == 0 + + +def test_json_saving_is_reversible(): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_all'}) + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'invert_selection'}) + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + json_string_2 = another_selection.convert_to_json() + assert json_string == json_string_2 + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_selection_from_json_is_the_same_as_from_runtime(trajectory): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_all'}) + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'invert_selection'}) + selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + selection2 = another_selection.select_in_trajectory(trajectory.trajectory) + print(f"original: {reusable_selection.operations}") + print(f"another: {another_selection.operations}") + assert selection == selection2 + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("element, expected", ( + (["Cu"], 208), + (["S"], 208), + (["Sb"], 64), + (["S", "Sb", "Cu"], 480) +)) +def test_select_atoms_selects_by_element(trajectory, element, expected): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={"function_name": "select_atoms", + "atom_types": element}) + selection = reusable_selection.select_in_trajectory(trajectory) + assert len(selection) == expected + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + loaded_selection = another_selection.select_in_trajectory(trajectory) + assert len(loaded_selection) == expected + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("index_range, expected", ( + ([15,35], 20), + ([470,510], 10), +)) +def test_select_atoms_selects_by_range(trajectory, index_range, expected): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_atoms', 'index_range': index_range}) + range_selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(range_selection) == expected + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("index_slice, expected", ( + ([150,350,10], 20), + ([470,510,5], 2), +)) +def test_select_atoms_selects_by_slice(trajectory, index_slice, expected): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_atoms', 'index_slice': index_slice}) + range_selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(range_selection) == expected + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + overshoot_selection = another_selection.select_in_trajectory(trajectory.trajectory) + assert len(overshoot_selection) == expected + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +@pytest.mark.parametrize("molecule_names, expected", ( + (['H2 O1'], 28746), + (['C613 H959 N193 O185 S10'], 1960), +)) +def test_select_molecules_selects_water(trajectory, molecule_names, expected): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_molecules', 'molecule_names': molecule_names}) + first_selection = reusable_selection.select_in_trajectory(trajectory.trajectory) + assert len(first_selection) == expected + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + second_selection = another_selection.select_in_trajectory(trajectory.trajectory) + assert len(second_selection) == expected + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +def test_select_molecules_inverted_selects_ions(trajectory): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_molecules', 'molecule_names': ['C613 H959 N193 O185 S10', 'H2 O1']}) + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'invert_selection'}) + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + non_molecules_selection = another_selection.select_in_trajectory(trajectory.trajectory) + assert all([trajectory.chemical_system.atom_list[index] in ['Na', 'Cl'] for index in non_molecules_selection]) + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +def test_select_pattern_selects_water(trajectory): + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_pattern', 'rdkit_pattern': "[#8X2;H2](~[H])~[H]"}) + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + water_selection = another_selection.select_in_trajectory(trajectory.trajectory) + assert len(water_selection) == 28746 + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +def test_selection_with_multiple_steps(trajectory): + """This tests if the ReusableSelection can select oxygen only in + the water molecules. It combines two steps: + 1. water is selected using rdkit pattern matching + 2. oxygen is selected using simple atom type matching; intersection of the selections is applied + The selection is then saved to a JSON string, loaded from the string and applied to the trajectory. + """ + reusable_selection = ReusableSelection() + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_pattern', 'rdkit_pattern': "[#8X2;H2](~[H])~[H]"}) + reusable_selection.set_selection(number=None, function_parameters={'function_name': 'select_atoms', 'atom_types': ['O'], 'operation_type': 'intersection'}) + json_string = reusable_selection.convert_to_json() + another_selection = ReusableSelection() + another_selection.load_from_json(json_string) + water_oxygen_selection = another_selection.select_in_trajectory(trajectory.trajectory) + assert len(water_oxygen_selection) == int(28746/3) diff --git a/MDANSE/Tests/UnitTests/test_selection.py b/MDANSE/Tests/UnitTests/test_selection.py new file mode 100644 index 0000000000..de1cb9ba10 --- /dev/null +++ b/MDANSE/Tests/UnitTests/test_selection.py @@ -0,0 +1,150 @@ + +import os + +import pytest +import numpy as np + +from MDANSE.Framework.AtomSelector.general_selection import select_all, select_none, invert_selection +from MDANSE.Framework.AtomSelector.atom_selection import select_atoms +from MDANSE.Framework.AtomSelector.molecule_selection import select_molecules +from MDANSE.Framework.AtomSelector.group_selection import select_labels, select_pattern +from MDANSE.Framework.AtomSelector.spatial_selection import select_positions, select_sphere +from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData + + +short_traj = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "short_trajectory_after_changes.mdt", +) +mdmc_traj = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "Ar_mdmc_h5md.h5", +) +com_traj = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "com_trajectory.mdt", +) +traj_2vb1 = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "Converted", + "2vb1.mdt" +) + + +@pytest.fixture(scope='module') +def trajectory(request): + return HDFTrajectoryInputData(request.param) + + +@pytest.fixture(scope="module") +def short_trajectory(): + traj_object = HDFTrajectoryInputData(short_traj) + yield traj_object.trajectory + traj_object.close() + + +@pytest.fixture(scope="module") +def gromacs_trajectory(): + traj_object = HDFTrajectoryInputData(traj_2vb1) + yield traj_object.trajectory + traj_object.close() + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_select_all(trajectory): + n_atoms = len(trajectory.chemical_system.atom_list) + selection = select_all(trajectory.trajectory) + assert len(selection) == n_atoms + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_select_none(trajectory): + selection = select_none(trajectory.trajectory) + assert len(selection) == 0 + + +@pytest.mark.parametrize("trajectory", [short_traj, mdmc_traj, com_traj], indirect=True) +def test_inverted_none_is_all(trajectory): + none_selection = select_none(trajectory.trajectory) + all_selection = select_all(trajectory.trajectory) + inverted_none = invert_selection(trajectory.trajectory, none_selection) + assert all_selection == inverted_none + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("element, expected", ( + (["Cu"], 208), + (["S"], 208), + (["Sb"], 64), + (["S", "Sb", "Cu"], 480) +)) +def test_select_atoms_selects_by_element(trajectory, element, expected): + s_selection = select_atoms(trajectory.trajectory, atom_types=element) + assert len(s_selection) == expected + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("index_range, expected", ( + ([15,35], 20), + ([470,510], 10), +)) +def test_select_atoms_selects_by_range(trajectory, index_range, expected): + range_selection = select_atoms(trajectory.trajectory, index_range=index_range) + assert len(range_selection) == expected + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("index_slice, expected", ( + ([150,350,10], 20), + ([470,510,5], 2), +)) +def test_select_atoms_selects_by_slice(trajectory, index_slice, expected): + range_selection = select_atoms(trajectory.trajectory, index_slice=index_slice) + assert len(range_selection) == expected + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +@pytest.mark.parametrize("molecule_names, expected", ( + (['H2 O1'], 28746), + (['C613 H959 N193 O185 S10'], 1960), +)) +def test_select_molecules(trajectory, molecule_names, expected): + water_selection = select_molecules(trajectory.trajectory, molecule_names = molecule_names) + assert len(water_selection) == expected + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +def test_select_molecules_inverted_selects_ions(trajectory): + all_molecules_selection = select_molecules(trajectory.trajectory, molecule_names = ['C613 H959 N193 O185 S10', 'H2 O1']) + non_molecules_selection = invert_selection(trajectory.trajectory, all_molecules_selection) + assert all([trajectory.chemical_system.atom_list[index] in ['Na', 'Cl'] for index in non_molecules_selection]) + + +@pytest.mark.parametrize("trajectory", [traj_2vb1], indirect=True) +def test_select_pattern_selects_water(trajectory): + water_selection = select_pattern(trajectory.trajectory, rdkit_pattern="[#8X2;H2](~[H])~[H]") + assert len(water_selection) == 28746 + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +@pytest.mark.parametrize("lower_limit, upper_limit, expected", ( + ([0.5, 0.5, 0.5], [0.7, 0.7, 0.7], {59}), + (None, [0.2, 0.2, 0.2], {46}), + ([1.8, 1.8, 1.8], None, {453}), +)) +def test_select_positions(trajectory, lower_limit, upper_limit, expected): + cube_selection = select_positions(trajectory.trajectory, + position_minimum = lower_limit, + position_maximum = upper_limit) + assert cube_selection == expected + + +@pytest.mark.parametrize("trajectory", [short_traj], indirect=True) +def test_select_sphere(trajectory): + sphere_selection = select_sphere(trajectory.trajectory, + sphere_centre = 0.6*np.ones(3), + sphere_radius = 0.2) + assert sphere_selection == {19, 17, 59, 15} diff --git a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomSelectionWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomSelectionWidget.py index 64285f7197..2220c0c81c 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomSelectionWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomSelectionWidget.py @@ -13,25 +13,145 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -from qtpy.QtCore import Qt, Slot + +import json +from enum import StrEnum + +from MDANSE.Framework.AtomSelector.selector import ReusableSelection +from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData +from qtpy.QtCore import Signal, Slot +from qtpy.QtGui import QStandardItem, QStandardItemModel from qtpy.QtWidgets import ( - QLineEdit, - QPushButton, + QAbstractItemView, QDialog, - QCheckBox, - QVBoxLayout, - QHBoxLayout, QGroupBox, + QHBoxLayout, QLabel, + QLineEdit, + QListView, QPlainTextEdit, + QPushButton, + QScrollArea, + QVBoxLayout, QWidget, ) -from MDANSE.Framework.AtomSelector import Selector + from MDANSE_GUI.InputWidgets.WidgetBase import WidgetBase -from MDANSE_GUI.Tabs.Visualisers.View3D import View3D from MDANSE_GUI.MolecularViewer.MolecularViewer import MolecularViewerWithPicking -from MDANSE.Framework.InputData.HDFTrajectoryInputData import HDFTrajectoryInputData -from .CheckableComboBox import CheckableComboBox +from MDANSE_GUI.Tabs.Visualisers.View3D import View3D +from MDANSE_GUI.Widgets.SelectionWidgets import ( + AllAtomSelection, + AtomSelection, + IndexSelection, + LabelSelection, + MoleculeSelection, + PatternSelection, + PositionSelection, + SphereSelection, +) + + +class SelectionValidity(StrEnum): + """Strings for selection check results.""" + + VALID_SELECTION = "Valid selection" + USELESS_SELECTION = "Selection did not change. This operation is not needed." + MALFORMED_SELECTION = "This is not a valid selection string." + + +class SelectionModel(QStandardItemModel): + """Stores the selection operations in the GUI view.""" + + selection_changed = Signal() + + def __init__(self, trajectory): + """Assign the current trajectory to the model.""" + super().__init__(None) + self._trajectory = trajectory + self._selection = ReusableSelection() + self._current_selection = set() + + def rebuild_selection(self, last_operation: str) -> SelectionValidity: + """Update the current selection based on the text in the GUI. + + Parameters + ---------- + last_operation : str + Additional selection operation input by the user. + + Returns + ------- + SelectionValidity + result of the check on last_operation + + """ + self._selection = ReusableSelection() + self._current_selection = set() + total_dict = {} + for row in range(self.rowCount()): + index = self.index(row, 0) + item = self.itemFromIndex(index) + json_string = item.text() + total_dict[row] = json.loads(json_string) + self._selection.load_from_json(json.dumps(total_dict)) + self._current_selection = self._selection.select_in_trajectory(self._trajectory) + if last_operation: + try: + valid = self._selection.validate_selection_string( + last_operation, + self._trajectory, + self._current_selection, + ) + except json.JSONDecodeError: + return SelectionValidity.MALFORMED_SELECTION + if valid: + self._selection.load_from_json(json_string) + return SelectionValidity.VALID_SELECTION + return SelectionValidity.USELESS_SELECTION + return None + + def current_selection(self, last_operation: str = "") -> set[int]: + """Return the selected atom indices. + + Parameters + ---------- + last_operation : str, optional + Extra selection operation typed by the user, by default "" + + Returns + ------- + set[int] + indices of all the selected atoms + + """ + self.rebuild_selection(last_operation) + return self._selection.select_in_trajectory(self._trajectory) + + def current_steps(self) -> str: + """Return selection operations as a JSON string. + + Returns + ------- + str + one string with all the selection operations in sequence + + """ + result = {} + for row in range(self.rowCount()): + index = self.index(row, 0) + item = self.itemFromIndex(index) + json_string = item.text() + python_object = json.loads(json_string) + result[row] = python_object + return json.dumps(result) + + @Slot(str) + def accept_from_widget(self, json_string: str): + """Add a selection operation sent from a selection widget.""" + new_item = QStandardItem(json_string) + new_item.setEditable(False) + self.appendRow(new_item) + self.selection_changed.emit() class SelectionHelper(QDialog): @@ -41,58 +161,40 @@ class SelectionHelper(QDialog): ---------- _helper_title : str The title of the helper dialog window. - _cbox_text : dict - The dictionary that maps the selector settings to text used in - the helper dialog. + """ _helper_title = "Atom selection helper" - _cbox_text = { - "all": "All atoms (excl. dummy atoms):", - "dummy": "All dummy atoms:", - "hs_on_heteroatom": "Hs on heteroatoms:", - "primary_amine": "Primary amine groups:", - "hydroxy": "Hydroxy groups:", - "methyl": "Methyl groups:", - "phosphate": "Phosphate groups:", - "sulphate": "Sulphate groups:", - "thiol": "Thiol groups:", - "water": "Water molecules:", - "hs_on_element": "Hs on elements:", - "element": "Elements:", - "name": "Atom name:", - "fullname": "Atom fullname:", - "index": "Indexes:", - } def __init__( self, - selector: Selector, traj_data: tuple[str, HDFTrajectoryInputData], field: QLineEdit, parent, *args, **kwargs, ): - """ + """Create the selection dialog. + Parameters ---------- - selector : Selector - The MDANSE selector initialized with the current chemical - system. traj_data : tuple[str, HDFTrajectoryInputData] A tuple of the trajectory data used to load the 3D viewer. field : QLineEdit The QLineEdit field that will need to be updated when applying the setting. + """ super().__init__(parent, *args, **kwargs) self.setWindowTitle(self._helper_title) - self.selector = selector + self.trajectory = traj_data[1].trajectory + self.system = self.trajectory.chemical_system + self.selection_model = SelectionModel(self.trajectory) self._field = field - self.settings = self.selector.settings - self.atm_full_names = self.selector.system.name_list + self.atm_full_names = self.system.name_list + self.molecule_names = self.system.unique_molecules() + self.labels = list(map(str, self.system._labels)) self.selection_textbox = QPlainTextEdit() self.selection_textbox.setReadOnly(True) @@ -108,32 +210,38 @@ def __init__( for button in self.create_buttons(): bottom.addWidget(button) - layouts[-1].addLayout(bottom) - helper_layout = QHBoxLayout() - for layout in layouts: - helper_layout.addLayout(layout) + sub_layout = QVBoxLayout() + helper_layout.addLayout(layouts[0]) + helper_layout.addLayout(sub_layout) + for layout in layouts[1:]: + sub_layout.addLayout(layout) + sub_layout.addLayout(bottom) self.setLayout(helper_layout) - self.update_others() self.all_selection = True - self.selected = set([]) + self.selected = set() + self.reset() def closeEvent(self, a0): - """Hide the window instead of closing. Some issues occur in the + """Hide the window instead of closing. + + Some issues occur in the 3D viewer when it is closed and then reopened. """ a0.ignore() self.hide() def create_buttons(self) -> list[QPushButton]: - """ + """Add buttons to the dialog layout. + Returns ------- list[QPushButton] List of push buttons to add to the last layout from create_layouts. + """ apply = QPushButton("Use Setting") reset = QPushButton("Reset") @@ -144,11 +252,13 @@ def create_buttons(self) -> list[QPushButton]: return [apply, reset, close] def create_layouts(self) -> list[QVBoxLayout]: - """ + """Call functions creating other widgets. + Returns ------- list[QVBoxLayout] List of QVBoxLayout to add to the helper layout. + """ layout_3d = QVBoxLayout() layout_3d.addWidget(self.view_3d) @@ -157,156 +267,112 @@ def create_layouts(self) -> list[QVBoxLayout]: for widget in self.left_widgets(): left.addWidget(widget) - right = QVBoxLayout() + right = QHBoxLayout() for widget in self.right_widgets(): right.addWidget(widget) return [layout_3d, left, right] def right_widgets(self) -> list[QWidget]: - """ + """Create widgets visualising the selection results. + Returns ------- list[QWidget] List of QWidgets to add to the right layout from create_layouts. + """ - return [self.selection_textbox] + return [self.selection_operations_view, self.selection_textbox] def left_widgets(self) -> list[QWidget]: - """ + """Create widgets for defining the selection. + Returns ------- list[QWidget] List of QWidgets to add to the left layout from create_layouts. - """ - match_exists = self.selector.match_exists + """ select = QGroupBox("selection") select_layout = QVBoxLayout() + scroll_area = QScrollArea() + scroll_area.setWidget(select) + scroll_area.setWidgetResizable(True) + + self.selection_widgets = [ + AllAtomSelection(self), + AtomSelection(self, self.trajectory), + IndexSelection(self), + MoleculeSelection(self, self.trajectory), + PatternSelection(self), + LabelSelection(self, self.trajectory), + PositionSelection(self, self.trajectory, self.view_3d._viewer), + SphereSelection(self, self.trajectory, self.view_3d._viewer), + ] - self.check_boxes = [] - self.combo_boxes = [] - - for k, v in self.settings.items(): - if isinstance(v, bool): - check_layout = QHBoxLayout() - checkbox = QCheckBox() - checkbox.setChecked(v) - checkbox.setLayoutDirection(Qt.RightToLeft) - label = QLabel(self._cbox_text[k]) - checkbox.setObjectName(k) - checkbox.stateChanged.connect(self.update_others) - if not match_exists[k]: - checkbox.setEnabled(False) - label.setStyleSheet("color: grey;") - self.check_boxes.append(checkbox) - check_layout.addWidget(label) - check_layout.addWidget(checkbox) - select_layout.addLayout(check_layout) - - elif isinstance(v, dict): - combo_layout = QHBoxLayout() - combo = CheckableComboBox() - items = [str(i) for i in v.keys() if match_exists[k][i]] - # we blocksignals here as there can be some - # performance issues with a large number of items - combo.model().blockSignals(True) - combo.addItems(items) - combo.model().blockSignals(False) - combo.setObjectName(k) - combo.model().dataChanged.connect(self.update_others) - label = QLabel(self._cbox_text[k]) - if len(items) == 0: - combo.setEnabled(False) - label.setStyleSheet("color: grey;") - self.combo_boxes.append(combo) - combo_layout.addWidget(label) - combo_layout.addWidget(combo) - select_layout.addLayout(combo_layout) + for widget in self.selection_widgets: + select_layout.addWidget(widget) + widget.new_selection.connect(self.selection_model.accept_from_widget) invert_layout = QHBoxLayout() - label = QLabel("Invert selection:") + label = QLabel("Current selection:") + self.selection_line = QLineEdit("", self) apply = QPushButton("Apply") - apply.clicked.connect(self.invert_selection) + apply.clicked.connect(self.append_selection) invert_layout.addWidget(label) + invert_layout.addWidget(self.selection_line) invert_layout.addWidget(apply) select_layout.addLayout(invert_layout) select.setLayout(select_layout) - return [select] - def update_others(self) -> None: - """Using the checkbox and combobox widgets: update the settings, - get the selection and update the textedit box with details of - the current selection and the 3d view to match the selection. - """ - for check_box in self.check_boxes: - self.settings[check_box.objectName()] = check_box.isChecked() - for combo_box in self.combo_boxes: - for i in range(combo_box.n_items): - txt = combo_box.text[i] - if combo_box.objectName() == "index": - key = int(txt) - else: - key = txt - self.settings[combo_box.objectName()][key] = combo_box.checked[i] - - self.selector.update_settings(self.settings) - self.selected = self.selector.get_idxs() + self.selection_operations_view = QListView(self) + self.selection_operations_view.setModel(self.selection_model) + self.selection_model.selection_changed.connect(self.recalculate_selection) + return [scroll_area] + + @Slot() + def recalculate_selection(self): + """Update atom indices after selection change.""" + self.selected = self.selection_model.current_selection() self.view_3d._viewer.change_picked(self.selected) self.update_selection_textbox() - def update_from_3d_view(self, selection: set[int]) -> None: - """A selection/deselection was made in the 3d view, update the + def update_from_3d_view(self, _selection: set[int]) -> None: + """Update atom indices after an atom has been clicked. + + A selection/deselection was made in the 3d view, update the check_boxes, combo_boxes and textbox. Parameters ---------- selection : set[int] Selection indexes from the 3d view. - """ - self.selector.update_with_idxs(selection) - self.settings = self.selector.settings - self.update_selection_widgets() - self.selected = self.selector.get_idxs() - self.update_selection_textbox() - def invert_selection(self): - """Inverts the selection.""" - self.selected = self.selector.all_idxs - self.selected - self.selector.update_with_idxs(self.selected) - self.settings = self.selector.settings - self.update_selection_widgets() - self.view_3d._viewer.change_picked(self.selected) + """ self.update_selection_textbox() - def update_selection_widgets(self) -> None: - """Updates the selection widgets so that it matches the full - setting. - """ - for check_box in self.check_boxes: - check_box.blockSignals(True) - if self.settings[check_box.objectName()]: - check_box.setCheckState(Qt.Checked) - else: - check_box.setCheckState(Qt.Unchecked) - check_box.blockSignals(False) - for combo_box in self.combo_boxes: - combo_box.model().blockSignals(True) - for i in range(combo_box.n_items): - txt = combo_box.text[i] - if combo_box.objectName() == "index": - key = int(txt) - else: - key = txt - combo_box.set_item_checked_state( - i, self.settings[combo_box.objectName()][key] - ) - combo_box.update_all_selected() - combo_box.update_line_edit() - combo_box.model().blockSignals(False) + @Slot() + def append_selection(self): + """Add a selection operation from the text input field.""" + self.selection_line.setStyleSheet("") + self.selection_line.setToolTip("") + selection_text = self.selection_line.text() + validation = self.selection_model.rebuild_selection(selection_text) + if validation in ( + SelectionValidity.MALFORMED_SELECTION, + SelectionValidity.USELESS_SELECTION, + ): + self.selection_line.setStyleSheet( + "QWidget#InputWidget { background-color:rgb(180,20,180); font-weight: bold }" + ) + self.selection_line.setToolTip(validation) + elif validation == SelectionValidity.VALID_SELECTION: + self.selection_model.appendRow(QStandardItem(selection_text)) + self.view_3d._viewer.change_picked(self.selected) + self.update_selection_textbox() def update_selection_textbox(self) -> None: """Update the selection textbox.""" @@ -317,31 +383,27 @@ def update_selection_textbox(self) -> None: self.selection_textbox.setPlainText("".join(text)) def apply(self) -> None: - """Set the field of the AtomSelectionWidget to the currently - chosen setting in this widget. - """ - self.selector.update_settings(self.settings) - self._field.setText(self.selector.settings_to_json()) + """Send the selection from the dialog to the main widget.""" + self._field.setText(self.selection_model.current_steps()) def reset(self) -> None: - """Resets the helper to the default state.""" - self.selector.reset_settings() - self.selector.settings["all"] = self.all_selection - self.settings = self.selector.settings - self.update_selection_widgets() - self.selected = self.selector.get_idxs() - self.view_3d._viewer.change_picked(self.selected) - self.update_selection_textbox() + """Reset the helper to the default state.""" + self.selection_model.clear() + self.selection_model.accept_from_widget( + '{"function_name": "select_all", "operation_type": "union"}' + ) + self.recalculate_selection() class AtomSelectionWidget(WidgetBase): """The atoms selection widget.""" _push_button_text = "Atom selection helper" - _default_value = '{"all": true}' + _default_value = "{}" _tooltip_text = "Specify which atoms will be used in the analysis. The input is a JSON string, and can be created using the helper dialog." def __init__(self, *args, **kwargs): + """Create the main widget for atom selection.""" super().__init__(*args, **kwargs) self._value = self._default_value self._field = QLineEdit(self._default_value, self._base) @@ -363,9 +425,14 @@ def __init__(self, *args, **kwargs): self._field.setToolTip(self._tooltip_text) def create_helper( - self, traj_data: tuple[str, HDFTrajectoryInputData] + self, + traj_data: tuple[str, HDFTrajectoryInputData], ) -> SelectionHelper: - """ + """Create the selection dialog. + + It will be populated with selection widget which can be used + to create the complete atom selection string. + Parameters ---------- traj_data : tuple[str, HDFTrajectoryInputData] @@ -375,29 +442,30 @@ def create_helper( ------- SelectionHelper Create and return the selection helper QDialog. + """ - selector = self._configurator.get_selector() - return SelectionHelper(selector, traj_data, self._field, self._base) + return SelectionHelper(traj_data, self._field, self._base) @Slot() def helper_dialog(self) -> None: - """Opens the helper dialog.""" + """Open the helper dialog.""" if self.helper.isVisible(): self.helper.close() else: self.helper.show() def get_widget_value(self) -> str: - """ + """Return the current text in the input field. + Returns ------- str The JSON selector setting. + """ selection_string = self._field.text() if len(selection_string) < 1: self._empty = True return self._default_value - else: - self._empty = False + self._empty = False return selection_string diff --git a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomTransmutationWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomTransmutationWidget.py index 901c846a50..486dabe90b 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomTransmutationWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/AtomTransmutationWidget.py @@ -69,16 +69,13 @@ def __init__( self.transmutation_textbox.setReadOnly(True) self.transmutation_combo = QComboBox() self.transmutation_combo.addItems(ATOMS_DATABASE.atoms) - self.transmuter.selector.settings["all"] = False super().__init__( - transmuter.selector, traj_data, field, parent, *args, **kwargs, ) - self.all_selection = False self.update_transmutation_textbox() def right_widgets(self) -> list[QWidget]: @@ -125,8 +122,9 @@ def apply_transmutation(self) -> None: transmutation and update the transmutation textbox with the new transmutation setting. """ + selection_string = self.selection_model.current_steps() self.transmuter.apply_transmutation( - self.settings, self.transmutation_combo.currentText() + selection_string, self.transmutation_combo.currentText() ) self.update_transmutation_textbox() @@ -134,12 +132,13 @@ def update_transmutation_textbox(self) -> None: """Update the transmutation textbox with the current transmuter setting information. """ - map = self.transmuter.get_setting() + substitutions = self.transmuter.get_setting() - text = [f"Number of atoms transmuted:\n{len(map)}\n\nTransmuted atoms:\n"] - atoms = self.selector.system.atom_list - for idx, symbol in map.items(): - text.append(f"{idx} {atoms[idx]} -> {symbol}\n") + text = [ + f"Number of atoms transmuted:\n{len(substitutions)}\n\nTransmuted atoms:\n" + ] + for idx, symbol in substitutions.items(): + text.append(f"{idx} {self.atm_full_names[idx]} -> {symbol}\n") self.transmutation_textbox.setText("".join(text)) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/CheckableComboBox.py b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/CheckableComboBox.py index 327fca0b3f..9a621d33a3 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/CheckableComboBox.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/CheckableComboBox.py @@ -42,6 +42,16 @@ def __init__(self, *args, **kwargs): self.addItem("select all", underline=True) self.lineEdit().setText("") + def clear(self): + result = super().clear() + self.items = [] + self.checked = [] + self.text = [] + self.select_all_item = None + self.addItem("select all", underline=True) + self.lineEdit().setText("") + return result + def eventFilter(self, a0: Union[QObject, None], a1: Union[QEvent, None]) -> bool: """Updates the check state of the items and the lineEdit. diff --git a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/PartialChargeWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/PartialChargeWidget.py index 4d357189aa..2f6d977119 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/PartialChargeWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/PartialChargeWidget.py @@ -68,8 +68,7 @@ def __init__( self.charge_textbox.setReadOnly(True) self.charge_qline = QLineEdit() self.charge_qline.setValidator(QDoubleValidator()) - self.mapper.selector.settings["all"] = False - super().__init__(mapper.selector, traj_data, field, parent, *args, **kwargs) + super().__init__(traj_data, field, parent, *args, **kwargs) self.all_selection = False self.update_charge_textbox() @@ -122,7 +121,8 @@ def apply_charges(self) -> None: except ValueError: # probably an empty QLineEdit box return - self.mapper.update_charges(self.settings, charge) + selection_string = self.selection_model.current_steps() + self.mapper.update_charges(selection_string, charge) self.update_charge_textbox() def update_charge_textbox(self) -> None: @@ -132,9 +132,8 @@ def update_charge_textbox(self) -> None: map = self.mapper.get_full_setting() text = ["Partial charge mapping:\n"] - atoms = self.selector.system.atom_list for idx, charge in map.items(): - text.append(f"{idx} ({atoms[idx]}) -> {charge}\n") + text.append(f"{idx} ({self.atm_full_names[idx]}) -> {charge}\n") self.charge_textbox.setText("".join(text)) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Widgets/SelectionWidgets.py b/MDANSE_GUI/Src/MDANSE_GUI/Widgets/SelectionWidgets.py new file mode 100644 index 0000000000..09cfd1014b --- /dev/null +++ b/MDANSE_GUI/Src/MDANSE_GUI/Widgets/SelectionWidgets.py @@ -0,0 +1,623 @@ +# This file is part of MDANSE_GUI. +# +# MDANSE_GUI is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +import json +from enum import StrEnum +from typing import TYPE_CHECKING, Any + +import numpy as np +from MDANSE.MolecularDynamics.Trajectory import Trajectory +from qtpy.QtCore import Signal, Slot +from qtpy.QtGui import QDoubleValidator, QValidator +from qtpy.QtWidgets import ( + QComboBox, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, +) +from rdkit.Chem import MolFromSmarts + +from MDANSE_GUI.InputWidgets.CheckableComboBox import CheckableComboBox + +if TYPE_CHECKING: + from MDANSE_GUI.MolecularViewer.MolecularViewer import MolecularViewer + + +class IndexSelectionMode(StrEnum): + """Valid atom selection modes for select_atoms.""" + + LIST = "list" + RANGE = "range" + SLICE = "slice" + + +class XYZValidator(QValidator): + """A custom validator for a QLineEdit. + + It is intended to limit the input to a string + of 3 comma-separated float numbers. + + Additional checks are necessary later in the code, + since the validator cannot exclude the cases of + 1 or 2 comma-separated values, since they are + a preliminary step when typing in 3 numbers. + """ + + PARAMETERS_NEEDED = 3 + + def validate(self, input_string: str, position: int) -> tuple[int, str]: + """Check the input string from a widget. + + Implementation of the virtual method of QValidator. + It takes in the string from a QLineEdit and the cursor position, + and an enum value of the validator state. Widgets will reject + inputs which change the state to Invalid. + + Parameters + ---------- + input_string : str + current contents of a text input field + position : int + position of the cursor in the text input field + + Returns + ------- + int + Validator state. + str + Original input string. + int + Cursor position. + + """ + state = QValidator.State.Intermediate + comma_count = input_string.count(",") + if input_string: + try: + values = [float(x) for x in input_string.split(",")] + except (TypeError, ValueError): + if input_string.endswith(",") and comma_count < self.PARAMETERS_NEEDED: + state = QValidator.State.Intermediate + else: + state = QValidator.State.Invalid + else: + if len(values) > self.PARAMETERS_NEEDED: + state = QValidator.State.Invalid + elif len(values) == self.PARAMETERS_NEEDED: + state = QValidator.State.Acceptable + else: + state = QValidator.State.Intermediate + return state, input_string, position + + +class BasicSelectionWidget(QGroupBox): + """Base class for atom selection widgets.""" + + new_selection = Signal(str) + + def __init__(self, parent=None, widget_label="Atom selection widget"): + """Create subwidgets common to atom selection. + + Parameters + ---------- + parent : QWidget, optional + parent in the Qt hierarchy, by default None + widget_label : str, optional + Text shown above the widget, by default "Atom selection widget" + + """ + super().__init__(parent) + layout = QHBoxLayout() + self.setLayout(layout) + self.setTitle(widget_label) + self.add_specific_widgets() + self.add_standard_widgets() + + def parameter_dictionary(self) -> dict[str, Any]: + """Collect and return selection function parameters.""" + return {} + + def add_specific_widgets(self): + """Add additional widgets to layout, depending on the selection function.""" + return + + def add_standard_widgets(self): + """Create widgets needed by all atom selection types. + + This creates a combo box for the set operation type, + and a button for making the selection. + """ + self.mode_box = QComboBox(self) + self.mode_box.setEditable(False) + self.mode_box.addItems( + ["Add (union)", "Filter (intersection)", "Remove (difference)"], + ) + self._mode_box_values = ["union", "intersection", "difference"] + self.commit_button = QPushButton("Apply", self) + layout = self.layout() + layout.addWidget(self.mode_box) + layout.addWidget(self.commit_button) + self.commit_button.clicked.connect(self.create_selection) + + def get_mode(self) -> str: + """Get the current set operation type from the combo box.""" + return self._mode_box_values[self.mode_box.currentIndex()] + + def create_selection(self): + """Collect the input values and emit them in a signal.""" + funtion_parameters = self.parameter_dictionary() + funtion_parameters["operation_type"] = self.get_mode() + self.new_selection.emit(json.dumps(funtion_parameters)) + + +class AllAtomSelection(BasicSelectionWidget): + """Widget for global atom selection, e.g. all atoms, no atoms.""" + + def __init__(self, parent=None, widget_label="ALL ATOMS"): + """Pass inputs to the parent class init. + + Parameters + ---------- + parent : QWidget, optional + parent in the Qt hierarchy, by default None + widget_label : str, optional + Text over the widget, by default "ALL ATOMS" + + """ + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Add the INVERT button.""" + layout = self.layout() + inversion_button = QPushButton("INVERT selection", self) + inversion_button.clicked.connect(self.invert_selection) + layout.addWidget(inversion_button) + layout.addWidget(QLabel("Add/remove ALL atoms")) + + def invert_selection(self): + """Emit the string for inverting the selection.""" + self.new_selection.emit(json.dumps({"function_name": "invert_selection"})) + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + return {"function_name": "select_all"} + + +class AtomSelection(BasicSelectionWidget): + """GUI frontend for select_atoms.""" + + def __init__( + self, + parent=None, + trajectory: Trajectory = None, + widget_label="Select atoms", + ): + """Create the widgets for select_atoms. + + Parameters + ---------- + parent : QWidget, optional + parent from the Qt object hierarchy, by default None + trajectory : Trajectory, optional + The current trajectory object, by default None + widget_label : str, optional + Text shown over the widget, by default "Select atoms" + + """ + self.atom_types = [] + self.atom_names = [] + if trajectory: + self.atom_types = list(np.unique(trajectory.chemical_system.atom_list)) + if trajectory.chemical_system.name_list: + self.atom_names = list(np.unique(trajectory.chemical_system.name_list)) + self.selection_types = [] + self.selection_keyword = "" + if self.atom_types: + self.selection_types += ["type"] + if self.atom_names: + self.selection_types += ["name"] + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Create selection combo boxes.""" + layout = self.layout() + layout.addWidget(QLabel("Select atoms by atom")) + self.selection_type_combo = QComboBox(self) + self.selection_type_combo.addItems(self.selection_types) + self.selection_type_combo.setEditable(False) + layout.addWidget(self.selection_type_combo) + self.selection_field = CheckableComboBox(self) + layout.addWidget(self.selection_field) + self.selection_type_combo.currentTextChanged.connect(self.switch_mode) + self.selection_type_combo.setCurrentText(self.selection_types[0]) + self.switch_mode(self.selection_types[0]) + + @Slot(str) + def switch_mode(self, new_mode: str): + """Change the contents of the second combo box.""" + self.selection_field.clear() + if new_mode == "type": + self.selection_field.addItems(self.atom_types) + self.selection_keyword = "atom_types" + elif new_mode == "name": + self.selection_field.addItems(self.atom_names) + self.selection_keyword = "atom_names" + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + function_parameters = {"function_name": "select_atoms"} + selection = self.selection_field.checked_values() + function_parameters[self.selection_keyword] = selection + return function_parameters + + +class IndexSelection(BasicSelectionWidget): + """GUI frontend for select_atoms.""" + + def __init__(self, parent=None, widget_label="Index selection"): + """Create all the widgets. + + Parameters + ---------- + parent : QWidget, optional + parent in the Qt object hierarchy, by default None + widget_label : str, optional + Text shown above the widget, by default "Index selection" + + """ + super().__init__(parent, widget_label) + self.selection_keyword = "index_list" + + def add_specific_widgets(self): + """Create the combo box and text input field.""" + layout = self.layout() + layout.addWidget(QLabel("Select atoms by index")) + self.selection_type_combo = QComboBox(self) + self.selection_type_combo.addItems(str(mode) for mode in IndexSelectionMode) + self.selection_type_combo.setEditable(False) + layout.addWidget(self.selection_type_combo) + self.selection_field = QLineEdit(self) + layout.addWidget(self.selection_field) + self.selection_type_combo.currentTextChanged.connect(self.switch_mode) + + @Slot(str) + def switch_mode(self, new_mode: str): + """Change the meaning of the text input field.""" + self.selection_field.setText("") + if new_mode == IndexSelectionMode.LIST: + self.selection_field.setPlaceholderText("0,1,2") + self.selection_keyword = "index_list" + self.selection_separator = "," + elif new_mode == IndexSelectionMode.RANGE: + self.selection_field.setPlaceholderText("0-20") + self.selection_keyword = "index_range" + self.selection_separator = "-" + elif new_mode == IndexSelectionMode.SLICE: + self.selection_field.setPlaceholderText("first:last:step") + self.selection_keyword = "index_slice" + self.selection_separator = ":" + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + function_parameters = {"function_name": "select_atoms"} + selection = self.selection_field.text() + function_parameters[self.selection_keyword] = [ + int(x) for x in selection.split(self.selection_separator) + ] + return function_parameters + + +class MoleculeSelection(BasicSelectionWidget): + """GUI frontend for select_molecule.""" + + def __init__( + self, + parent=None, + trajectory: Trajectory = None, + widget_label="Select molecules", + ): + """Create the widgets for select_atoms. + + Parameters + ---------- + parent : QWidget, optional + parent from the Qt object hierarchy, by default None + trajectory : Trajectory, optional + The current trajectory object, by default None + widget_label : str, optional + Text shown over the widget, by default "Select atoms" + + """ + self.molecule_names = [] + if trajectory: + self.molecule_names = trajectory.chemical_system.unique_molecules() + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Create the combo box for molecule names.""" + layout = self.layout() + layout.addWidget(QLabel("Select molecules named: ")) + self.selection_field = CheckableComboBox(self) + layout.addWidget(self.selection_field) + self.selection_field.addItems(self.molecule_names) + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + function_parameters = {"function_name": "select_molecules"} + selection = self.selection_field.checked_values() + function_parameters["molecule_names"] = selection + return function_parameters + + +class LabelSelection(BasicSelectionWidget): + """GUI frontend for select_label.""" + + def __init__( + self, + parent=None, + trajectory: Trajectory = None, + widget_label="Select by label", + ): + """Create the widgets for select_atoms. + + Parameters + ---------- + parent : QWidget, optional + parent from the Qt object hierarchy, by default None + trajectory : Trajectory, optional + The current trajectory object, by default None + widget_label : str, optional + Text shown over the widget, by default "Select atoms" + + """ + self.labels = [] + if trajectory: + self.labels = list(trajectory.chemical_system._labels.keys()) + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Create the combo box for atom labels.""" + layout = self.layout() + layout.addWidget(QLabel("Select atoms with label: ")) + self.selection_field = CheckableComboBox(self) + layout.addWidget(self.selection_field) + self.selection_field.addItems(self.labels) + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + function_parameters = {"function_name": "select_labels"} + selection = self.selection_field.checked_values() + function_parameters["atom_labels"] = selection + return function_parameters + + +class PatternSelection(BasicSelectionWidget): + """GUI frontend for select_pattern.""" + + def __init__( + self, + parent=None, + widget_label="SMARTS pattern matching", + ): + """Create the widgets for select_atoms. + + Parameters + ---------- + parent : QWidget, optional + parent from the Qt object hierarchy, by default None + widget_label : str, optional + Text shown over the widget, by default "Select atoms" + + """ + self.pattern_dictionary = { + "primary amine": "[#7X3;H2;!$([#7][#6X3][!#6]);!$([#7][#6X2][!#6])](~[H])~[H]", + "hydroxy": "[#8;H1,H2]~[H]", + "methyl": "[#6;H3](~[H])(~[H])~[H]", + "phosphate": "[#15X4](~[#8])(~[#8])(~[#8])~[#8]", + "sulphate": "[#16X4](~[#8])(~[#8])(~[#8])~[#8]", + "thiol": "[#16X2;H1]~[H]", + } + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Create the pattern text field.""" + layout = self.layout() + layout.addWidget(QLabel("Pick a group")) + self.selection_field = QComboBox(self) + layout.addWidget(self.selection_field) + self.selection_field.addItems(self.pattern_dictionary.keys()) + layout.addWidget(QLabel("pattern:")) + self.input_field = QLineEdit("", self) + self.input_field.setPlaceholderText("can be edited") + layout.addWidget(self.input_field) + self.selection_field.currentTextChanged.connect(self.update_string) + self.input_field.textChanged.connect(self.check_inputs) + + @Slot() + def check_inputs(self): + """Disable selection of invalid or incomplete input.""" + enable = True + smarts_string = self.input_field.text() + temp_molecule = MolFromSmarts(smarts_string) + if temp_molecule is None: + enable = False + self.commit_button.setEnabled(enable) + + @Slot(str) + def update_string(self, key_string: str): + """Fill the input field with pre-defined text.""" + if key_string in self.pattern_dictionary: + self.input_field.setText(self.pattern_dictionary[key_string]) + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + function_parameters = {"function_name": "select_pattern"} + selection = self.input_field.text() + function_parameters["rdkit_pattern"] = selection + return function_parameters + + +class PositionSelection(BasicSelectionWidget): + """GUI frontend for select_positions.""" + + def __init__( + self, + parent=None, + trajectory: Trajectory = None, + molecular_viewer: "MolecularViewer" = None, + widget_label="Select by position", + ): + """Create the widgets for select_atoms. + + Parameters + ---------- + parent : QWidget, optional + parent from the Qt object hierarchy, by default None + trajectory : Trajectory, optional + The current trajectory object, by default None + molecular_viewer : MolecularViewer, optional + instance of the 3D viewer showing the current simulation frame + widget_label : str, optional + Text shown over the widget, by default "Select atoms" + + """ + self._viewer = molecular_viewer + self._lower_limit = np.zeros(3) + self._upper_limit = np.linalg.norm(trajectory.unit_cell(0)._unit_cell, axis=1) + self._current_lower_limit = self._lower_limit.copy() + self._current_upper_limit = self._upper_limit.copy() + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Create text input fields with validators.""" + layout = self.layout() + layout.addWidget(QLabel("Lower limits")) + self._lower_limit_input = QLineEdit( + ",".join([str(round(x, 3)) for x in self._lower_limit]), + ) + layout.addWidget(self._lower_limit_input) + layout.addWidget(QLabel("Upper limits")) + self._upper_limit_input = QLineEdit( + ",".join([str(round(x, 3)) for x in self._upper_limit]), + ) + layout.addWidget(self._upper_limit_input) + for field in [self._lower_limit_input, self._upper_limit_input]: + field.setValidator(XYZValidator(self)) + field.textChanged.connect(self.check_inputs) + + @Slot() + def check_inputs(self): + """Disable selection of invalid or incomplete input.""" + enable = True + try: + self._current_lower_limit = [ + float(x) for x in self._lower_limit_input.text().split(",") + ] + self._current_upper_limit = [ + float(x) for x in self._upper_limit_input.text().split(",") + ] + except (TypeError, ValueError): + enable = False + else: + if ( + len(self._current_lower_limit) != 3 + or len(self._current_upper_limit) != 3 + ): + enable = False + self.commit_button.setEnabled(enable) + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + return { + "function_name": "select_positions", + "frame_number": self._viewer._current_frame, + "position_minimum": list(self._current_lower_limit), + "position_maximum": list(self._current_upper_limit), + } + + +class SphereSelection(BasicSelectionWidget): + """GUI frontend for select_sphere.""" + + def __init__( + self, + parent=None, + trajectory: Trajectory = None, + molecular_viewer: "MolecularViewer" = None, + widget_label="Select in a sphere", + ): + """Create the widgets for select_atoms. + + Parameters + ---------- + parent : QWidget, optional + parent from the Qt object hierarchy, by default None + trajectory : Trajectory, optional + The current trajectory object, by default None + molecular_viewer : MolecularViewer, optional + instance of the 3D viewer showing the current simulation frame + widget_label : str, optional + Text shown over the widget, by default "Select atoms" + + """ + self._viewer = molecular_viewer + self._current_sphere_centre = np.diag(trajectory.unit_cell(0)._unit_cell) * 0.5 + self._current_sphere_radius = np.min(self._current_sphere_centre) + super().__init__(parent, widget_label) + + def add_specific_widgets(self): + """Create the text input fields for sphere radius and centre.""" + layout = self.layout() + layout.addWidget(QLabel("Sphere centre")) + self._sphere_centre_input = QLineEdit( + ",".join([str(round(x, 3)) for x in self._current_sphere_centre]), + self, + ) + layout.addWidget(self._sphere_centre_input) + layout.addWidget(QLabel("Sphere radius (nm)")) + self._sphere_radius_input = QLineEdit("0.5", self) + layout.addWidget(self._sphere_radius_input) + self._sphere_centre_input.setValidator(XYZValidator()) + self._sphere_centre_input.textChanged.connect(self.check_inputs) + self._sphere_radius_input.setValidator(QDoubleValidator()) + self._sphere_radius_input.textChanged.connect(self.check_inputs) + + @Slot() + def check_inputs(self): + """Disable selection on invalid or incomplete input.""" + enable = True + try: + self._current_sphere_centre = [ + float(x) for x in self._sphere_centre_input.text().split(",") + ] + self._current_sphere_radius = float(self._sphere_radius_input.text()) + except (TypeError, ValueError): + enable = False + else: + if len(self._current_sphere_centre) != 3: + enable = False + self.commit_button.setEnabled(enable) + + def parameter_dictionary(self): + """Collect and return selection function parameters.""" + return { + "function_name": "select_sphere", + "frame_number": self._viewer._current_frame, + "sphere_centre": list(self._current_sphere_centre), + "sphere_radius": self._current_sphere_radius, + }