Skip to content

Enable builds without direct torch.cuda availability and support sm89 / sm90. #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 193 additions & 58 deletions csrc/flashfftconv/setup.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,210 @@
from __future__ import annotations

import os
import subprocess

import torch

from functools import cache
from pathlib import Path
from typing import Tuple, List

from packaging.version import parse, Version
from setuptools import setup

from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
import subprocess

def get_last_arch_torch():
arch = torch.cuda.get_arch_list()[-1]
print(f"Found arch: {arch} from existing torch installation")
return arch

def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]

return raw_output, bare_metal_major, bare_metal_minor

def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]

CUDA_PATH: Path = Path(CUDA_HOME)
TORCH_VERSION: Version = Version(torch.__version__)

TORCH_MAJOR: int = TORCH_VERSION.major
TORCH_MINOR: int = TORCH_VERSION.minor

EXTENSION_NAME: str = 'monarch_cuda'
EXTENDED_CAPABILITIES: Tuple[int, ...] = (89, 90)


@cache
def get_cuda_bare_metal_version(cuda_dir: Path) -> Tuple[str, Version]:

raw = (
subprocess.run(
[str(cuda_dir / 'bin' / 'nvcc'), '-V'],
capture_output=True,
check=True,
encoding='utf-8',
)
.stdout
)

output = raw.split()
version, _, _ = output[output.index('release') + 1].partition(',')

return raw, parse(version)


def raise_if_cuda_home_none(global_option: str) -> None:

if CUDA_HOME is None:

raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your "
"environment has nvcc available? If you're installing within a container from "
"https://hub.docker.com/r/pytorch/pytorch, only images whose names contain "
"'devel' will provide nvcc."
)


def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:

_, version = get_cuda_bare_metal_version(CUDA_PATH)

if version >= Version("11.2"):

nvcc_extra_args.extend(("--threads", "4"))

return nvcc_extra_args

arch = get_last_arch_torch()
# [MP] make install more flexible here
sm_num = arch[-2:]
cc_flag = ['--generate-code=arch=compute_80,code=compute_80']

def arch_flags(compute: int, ptx: bool = False) -> str:

build = 'compute' if ptx else 'sm'

return f'arch=compute_{compute},code={build}_{compute}'


class CompilerFlags(List[str]):

def add_arch(self, compute: int, ptx: bool = True, sass: bool = False):

if ptx:

self.append("-gencode")
self.append(arch_flags(compute, True))

if sass:

self.append("-gencode")
self.append(arch_flags(compute, False))

return self


def build_compiler_flags(
ptx: bool = True, sass: bool = False, multi_arch: bool = False
) -> List[str]:

flags = (
CompilerFlags()
.add_arch(compute=80, ptx=ptx, sass=sass)
)

if multi_arch:

_, version = get_cuda_bare_metal_version(CUDA_PATH)

if version < Version("11.0"):

raise RuntimeError(f"{EXTENSION_NAME} is only supported on CUDA 11 and above")

elif version >= Version("11.8"):

for compute in EXTENDED_CAPABILITIES:

flags.add_arch(compute=compute, ptx=ptx, sass=sass)

return flags


if not torch.cuda.is_available():

print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, FlashFFTConv will cross-compile for Ampere (compute capabilities 8.0 and 8.6) "
"and if CUDA version >= 11.8, Ada (compute capability 8.9) and Hopper (compute capability 9.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)

if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:

_, bare_metal_version = get_cuda_bare_metal_version(CUDA_PATH)

if bare_metal_version >= Version("11.8"):

os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0"

elif bare_metal_version >= Version("11.1"):

os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"

elif bare_metal_version == Version("11.0"):

os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0"

else:

raise RuntimeError(f"{EXTENSION_NAME} is only supported on CUDA 11 and above")


# Log PyTorch Version.
print(f"\n\ntorch.__version__ = {TORCH_VERSION}\n\n")


# Verify that CUDA_HOME exists.
raise_if_cuda_home_none(EXTENSION_NAME)


