From 5fa5f6cc6f1166016d47152502b24ea06a9757ff Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Sat, 27 Jul 2024 14:16:13 -0400 Subject: [PATCH] Swap to `ruff` and `setuptools-scm` --- .pre-commit-config.yaml | 43 +- Makefile | 25 +- devtools/envs/base.yaml | 7 +- examples/parameter-gradients.ipynb | 4 +- pyproject.toml | 41 +- smee/__init__.py | 8 +- smee/_models.py | 4 +- smee/_version.py | 716 ------------------- smee/converters/openff/_openff.py | 2 +- smee/converters/openff/nonbonded.py | 4 +- smee/converters/openmm/_openmm.py | 18 +- smee/converters/openmm/nonbonded.py | 14 +- smee/converters/openmm/valence.py | 27 +- smee/geometry.py | 2 +- smee/mm/_config.py | 6 +- smee/mm/_mm.py | 16 +- smee/mm/_ops.py | 12 +- smee/potentials/_potentials.py | 8 +- smee/potentials/nonbonded.py | 12 +- smee/tests/convertors/openff/test_valence.py | 4 +- smee/tests/convertors/test_openmm.py | 6 +- smee/tests/mm/test_ops.py | 4 +- smee/tests/potentials/test_nonbonded.py | 20 +- smee/tests/test_utils.py | 8 +- smee/tests/utils.py | 8 + smee/utils.py | 1 - 26 files changed, 145 insertions(+), 875 deletions(-) delete mode 100644 smee/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b710de2..ac7c335 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,38 +1,15 @@ repos: - - repo: local +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 hooks: - - id: isort - name: "[Package] Import formatting" - language: system - entry: isort - files: \.py$ - - - id: black - name: "[Package] Code formatting" - language: system - entry: black - files: \.py$ + - id: trailing-whitespace + - id: end-of-file-fixer - - id: flake8 - name: "[Package] Linting" +- repo: local + hooks: + - id: ruff + name: Autoformat python code language: system - entry: flake8 + entry: ruff + args: [check] files: \.py$ - - - id: isort-examples - name: "[Examples] Import formatting" - language: system - entry: nbqa isort - files: examples/.+\.ipynb$ - - - id: black-examples - name: "[Examples] Code formatting" - language: system - entry: nbqa black - files: examples/.+\.ipynb$ - - - id: flake8-examples - name: "[Examples] Linting" - language: system - entry: nbqa flake8 --ignore=E402 - files: examples/.+\.ipynb$ \ No newline at end of file diff --git a/Makefile b/Makefile index ed9bc51..f3e046c 100644 --- a/Makefile +++ b/Makefile @@ -4,32 +4,23 @@ CONDA_ENV_RUN := conda run --no-capture-output --name $(PACKAGE_NAME) EXAMPLES_SKIP := examples/md-simulations.ipynb EXAMPLES := $(filter-out $(EXAMPLES_SKIP), $(wildcard examples/*.ipynb)) -.PHONY: pip-install env lint format test test-examples docs-build docs-deploy docs-insiders - -pip-install: - $(CONDA_ENV_RUN) pip install --no-build-isolation --no-deps -e . +.PHONY: env lint format test test-examples docs-build docs-deploy docs-insiders env: mamba create --name $(PACKAGE_NAME) mamba env update --name $(PACKAGE_NAME) --file devtools/envs/base.yaml - $(CONDA_ENV_RUN) pip install --no-build-isolation --no-deps -e . + $(CONDA_ENV_RUN) pip install --no-deps -e . $(CONDA_ENV_RUN) pre-commit install || true lint: - $(CONDA_ENV_RUN) isort --check-only $(PACKAGE_NAME) - $(CONDA_ENV_RUN) black --check $(PACKAGE_NAME) - $(CONDA_ENV_RUN) flake8 $(PACKAGE_NAME) - $(CONDA_ENV_RUN) nbqa isort --check-only examples - $(CONDA_ENV_RUN) nbqa black --check examples - $(CONDA_ENV_RUN) nbqa flake8 --ignore=E402 examples + $(CONDA_ENV_RUN) ruff check $(PACKAGE_NAME) + $(CONDA_ENV_RUN) ruff check examples format: - $(CONDA_ENV_RUN) isort $(PACKAGE_NAME) - $(CONDA_ENV_RUN) black $(PACKAGE_NAME) - $(CONDA_ENV_RUN) flake8 $(PACKAGE_NAME) - $(CONDA_ENV_RUN) nbqa isort examples - $(CONDA_ENV_RUN) nbqa black examples - $(CONDA_ENV_RUN) nbqa flake8 --ignore=E402 examples + $(CONDA_ENV_RUN) ruff format $(PACKAGE_NAME) + $(CONDA_ENV_RUN) ruff check --fix --select I $(PACKAGE_NAME) + $(CONDA_ENV_RUN) ruff format examples + $(CONDA_ENV_RUN) ruff check --fix --select I examples test: $(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_NAME)/tests/ diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index ec7a9a1..92fef6d 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -40,13 +40,10 @@ dependencies: - scipy # test logsumexp implementation - smirnoff-plugins - - versioneer + - setuptools_scm >=8 - pre-commit - - isort - - black - - flake8 - - flake8-pyproject + - ruff - nbqa - pytest diff --git a/examples/parameter-gradients.ipynb b/examples/parameter-gradients.ipynb index a2167b8..f4f357d 100644 --- a/examples/parameter-gradients.ipynb +++ b/examples/parameter-gradients.ipynb @@ -123,13 +123,13 @@ "energy.backward()\n", "\n", "for parameter_key, gradient in zip(\n", - " vdw_potential.parameter_keys, vdw_potential.parameters.grad.numpy()\n", + " vdw_potential.parameter_keys, vdw_potential.parameters.grad.numpy(), strict=True\n", "):\n", " parameter_cols = vdw_potential.parameter_cols\n", "\n", " parameter_grads = \", \".join(\n", " f\"dU/d{parameter_col} = {parameter_grad: 8.3f}\"\n", - " for parameter_col, parameter_grad in zip(parameter_cols, gradient)\n", + " for parameter_col, parameter_grad in zip(parameter_cols, gradient, strict=True)\n", " )\n", " print(f\"{parameter_key.id.ljust(15)} - {parameter_grads}\")" ], diff --git a/pyproject.toml b/pyproject.toml index 6a21a12..2475530 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0", "wheel", "versioneer"] +requires = ["setuptools>=61.0", "setuptools_scm>=8", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -12,35 +12,20 @@ readme = "README.md" requires-python = ">=3.10" classifiers = ["Programming Language :: Python :: 3"] -[tool.setuptools] -zip-safe = false -include-package-data = true +[tool.setuptools.packages.find] +include = ["smee*"] -[tool.setuptools.dynamic] -version = {attr = "smee.__version__"} +[tool.setuptools_scm] -[tool.setuptools.packages.find] -namespaces = true -where = ["."] - -[tool.versioneer] -VCS = "git" -style = "pep440" -versionfile_source = "smee/_version.py" -versionfile_build = "smee/_version.py" -tag_prefix = "" -parentdir_prefix = "smee-" - -[tool.black] -line-length = 88 - -[tool.isort] -profile = "black" - -[tool.flake8] -max-line-length = 88 -ignore = ["E203", "E266", "E501", "W503"] -select = ["B","C","E","F","W","T4","B9"] +[tool.ruff] +extend-include = ["*.ipynb"] + +[tool.ruff.lint] +ignore = ["C901","E402","E501"] +select = ["B","C","E","F","W","B9"] + +[tool.ruff.lint.pydocstyle] +convention = "google" [tool.coverage.run] omit = ["**/tests/*", "**/_version.py"] diff --git a/smee/__init__.py b/smee/__init__.py index 023badf..dfa127b 100755 --- a/smee/__init__.py +++ b/smee/__init__.py @@ -4,7 +4,8 @@ Differentiably evaluate energies of molecules using SMIRNOFF force fields """ -from . import _version +import importlib.metadata + from ._constants import CUTOFF_ATTRIBUTE, SWITCH_ATTRIBUTE, EnergyFn, PotentialType from ._models import ( NonbondedParameterMap, @@ -21,7 +22,10 @@ from .geometry import add_v_site_coords, compute_v_site_coords from .potentials import compute_energy, compute_energy_potential -__version__ = _version.get_versions()["version"] +try: + __version__ = importlib.metadata.version("smee") +except importlib.metadata.PackageNotFoundError: + __version__ = "0+unknown" __all__ = [ "CUTOFF_ATTRIBUTE", diff --git a/smee/_models.py b/smee/_models.py index 2cd136a..e0b0aac 100644 --- a/smee/_models.py +++ b/smee/_models.py @@ -246,7 +246,7 @@ def n_atoms(self) -> int: """The number of atoms in the system.""" return sum( topology.n_atoms * n_copies - for topology, n_copies in zip(self.topologies, self.n_copies) + for topology, n_copies in zip(self.topologies, self.n_copies, strict=True) ) @property @@ -254,7 +254,7 @@ def n_v_sites(self) -> int: """The number of v-sites in the system.""" return sum( topology.n_v_sites * n_copies - for topology, n_copies in zip(self.topologies, self.n_copies) + for topology, n_copies in zip(self.topologies, self.n_copies, strict=True) ) @property diff --git a/smee/_version.py b/smee/_version.py deleted file mode 100644 index 52b2fa5..0000000 --- a/smee/_version.py +++ /dev/null @@ -1,716 +0,0 @@ -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. -# Generated by versioneer-0.29 -# https://github.com/python-versioneer/python-versioneer - -"""Git implementation of _version.py.""" - -import errno -import functools -import os -import re -import subprocess -import sys -from typing import Any, Callable, Dict, List, Optional, Tuple - - -def get_keywords() -> Dict[str, str]: - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - VCS: str - style: str - tag_prefix: str - parentdir_prefix: str - versionfile_source: str - verbose: bool - - -def get_config() -> VersioneerConfig: - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "" - cfg.parentdir_prefix = "smee-" - cfg.versionfile_source = "smee/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator - """Create decorator to mark a method as the handler of a VCS.""" - - def decorate(f: Callable) -> Callable: - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -def run_command( - commands: List[str], - args: List[str], - cwd: Optional[str] = None, - verbose: bool = False, - hide_stderr: bool = False, - env: Optional[Dict[str, str]] = None, -) -> Tuple[Optional[str], Optional[int]]: - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs: Dict[str, Any] = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen( - [command] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - **popen_kwargs, - ) - break - except OSError as e: - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir( - parentdir_prefix: str, - root: str, - verbose: bool, -) -> Dict[str, Any]: - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords: Dict[str, str] = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords( - keywords: Dict[str, str], - tag_prefix: str, - verbose: bool, -) -> Dict[str, Any]: - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r"\d", r): - continue - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs( - tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command -) -> Dict[str, Any]: - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - f"{tag_prefix}[[:digit:]]*", - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces: Dict[str, Any] = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces: Dict[str, Any]) -> str: - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces: Dict[str, Any]) -> str: - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces: Dict[str, Any]) -> str: - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces: Dict[str, Any]) -> str: - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces: Dict[str, Any]) -> str: - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces: Dict[str, Any]) -> str: - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -def get_versions() -> Dict[str, Any]: - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split("/"): - root = os.path.dirname(root) - except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } diff --git a/smee/converters/openff/_openff.py b/smee/converters/openff/_openff.py index a90554b..8bf845b 100644 --- a/smee/converters/openff/_openff.py +++ b/smee/converters/openff/_openff.py @@ -197,7 +197,7 @@ def _convert_v_sites( v_site_maps = [] - for topology, handler in zip(topologies, handlers): + for topology, handler in zip(topologies, handlers, strict=True): if handler is None: v_site_maps.append(None) continue diff --git a/smee/converters/openff/nonbonded.py b/smee/converters/openff/nonbonded.py index c119f0d..1948b48 100644 --- a/smee/converters/openff/nonbonded.py +++ b/smee/converters/openff/nonbonded.py @@ -80,7 +80,9 @@ def convert_nonbonded_handlers( parameter_maps = [] - for handler, topology, v_site_map in zip(handlers, topologies, v_site_maps): + for handler, topology, v_site_map in zip( + handlers, topologies, v_site_maps, strict=True + ): assignment_map = collections.defaultdict(lambda: collections.defaultdict(float)) n_particles = topology.n_atoms + ( diff --git a/smee/converters/openmm/_openmm.py b/smee/converters/openmm/_openmm.py index 267c5fe..e85d25f 100644 --- a/smee/converters/openmm/_openmm.py +++ b/smee/converters/openmm/_openmm.py @@ -81,7 +81,7 @@ def create_openmm_system( omm_system = openmm.System() - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): for _ in range(n_copies): start_idx = omm_system.getNumParticles() @@ -96,7 +96,7 @@ def create_openmm_system( omm_system.addParticle(0.0) for key, parameter_idx in zip( - topology.v_sites.keys, topology.v_sites.parameter_idxs + topology.v_sites.keys, topology.v_sites.parameter_idxs, strict=True ): system_idx = start_idx + topology.v_sites.key_to_idx[key] assert system_idx >= start_idx @@ -124,14 +124,16 @@ def create_openmm_system( def _apply_constraints(omm_system: openmm.System, system: smee.TensorSystem): idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): if topology.constraints is None: continue for _ in range(n_copies): atom_idxs = topology.constraints.idxs + idx_offset - for (i, j), distance in zip(atom_idxs, topology.constraints.distances): + for (i, j), distance in zip( + atom_idxs, topology.constraints.distances, strict=True + ): omm_system.addConstraint(i, j, distance * _ANGSTROM) idx_offset += topology.n_particles @@ -249,7 +251,7 @@ def convert_to_openmm_topology( omm_topology = openmm.app.Topology() - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): chain = omm_topology.addChain() is_water = topology.n_atoms == 3 and sorted( @@ -275,12 +277,14 @@ def convert_to_openmm_topology( ) atoms[i] = omm_topology.addAtom(name, element, residue) - for i in range(topology.n_v_sites): + for _ in range(topology.n_v_sites): omm_topology.addAtom( "X", openmm.app.Element.getByAtomicNumber(82), residue ) - for bond_idxs, bond_order in zip(topology.bond_idxs, topology.bond_orders): + for bond_idxs, bond_order in zip( + topology.bond_idxs, topology.bond_orders, strict=True + ): idx_a, idx_b = int(bond_idxs[0]), int(bond_idxs[1]) bond_order = int(bond_order) diff --git a/smee/converters/openmm/nonbonded.py b/smee/converters/openmm/nonbonded.py index 784fcb2..4d3a238 100644 --- a/smee/converters/openmm/nonbonded.py +++ b/smee/converters/openmm/nonbonded.py @@ -132,7 +132,7 @@ def _build_vdw_lookup( for col, col_idx in parameter_col_to_idx.items() } - for col, col_idx in parameter_col_to_idx.items(): + for col in parameter_col_to_idx: parameter_lookup[col][i + j * n_params] = float( parameters[col] * unit_conversion[col] ) @@ -169,7 +169,7 @@ def _detect_parameters( assigned_vars.add(assigned_var.strip()) parsed_fn = symengine.sympify(line) - free_vars.update(set(str(x) for x in parsed_fn.free_symbols) - assigned_vars) + free_vars.update({str(x) for x in parsed_fn.free_symbols} - assigned_vars) for assigned_var, fn in mixing_fn.items(): fn = fn.strip().strip(";") @@ -180,7 +180,7 @@ def _detect_parameters( assigned_vars.add(assigned_var.strip()) parsed_fn = symengine.sympify(fn) - free_vars.update(set(str(x) for x in parsed_fn.free_symbols) - assigned_vars) + free_vars.update({str(x) for x in parsed_fn.free_symbols} - assigned_vars) free_vars -= assigned_vars @@ -257,7 +257,7 @@ def _add_parameters_to_vdw_without_lookup( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] parameters = parameter_map.assignment_matrix @ potential.parameters @@ -317,7 +317,7 @@ def _add_parameters_to_vdw_with_lookup( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] assignment_dense = parameter_map.assignment_matrix.to_dense() @@ -463,7 +463,7 @@ def convert_lj_potential( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] parameters = parameter_map.assignment_matrix @ potential.parameters @@ -534,7 +534,7 @@ def convert_coulomb_potential( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] parameters = parameter_map.assignment_matrix @ potential.parameters diff --git a/smee/converters/openmm/valence.py b/smee/converters/openmm/valence.py index 3690d0a..3465892 100644 --- a/smee/converters/openmm/valence.py +++ b/smee/converters/openmm/valence.py @@ -21,7 +21,7 @@ def convert_bond_potential( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameters = ( topology.parameters[potential.type].assignment_matrix @ potential.parameters ) @@ -29,7 +29,7 @@ def convert_bond_potential( for _ in range(n_copies): atom_idxs = topology.parameters[potential.type].particle_idxs + idx_offset - for (i, j), (constant, length) in zip(atom_idxs, parameters): + for (i, j), (constant, length) in zip(atom_idxs, parameters, strict=True): force.addBond( i, j, @@ -53,7 +53,7 @@ def _convert_angle_potential( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameters = ( topology.parameters[potential.type].assignment_matrix @ potential.parameters ) @@ -61,7 +61,7 @@ def _convert_angle_potential( for _ in range(n_copies): atom_idxs = topology.parameters[potential.type].particle_idxs + idx_offset - for (i, j, k), (constant, angle) in zip(atom_idxs, parameters): + for (i, j, k), (constant, angle) in zip(atom_idxs, parameters, strict=True): force.addAngle( i, j, @@ -89,7 +89,7 @@ def convert_torsion_potential( idx_offset = 0 - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameters = ( topology.parameters[potential.type].assignment_matrix @ potential.parameters ) @@ -97,14 +97,17 @@ def convert_torsion_potential( for _ in range(n_copies): atom_idxs = topology.parameters[potential.type].particle_idxs + idx_offset - for (i, j, k, l), (constant, periodicity, phase, idivf) in zip( - atom_idxs, parameters - ): + for (idx_i, idx_j, idx_k, idx_l), ( + constant, + periodicity, + phase, + idivf, + ) in zip(atom_idxs, parameters, strict=True): force.addTorsion( - i, - j, - k, - l, + idx_i, + idx_j, + idx_k, + idx_l, int(periodicity), phase * _RADIANS, constant / idivf * _KCAL_PER_MOL, diff --git a/smee/geometry.py b/smee/geometry.py index 7c22283..36f9034 100644 --- a/smee/geometry.py +++ b/smee/geometry.py @@ -190,7 +190,7 @@ def _build_v_site_coord_frames( stacked_frames = [[], [], [], []] - for key, weight in zip(v_sites.keys, weights): + for key, weight in zip(v_sites.keys, weights, strict=True): parent_coords = conformer[:, key.orientation_atom_indices, :] weighted_coords = torch.transpose( (torch.transpose(parent_coords, 1, 2) @ weight.T), 1, 2 diff --git a/smee/mm/_config.py b/smee/mm/_config.py index 5e3cad9..7a4fb7d 100644 --- a/smee/mm/_config.py +++ b/smee/mm/_config.py @@ -26,8 +26,10 @@ def _quantity_validator( try: return value.in_units_of(expected_units) - except TypeError: - raise ValueError(f"invalid units {value.unit} - expected {expected_units}") + except TypeError as e: + raise ValueError( + f"invalid units {value.unit} - expected {expected_units}" + ) from e def _quantity_serializer(value: openmm.unit.Quantity) -> str: diff --git a/smee/mm/_mm.py b/smee/mm/_mm.py index ca4b993..98d7e50 100644 --- a/smee/mm/_mm.py +++ b/smee/mm/_mm.py @@ -42,7 +42,9 @@ def _apply_hmr( idx_offset = 0 - for topology, n_copies in zip(system_smee.topologies, system_smee.n_copies): + for topology, n_copies in zip( + system_smee.topologies, system_smee.n_copies, strict=True + ): for _ in range(n_copies): for idx_a, idx_b in topology.bond_idxs: if topology.atomic_nums[idx_a] == 1: @@ -75,12 +77,16 @@ def _topology_to_rdkit(topology: smee.TensorTopology) -> Chem.Mol: """Convert a topology to an RDKit molecule.""" mol = Chem.RWMol() - for atomic_num, formal_charge in zip(topology.atomic_nums, topology.formal_charges): + for atomic_num, formal_charge in zip( + topology.atomic_nums, topology.formal_charges, strict=True + ): atom = Chem.Atom(int(atomic_num)) atom.SetFormalCharge(int(formal_charge)) mol.AddAtom(atom) - for bond_idxs, bond_order in zip(topology.bond_idxs, topology.bond_orders): + for bond_idxs, bond_order in zip( + topology.bond_idxs, topology.bond_orders, strict=True + ): idx_a, idx_b = int(bond_idxs[0]), int(bond_idxs[1]) mol.AddBond(idx_a, idx_b, Chem.BondType(bond_order)) @@ -114,7 +120,7 @@ def _topology_to_xyz( "", *[ f"{element} {x:f} {y:f} {z:f}" - for (element, (x, y, z)) in zip(elements, coords) + for (element, (x, y, z)) in zip(elements, coords, strict=True) ], ] ) @@ -143,7 +149,7 @@ def _approximate_box_size( for atomic_num in topology.atomic_nums ) * n_copies - for topology, n_copies in zip(system.topologies, system.n_copies) + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True) ) volume = weight / openmm.unit.AVOGADRO_CONSTANT_NA / config.target_density diff --git a/smee/mm/_ops.py b/smee/mm/_ops.py index c4f7e59..88ec422 100644 --- a/smee/mm/_ops.py +++ b/smee/mm/_ops.py @@ -139,7 +139,7 @@ def _get_mass(v: int) -> float: return sum( sum(_get_mass(atomic_num) for atomic_num in topology.atomic_nums) * n_copies - for topology, n_copies in zip(system.topologies, system.n_copies) + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True) ) @@ -315,7 +315,7 @@ def forward(ctx, kwargs: _EnsembleAverageKwargs, *theta: torch.Tensor): ctx.mark_non_differentiable(*avg_stds) - return tuple([*avg_values, *avg_stds, tuple(columns)]) + return *avg_values, *avg_stds, tuple(columns) @staticmethod def backward(ctx, *grad_outputs): @@ -436,7 +436,7 @@ def forward(ctx, kwargs: _ReweightAverageKwargs, *theta: torch.Tensor): ctx.columns = columns ctx.save_for_backward(*theta, *du_d_theta, delta, weights, values) - return tuple([*avg_values, columns]) + return *avg_values, columns @staticmethod def backward(ctx, *grad_outputs): @@ -555,8 +555,8 @@ def compute_ensemble_averages( avg_std = avg_outputs[len(avg_outputs) // 2 :] return ( - {column: avg for avg, column in zip(avg_values, columns)}, - {column: avg for avg, column in zip(avg_std, columns)}, + {column: avg for avg, column in zip(avg_values, columns, strict=True)}, + {column: avg for avg, column in zip(avg_std, columns, strict=True)}, ) @@ -612,4 +612,4 @@ def reweight_ensemble_averages( } *avg_outputs, columns = _ReweightAverageOp.apply(kwargs, *tensors) - return {column: avg for avg, column in zip(avg_outputs, columns)} + return {column: avg for avg, column in zip(avg_outputs, columns, strict=True)} diff --git a/smee/potentials/_potentials.py b/smee/potentials/_potentials.py index b8a1128..49a5e0c 100644 --- a/smee/potentials/_potentials.py +++ b/smee/potentials/_potentials.py @@ -50,7 +50,7 @@ def broadcast_parameters( parameters = [] - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] topology_parameters = parameter_map.assignment_matrix @ potential.parameters @@ -97,7 +97,7 @@ def broadcast_exceptions( parameter_idxs = [] - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] if isinstance(parameter_map, smee.ValenceParameterMap): @@ -131,7 +131,7 @@ def broadcast_exceptions( parameter_idxs_a = parameter_idxs[idxs_a] parameter_idxs_b = parameter_idxs[idxs_b] - if len(set((min(i, j), max(i, j)) for i, j in potential.exceptions)) != len( + if len({(min(i, j), max(i, j)) for i, j in potential.exceptions}) != len( potential.exceptions ): raise NotImplementedError("cannot define different exceptions for i-j and j-i") @@ -172,7 +172,7 @@ def broadcast_idxs( per_topology_idxs = [] - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_map = topology.parameters[potential.type] n_interacting_particles = parameter_map.particle_idxs.shape[-1] diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py index 2f4f033..a8ac45a 100644 --- a/smee/potentials/nonbonded.py +++ b/smee/potentials/nonbonded.py @@ -56,7 +56,7 @@ def _broadcast_exclusions( per_topology_exclusion_idxs = [] per_topology_exclusion_scales = [] - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): exclusion_idxs = topology.parameters[potential.type].exclusions exclusion_offset = ( @@ -212,10 +212,10 @@ def prepare_lrc_types( """ n_by_type = collections.defaultdict(int) - for topology, n_copies in zip(system.topologies, system.n_copies): + for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): parameter_counts = topology.parameters["vdW"].assignment_matrix.abs().sum(dim=0) - for key, count in zip(potential.parameter_keys, parameter_counts): + for key, count in zip(potential.parameter_keys, parameter_counts, strict=True): n_by_type[key] += count.item() * n_copies counts = smee.utils.tensor_like( @@ -815,8 +815,8 @@ def _compute_pme_exclusions( ] max_exclusions = 0 - for exclusions, topology, n_copies in zip( - exclusion_templates, system.topologies, system.n_copies + for exclusions, topology in zip( + exclusion_templates, system.topologies, strict=True ): for i, j in topology.parameters[potential.type].exclusions: exclusions[i].append(int(j)) @@ -830,7 +830,7 @@ def _compute_pme_exclusions( exclusions_per_type = [] for exclusions, topology, n_copies in zip( - exclusion_templates, system.topologies, system.n_copies + exclusion_templates, system.topologies, system.n_copies, strict=True ): for atom_exclusions in exclusions: n_padding = max_exclusions - len(atom_exclusions) diff --git a/smee/tests/convertors/openff/test_valence.py b/smee/tests/convertors/openff/test_valence.py index 10a8bb0..1aa4aa6 100644 --- a/smee/tests/convertors/openff/test_valence.py +++ b/smee/tests/convertors/openff/test_valence.py @@ -39,7 +39,7 @@ def test_convert_bonds(ethanol, ethanol_interchange): actual_parameters = { tuple(particle_idxs.tolist()): parameter_keys[parameter_idxs.nonzero()] for parameter_idxs, particle_idxs in zip( - assignment_matrix, parameter_map.particle_idxs + assignment_matrix, parameter_map.particle_idxs, strict=True ) } expected_parameters = { @@ -91,7 +91,7 @@ def test_convert_propers(ethanol, ethanol_interchange): potential.parameter_keys[parameter_idxs.nonzero()].mult, ) for parameter_idxs, particle_idxs in zip( - assignment_matrix, parameter_map.particle_idxs + assignment_matrix, parameter_map.particle_idxs, strict=True ) } expected_parameters = { diff --git a/smee/tests/convertors/test_openmm.py b/smee/tests/convertors/test_openmm.py index ccb392c..8f09952 100644 --- a/smee/tests/convertors/test_openmm.py +++ b/smee/tests/convertors/test_openmm.py @@ -122,7 +122,7 @@ def compare_vec3(a: openmm.Vec3, b: openmm.Vec3): ] for i, (v_site_interchange, v_site_smee) in enumerate( - zip(v_sites_interchange, v_sites_smee) + zip(v_sites_interchange, v_sites_smee, strict=True) ): assert v_site_smee.getNumParticles() == v_site_interchange.getNumParticles() @@ -171,7 +171,7 @@ def test_convert_to_openmm_system_periodic(): n_copies_per_mol = [5, 5] # carbonic acid has impropers, 1-5 interactions so should test most convertors - for smiles, n_copies in zip(["OC(=O)O", "O"], n_copies_per_mol): + for smiles, n_copies in zip(["OC(=O)O", "O"], n_copies_per_mol, strict=True): mol = openff.toolkit.Molecule.from_smiles(smiles) mol.generate_conformers(n_conformers=1) @@ -251,7 +251,7 @@ def test_convert_to_openmm_system_dexp_periodic(test_data_dir): n_copies_per_mol = [5, 5] - for smiles, n_copies in zip(["OCCO", "O"], n_copies_per_mol): + for smiles, n_copies in zip(["OCCO", "O"], n_copies_per_mol, strict=True): mol = openff.toolkit.Molecule.from_smiles(smiles) mol.generate_conformers(n_conformers=1) diff --git a/smee/tests/mm/test_ops.py b/smee/tests/mm/test_ops.py index 90bd654..97ebdba 100644 --- a/smee/tests/mm/test_ops.py +++ b/smee/tests/mm/test_ops.py @@ -86,7 +86,7 @@ def test_pack_unpack_force_field(mocker): expected_v_site = updated_tensors[-1] for i, (original, unpacked) in enumerate( - zip(force_field.potentials, unpacked_force_field.potentials) + zip(force_field.potentials, unpacked_force_field.potentials, strict=True) ): assert original.type == unpacked.type assert original.fn == unpacked.fn @@ -238,7 +238,7 @@ def test_compute_observables(tmp_path, mock_argon_tensors, mock_argon_params): frames_path = tmp_path / ("frames.msgpack") with frames_path.open("wb") as file: - for coord, box_vector in zip(coords, box_vectors): + for coord, box_vector in zip(coords, box_vectors, strict=True): frame = ( torch.tensor(coord).float(), torch.tensor(box_vector).float(), diff --git a/smee/tests/potentials/test_nonbonded.py b/smee/tests/potentials/test_nonbonded.py index ae5caba..456c75f 100644 --- a/smee/tests/potentials/test_nonbonded.py +++ b/smee/tests/potentials/test_nonbonded.py @@ -83,11 +83,11 @@ def test_compute_pairwise_scales(): [0.01, 0.02, 0.02, 1.0, 0.02] + [1.0] * (system.n_particles - 5), [0.01, 0.02, 0.02, 0.02, 1.0] + [1.0] * (system.n_particles - 5), # - [1.0] * 5 + [1.0, 0.01, 0.01, 0.01, 0.01] + [1.0] * (system.n_particles - 10), - [1.0] * 5 + [0.01, 1.0, 0.02, 0.02, 0.02] + [1.0] * (system.n_particles - 10), - [1.0] * 5 + [0.01, 0.02, 1.0, 0.02, 0.02] + [1.0] * (system.n_particles - 10), - [1.0] * 5 + [0.01, 0.02, 0.02, 1.0, 0.02] + [1.0] * (system.n_particles - 10), - [1.0] * 5 + [0.01, 0.02, 0.02, 0.02, 1.0] + [1.0] * (system.n_particles - 10), + [1.0] * 5 + [1.0, 0.01, 0.01, 0.01, 0.01] + [1.0] * (system.n_particles - 10), # noqa: E501 + [1.0] * 5 + [0.01, 1.0, 0.02, 0.02, 0.02] + [1.0] * (system.n_particles - 10), # noqa: E501 + [1.0] * 5 + [0.01, 0.02, 1.0, 0.02, 0.02] + [1.0] * (system.n_particles - 10), # noqa: E501 + [1.0] * 5 + [0.01, 0.02, 0.02, 1.0, 0.02] + [1.0] * (system.n_particles - 10), # noqa: E501 + [1.0] * 5 + [0.01, 0.02, 0.02, 0.02, 1.0] + [1.0] * (system.n_particles - 10), # noqa: E501 # [1.0] * 10 + [1.0, 0.01, 0.01] + [1.0] * (system.n_particles - 13), [1.0] * 10 + [0.01, 1.0, 0.02] + [1.0] * (system.n_particles - 13), @@ -199,7 +199,7 @@ def test_compute_pairwise_non_periodic(with_batch): ) n_expected_pairs = len(expected_idxs) - expected_batch_size = tuple() if not with_batch else (1,) + expected_batch_size = () if not with_batch else (1,) assert pairwise.idxs.shape == (n_expected_pairs, 2) assert torch.allclose(pairwise.idxs, expected_idxs) @@ -259,6 +259,14 @@ def test_prepare_lrc_types(with_exceptions): idxs_i, idxs_j, n_ij_interactions = prepare_lrc_types(system, vdw_potential) + if with_exceptions: + subset_idxs = [0, 1, 2, 13, 14, 25, 91, 92, 93, 94] + + idxs_i = idxs_i[subset_idxs] + idxs_j = idxs_j[subset_idxs] + + n_ij_interactions = n_ij_interactions[subset_idxs] + assert idxs_i.shape == expected_idxs_i.shape assert torch.allclose(idxs_i, expected_idxs_i) diff --git a/smee/tests/test_utils.py b/smee/tests/test_utils.py index e602004..23a22d4 100644 --- a/smee/tests/test_utils.py +++ b/smee/tests/test_utils.py @@ -9,7 +9,7 @@ def test_find_exclusions_simple(): molecule = openff.toolkit.Molecule() - for i in range(6): + for _ in range(6): molecule.add_atom(6, 0, False) for i in range(5): molecule.add_bond(i, i + 1, 1, False) @@ -35,7 +35,7 @@ def test_find_exclusions_simple(): def test_find_exclusions_rings(): molecule = openff.toolkit.Molecule() - for i in range(8): + for _ in range(8): molecule.add_atom(6, 0, False) # para substituted 6-membered ring @@ -79,7 +79,7 @@ def test_find_exclusions_rings(): def test_find_exclusions_dimer(): molecule = openff.toolkit.Molecule() - for i in range(3): + for _ in range(3): molecule.add_atom(6, 0, False) molecule.add_bond(0, 1, 1, False) @@ -102,7 +102,7 @@ def test_find_exclusions_dimer(): def test_find_exclusions_v_sites(): molecule = openff.toolkit.Molecule() - for i in range(4): + for _ in range(4): molecule.add_atom(6, 0, False) for i in range(3): molecule.add_bond(i, i + 1, 1, False) diff --git a/smee/tests/utils.py b/smee/tests/utils.py index a62163f..a725bda 100644 --- a/smee/tests/utils.py +++ b/smee/tests/utils.py @@ -1,6 +1,7 @@ import typing import openff.interchange +import openff.interchange.models import openff.toolkit import openff.units import torch @@ -232,6 +233,13 @@ def add_explicit_lb_exceptions( potential.parameters = torch.vstack( [torch.zeros_like(potential.parameters), torch.stack([eps_ij, sig_ij], dim=-1)] ) + for i in range(len(eps_ij)): + potential.parameter_keys.append( + openff.interchange.models.PotentialKey( + id=f"exception-{i}", associated_handler="vdW" + ) + ) + potential.exceptions = { (int(idx_i), int(idx_j)): i + n_params for i, (idx_i, idx_j) in enumerate(zip(idxs_i, idxs_j, strict=True)) diff --git a/smee/utils.py b/smee/utils.py index 9bc8408..4b40540 100644 --- a/smee/utils.py +++ b/smee/utils.py @@ -41,7 +41,6 @@ def find_exclusions( ) if v_sites is not None: - for v_site_key in v_sites.keys: v_site_idx = v_sites.key_to_idx[v_site_key] parent_idx = v_site_key.orientation_atom_indices[0]