-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsetup.py
84 lines (82 loc) · 2.9 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
if 'CXX' not in os.environ:
os.environ['CXX'] = 'g++'
print("Set C++ compiler: g++")
setup(
name='torchff',
classifiers=[
'Development Status :: Beta',
'Natural Language :: English',
'Intended Audience :: Science/Research',
'Programming Language :: Python :: 3.12',
],
packages=find_packages(exclude=['csrc', 'tests', 'docs']),
ext_modules=[
CUDAExtension(
name='torchff_harmonic_bond',
sources=[
'csrc/bond/harmonic_bond_interface.cpp',
'csrc/bond/harmonic_bond_cpu.cpp',
'csrc/bond/harmonic_bond_cuda.cu'
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-arch=sm_80']
},
),
CUDAExtension(
name='torchff_harmonic_angle',
sources=[
'csrc/angle/harmonic_angle_interface.cpp',
'csrc/angle/harmonic_angle_cpu.cpp',
'csrc/angle/harmonic_angle_cuda.cu'
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-arch=sm_80']
},
include_dirs=[os.path.join(os.path.dirname(__file__), "csrc")]
),
CUDAExtension(
name='torchff_coulomb',
sources=['csrc/coulomb/coulomb_interface.cpp', 'csrc/coulomb/coulomb_cuda.cu'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-arch=sm_80']
},
include_dirs=[os.path.join(os.path.dirname(__file__), "csrc")]
),
CUDAExtension(
name='torchff_periodic_torsion',
sources=['csrc/torsion/periodic_torsion_interface.cpp', 'csrc/torsion/periodic_torsion_cuda.cu'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-arch=sm_80']
},
include_dirs=[os.path.join(os.path.dirname(__file__), "csrc")]
),
CUDAExtension(
name='torchff_vdw',
sources=['csrc/vdw/lennard_jones_interface.cpp', 'csrc/vdw/lennard_jones_cuda.cu'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-arch=sm_80']
},
include_dirs=[os.path.join(os.path.dirname(__file__), "csrc")]
),
CUDAExtension(
name='torchff_nblist',
sources=['csrc/nblist/nblist_interface.cpp', 'csrc/nblist/nblist_nsquared_cuda.cu'],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-arch=sm_80']
},
include_dirs=[os.path.join(os.path.dirname(__file__), "csrc")]
)
],
cmdclass={
'build_ext': BuildExtension
}
)