Skip to content

Commit

Permalink
New setup (#358)
Browse files Browse the repository at this point in the history
* updated installation for cuda 12.6

* new setup.py pyproject.toml installation

* cleanup

* fix tests and deprecations

* most deps are added now by the pip package

* cleanup deps more

* better handling of CUDA builds

* correct building and installation

* reorganize CI

* install nnpops later

* fix nnpops installation on correct OSs

* nnpops is installing pytorch-cpu from conda which is messing with the torch used to compile from pip

* pytorch doesn't generate pip packages anymore for old osx machines

* be a bit more flexible with specific builds

* fix docs building

* fix python version

* delayed import of NNPops to fix issues in docs generation

* fixed paths to scripts module

* undo commenting out
  • Loading branch information
stefdoerr authored Feb 12, 2025
1 parent d616c8a commit 1502491
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 143 deletions.
45 changes: 25 additions & 20 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os:
[
os: [
"ubuntu-latest",
"ubuntu-22.04-arm",
"macos-latest",
"macos-13",
# "macos-13",
"windows-2022",
]
python-version: ["3.10"]
Expand Down Expand Up @@ -50,41 +49,47 @@ jobs:
conda-remove-defaults: "true"
if: matrix.os == 'macos-13'

- name: Install OS-specific conda dependencies
- name: Install OS-specific compilers
run: |
if [[ "${{ matrix.os }}" == "ubuntu-22.04-arm" ]]; then
conda install --file conda_deps_linux_aarch64.txt --channel conda-forge --override-channels
conda install gxx --channel conda-forge --override-channels
elif [[ "${{ runner.os }}" == "Linux" ]]; then
conda install --file conda_deps_linux.txt --channel conda-forge --override-channels
conda install gxx --channel conda-forge --override-channels
elif [[ "${{ runner.os }}" == "macOS" ]]; then
conda install --file conda_deps_osx.txt --channel conda-forge --override-channels
conda install clangxx llvm-openmp pybind11 --channel conda-forge --override-channels
elif [[ "${{ runner.os }}" == "Windows" ]]; then
conda install --file conda_deps_win.txt --channel conda-forge --override-channels
conda install vc vc14_runtime vs2015_runtime --channel conda-forge --override-channels
fi
- name: Install testing packages
run: conda install -y -c conda-forge flake8 pytest psutil

- name: List the conda environment
run: conda list

- name: Install testing packages
run: conda install -y -c conda-forge flake8 pytest psutil python-build

- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Build and install the package
run: |
if [[ "${{ runner.os }}" == "Windows" ]]; then
export LIB="C:/Miniconda/envs/test/Library/lib"
pip -vv install .
else
pip -vv install .
fi
python -m build
pip install dist/*.whl
env:
WITH_CUDA: "0"

- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Install nnpops
# if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest'
# run: conda install nnpops --channel conda-forge --override-channels

- name: List the conda environment
run: conda list

- name: Run tests
run: pytest -v -s --durations=10
Expand Down
39 changes: 20 additions & 19 deletions .github/workflows/docs_build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,30 @@ on:
branches:
- "main"


jobs:
build-docs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4

- uses: conda-incubator/setup-miniconda@v3
with:
python-version: "3.10"
channels: conda-forge
conda-remove-defaults: "true"

- name: Install compiler
run: conda install gxx --channel conda-forge --override-channels

- name: Set up Env
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment.yml
init-shell: bash
generate-run-shell: true
- name: Install docs dependencies
run: |
pip install -vv .
pip install -r docs/requirements.txt
shell: bash -el {0}

- name: Install docs dependencies
run: |
pip install -vv .
pip install -r docs/requirements.txt
shell: bash -el {0}

- name: Build Sphinx Documentation
run: |
cd docs
make html
shell: bash -el {0}
- name: Build Sphinx Documentation
run: |
cd docs
make html
shell: bash -el {0}
10 changes: 0 additions & 10 deletions conda_deps_linux.txt

This file was deleted.

9 changes: 0 additions & 9 deletions conda_deps_linux_aarch64.txt

This file was deleted.

12 changes: 0 additions & 12 deletions conda_deps_osx.txt

This file was deleted.

13 changes: 0 additions & 13 deletions conda_deps_win.txt

This file was deleted.

3 changes: 1 addition & 2 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ It is recommended to install the same version as the one used by torch.

.. code-block:: shell
conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version gxx pytorch=*=*cuda*
conda install -c conda-forge python=3.10 cuda-version=12.6 cuda-nvvm cuda-nvcc cuda-libraries-dev
* CUDA<12
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ To implement a new architecture, you need to follow these steps:
**shared_args,
)
4. Add any new parameters required to initialize your module to scripts.train.get_args:
4. Add any new parameters required to initialize your module to torchmdnet.scripts.train.get_args:
.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion docs/source/torchmd-train.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Command line interface
~~~~~~~~~~~~~~~~~~~~~~


.. autoprogram:: scripts.train:get_argparse()
.. autoprogram:: torchmdnet.scripts.train:get_argparse()
:prog: torchmd-train


38 changes: 38 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[project]
name = "torchmd-net"
description = "TorchMD-NET provides state-of-the-art neural networks potentials for biomolecular systems"
authors = [{ name = "Acellera", email = "info@acellera.com" }]
readme = "README.md"
requires-python = ">=3.8"
dynamic = ["version"]
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: POSIX :: Linux",
]
dependencies = [
"h5py",
# "nnpops",
"torch==2.5.1.*",
"torch_geometric",
"lightning",
"tqdm",
"pandas",
]

[project.urls]
"Homepage" = "https://github.com/torchmd/torchmd-net"
"Bug Tracker" = "https://github.com/torchmd/torchmd-net/issues"

[project.scripts]
torchmd-train = "torchmdnet.scripts.train:main"

[tool.setuptools_scm]

[tool.setuptools.packages.find]
where = [""]
include = ["torchmdnet*"]
namespaces = false

[build-system]
requires = ["setuptools>=64", "setuptools-scm>=8", "torch==2.5.1.*"]
build-backend = "setuptools.build_meta"
52 changes: 10 additions & 42 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,17 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import subprocess
from setuptools import setup, find_packages
from setuptools import setup
import torch
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
include_paths,
CppExtension,
)
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
import os
import sys

is_windows = sys.platform == "win32"

try:
version = (
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
.strip()
.decode("utf-8")
)
except Exception:
print("Failed to retrieve the current version, defaulting to 0")
version = "0"

# If WITH_CUDA is defined
if os.environ.get("WITH_CUDA", "0") == "1":
use_cuda = True
elif os.environ.get("WITH_CUDA", "0") == "0":
use_cuda = False
else:
use_cuda = torch.cuda._is_compiled()

Expand All @@ -37,13 +21,12 @@ def set_torch_cuda_arch_list():
"""Set the CUDA arch list according to the architectures the current torch installation was compiled for.
This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
"""
if not os.environ.get("TORCH_CUDA_ARCH_LIST"):
if use_cuda:
arch_flags = torch._C._cuda_getArchFlags()
sm_versions = [x[3:] for x in arch_flags.split() if x.startswith("sm_")]
formatted_versions = ";".join([f"{y[0]}.{y[1]}" for y in sm_versions])
formatted_versions += "+PTX"
os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions
if use_cuda and not os.environ.get("TORCH_CUDA_ARCH_LIST"):
arch_flags = torch._C._cuda_getArchFlags()
sm_versions = [x[3:] for x in arch_flags.split() if x.startswith("sm_")]
formatted_versions = ";".join([f"{y[0]}.{y[1]}" for y in sm_versions])
formatted_versions += "+PTX"
os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions


set_torch_cuda_arch_list()
Expand All @@ -61,30 +44,15 @@ def set_torch_cuda_arch_list():
name="torchmdnet.extensions.torchmdnet_extensions",
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")]
+ neighbor_sources,
include_dirs=include_paths(),
define_macros=[("WITH_CUDA", 1)] if use_cuda else [],
)

if __name__ == "__main__":
setup(
name="torchmd-net",
version=version,
packages=find_packages(),
ext_modules=[extensions],
cmdclass={
"build_ext": BuildExtension.with_options(
no_python_abi_suffix=True, use_ninja=False
)
},
include_package_data=True,
entry_points={
"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]
},
package_data={
"torchmdnet": (
["extensions/torchmdnet_extensions.so"]
if not is_windows
else ["extensions/torchmdnet_extensions.dll"]
)
},
)
17 changes: 12 additions & 5 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest
from os.path import dirname, join
from torchmdnet.calculators import External
Expand All @@ -15,6 +15,8 @@
@pytest.mark.parametrize("box", [None, torch.eye(3)])
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
def test_compare_forward(box, use_cuda_graphs):
from copy import deepcopy

if use_cuda_graphs and not torch.cuda.is_available():
pytest.skip("CUDA not available")
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
Expand Down Expand Up @@ -48,14 +50,19 @@ def test_compare_forward(box, use_cuda_graphs):
checkpoint, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device
)
calc.model = model
# Path the model
model = deepcopy(model)
model.representation_model.distance.check_errors = not use_cuda_graphs
model.representation_model.static_shapes = use_cuda_graphs
model.representation_model.distance.resize_to_fit = not use_cuda_graphs
calc_graph.model = model
if box is not None:
box = (box * 2 * args["cutoff_upper"]).unsqueeze(0)
for _ in range(10):
e_calc, f_calc = calc.calculate(pos, box)
e_pred, f_pred = calc_graph.calculate(pos, box)
assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred)
assert_close(e_calc, e_pred)
assert_close(f_calc, f_pred)


def test_compare_forward_multiple():
Expand All @@ -72,5 +79,5 @@ def test_compare_forward_multiple():
torch.cat([torch.zeros(len(z1)), torch.ones(len(z2))]).long(),
)

assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred.view(-1, len(z1), 3))
assert_close(e_calc, e_pred)
assert_close(f_calc, f_pred.view(-1, len(z1), 3), rtol=1e-4, atol=1e-5)
4 changes: 2 additions & 2 deletions tests/test_equivariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_scalar_invariance():

y = model(z, pos, batch)[0]
y_rot = model(z, pos @ rotate, batch)[0]
torch.testing.assert_allclose(y, y_rot)
torch.testing.assert_close(y, y_rot)


def test_vector_equivariance():
Expand All @@ -50,4 +50,4 @@ def test_vector_equivariance():

y = model(z, pos, batch)[0]
y_rot = model(z, pos @ rotate, batch)[0]
torch.testing.assert_allclose(y @ rotate, y_rot)
torch.testing.assert_close(y @ rotate, y_rot)
Loading

0 comments on commit 1502491

Please sign in to comment.