setup(
name='monarch_cuda',
ext_modules=[
CUDAExtension('monarch_cuda', [
'monarch.cpp',
'monarch_cuda/monarch_cuda_interface_fwd.cu',
'monarch_cuda/monarch_cuda_interface_fwd_complex.cu',
'monarch_cuda/monarch_cuda_interface_fwd_bf16.cu',
'monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu',
'monarch_cuda/monarch_cuda_interface_fwd_r2r.cu',
'monarch_cuda/monarch_cuda_interface_fwd_r2r_bf16.cu',
'monarch_cuda/monarch_cuda_interface_bwd.cu',
'monarch_cuda/monarch_cuda_interface_bwd_complex.cu',
'monarch_cuda/monarch_cuda_interface_bwd_bf16.cu',
'monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu',
'monarch_cuda/monarch_cuda_interface_bwd_r2r.cu',
'monarch_cuda/monarch_cuda_interface_bwd_r2r_bf16.cu',
'butterfly/butterfly_cuda.cu',
'butterfly/butterfly_padded_cuda.cu',
'butterfly/butterfly_padded_cuda_bf16.cu',
'butterfly/butterfly_ifft_cuda.cu',
'butterfly/butterfly_cuda_bf16.cu',
'butterfly/butterfly_ifft_cuda_bf16.cu',
'butterfly/butterfly_padded_ifft_cuda.cu',
'butterfly/butterfly_padded_ifft_cuda_bf16.cu',
'conv1d/conv1d_bhl.cu',
'conv1d/conv1d_blh.cu',
'conv1d/conv1d_bwd_cuda_bhl.cu',
'conv1d/conv1d_bwd_cuda_blh.cu',
],
extra_compile_args={'cxx': ['-O3'],
'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag)
})
CUDAExtension(
name=EXTENSION_NAME,
sources=[
'monarch.cpp',
'monarch_cuda/monarch_cuda_interface_fwd.cu',
'monarch_cuda/monarch_cuda_interface_fwd_complex.cu',
'monarch_cuda/monarch_cuda_interface_fwd_bf16.cu',
'monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu',
'monarch_cuda/monarch_cuda_interface_fwd_r2r.cu',
'monarch_cuda/monarch_cuda_interface_fwd_r2r_bf16.cu',
'monarch_cuda/monarch_cuda_interface_bwd.cu',
'monarch_cuda/monarch_cuda_interface_bwd_complex.cu',
'monarch_cuda/monarch_cuda_interface_bwd_bf16.cu',
'monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu',
'monarch_cuda/monarch_cuda_interface_bwd_r2r.cu',
'monarch_cuda/monarch_cuda_interface_bwd_r2r_bf16.cu',
'butterfly/butterfly_cuda.cu',
'butterfly/butterfly_padded_cuda.cu',
'butterfly/butterfly_padded_cuda_bf16.cu',
'butterfly/butterfly_ifft_cuda.cu',
'butterfly/butterfly_cuda_bf16.cu',
'butterfly/butterfly_ifft_cuda_bf16.cu',
'butterfly/butterfly_padded_ifft_cuda.cu',
'butterfly/butterfly_padded_ifft_cuda_bf16.cu',
'conv1d/conv1d_bhl.cu',
'conv1d/conv1d_blh.cu',
'conv1d/conv1d_bwd_cuda_bhl.cu',
'conv1d/conv1d_bwd_cuda_blh.cu',
],
extra_compile_args=(
{
'cxx': ['-O3'],
'nvcc': append_nvcc_threads(
['-O3', '-lineinfo', '--use_fast_math', '-std=c++17']
+ build_compiler_flags(ptx=True, sass=False, multi_arch=False)
),
}
),
),
],
cmdclass={
'build_ext': BuildExtension
},
cmdclass={'build_ext': BuildExtension},
version='0.0.0',
description='Fast FFT algorithms for convolutions',
url='https://github.com/HazyResearch/flash-fft-conv',
author='Dan Fu, Hermann Kumbong',
author_email='danfu@cs.stanford.edu',
license='Apache 2.0')
license='Apache 2.0'
)