From 09d5c9545a267cc0615ef8b410f0753b0d2ebdb4 Mon Sep 17 00:00:00 2001 From: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> Date: Thu, 15 Aug 2024 00:36:16 -0400 Subject: [PATCH] Reduces container size, fixes Tripy dependencies, bumps version (#62) - Greatly reduces the size of the development container by using a smaller base image and cutting out many unnecessary packages. - Adds TRT as a direct dependency of Tripy since in `plugin.py`, we need to load it to determine the major version of `libnvinfer_plugin`. - Removes CHANGELOG as we can automatically generate release notes via GitHub releases. - Removes JAX as a test dependency and the tests that use it. We were previously using it to test interoperability using DLPack, but we already test that with torch/cupy/numpy. Adding JAX provides no additional coverage and unnecessarily bloats the container. - Bumps version to 0.0.2 --- tripy/CHANGELOG.md | 6 ---- tripy/Dockerfile | 32 +++++----------------- tripy/pyproject.toml | 20 ++++++++------ tripy/tests/common/test_array.py | 9 ------ tripy/tests/common/test_utils.py | 21 +------------- tripy/tests/frontend/test_tensor.py | 7 ----- tripy/tests/integration/test_dequantize.py | 1 - tripy/tests/integration/test_functional.py | 23 ++-------------- tripy/tests/integration/test_quantize.py | 19 +++---------- tripy/tripy/__init__.py | 5 +++- tripy/tripy/flat_ir/ops/plugin.py | 7 +++-- 11 files changed, 35 insertions(+), 115 deletions(-) delete mode 100644 tripy/CHANGELOG.md diff --git a/tripy/CHANGELOG.md b/tripy/CHANGELOG.md deleted file mode 100644 index f3b49d204..000000000 --- a/tripy/CHANGELOG.md +++ /dev/null @@ -1,6 +0,0 @@ -# Tripy Change Log - -Dates are in YYYY-MM-DD format. - -## 0.0.1 () -- Initial Release diff --git a/tripy/Dockerfile b/tripy/Dockerfile index 96bfcb5dc..863f5484c 100644 --- a/tripy/Dockerfile +++ b/tripy/Dockerfile @@ -1,6 +1,6 @@ -FROM nvcr.io/nvidia/cuda:12.2.2-devel-ubuntu22.04 +FROM ubuntu:22.04 -LABEL org.opencontainers.image.description Tripy development container +LABEL org.opencontainers.image.description="Tripy development container" WORKDIR /tripy @@ -11,42 +11,24 @@ ARG uid=1000 ARG gid=1000 ENV DEBIAN_FRONTEND=noninteractive -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64/:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH - -# MPI is currently required for MLIR-TRT RUN groupadd -r -f -g ${gid} trtuser && \ useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash trtuser && \ usermod -aG sudo trtuser && \ echo 'trtuser:nvidia' | chpasswd && \ - mkdir -p /workspace && chown trtuser /workspace && \ - apt-get update && \ - apt-get install -y software-properties-common sudo fakeroot python3-pip gdb git wget libcudnn8 curl jq libopenmpi3 libopenmpi-dev && \ + mkdir -p /workspace && chown trtuser /workspace + +RUN apt-get update && \ + apt-get install -y sudo python3 python3-pip gdb git wget curl && \ apt-get clean && \ python3 -m pip install --upgrade pip -# Copy your .lldbinit file into the home directory of the root user COPY .lldbinit /root/ - -# Install the recommended version of TensorRT for development. -RUN cd /usr/lib/ && \ - wget -q https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.1.0/tars/TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz && \ - tar -xzf TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz && \ - rm TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz && \ - rm -rf /usr/lib/TensorRT-10.1.0.27/data/ /usr/lib/TensorRT-10.1.0.27/doc/ /usr/lib/TensorRT-10.1.0.27/samples /usr/lib/TensorRT-10.1.0.27/bin /usr/lib/TensorRT-10.1.0.27/python -ENV LD_LIBRARY_PATH=/usr/lib/TensorRT-10.1.0.27/lib/:$LD_LIBRARY_PATH - COPY pyproject.toml /tripy/pyproject.toml RUN pip install build .[docs,dev,test] \ - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ -f https://nvidia.github.io/TensorRT-Incubator/packages.html \ --extra-index-url https://download.pytorch.org/whl - -######################################## -# Configure mlir-tensorrt packages -######################################## - # Installl lldb for debugging purposes in Tripy container. # The LLVM version should correspond on LLVM_VERSION specified in https://github.com/NVIDIA/TensorRT-Incubator/blob/main/mlir-tensorrt/build_tools/docker/Dockerfile#L30. ARG LLVM_VERSION=17 @@ -64,4 +46,4 @@ RUN echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-$LLVM_VERSION main ln -s /usr/bin/lldb-17 /usr/bin/lldb # Export tripy into the PYTHONPATH so it doesn't need to be installed after making changes -ENV PYTHONPATH=/tripy:$PYTHONPATH +ENV PYTHONPATH=/tripy diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index 9b9207127..af413eeb8 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -1,14 +1,15 @@ [project] name = "tripy" -version = "0.0.1" +version = "0.0.2" authors = [{name = "NVIDIA", email="svc_tensorrt@nvidia.com"}] description = "Tripy: A Python Programming Model For TensorRT" readme = "README.md" requires-python = ">= 3.9" license = {text = "Apache 2.0"} dependencies = [ - "mlir-tensorrt-compiler==0.1.29+cuda12.trt102", - "mlir-tensorrt-runtime==0.1.29+cuda12.trt102", + "tensorrt~=10.0", + "mlir-tensorrt-compiler==0.1.31+cuda12.trt102", + "mlir-tensorrt-runtime==0.1.31+cuda12.trt102", "colored==2.2.3", ] @@ -25,10 +26,14 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] dev = [ - "jinja2==3.1.2", + "pre-commit==3.6.0", +] +doc_test_common = [ + "torch==2.4.0+cu121", "numpy==1.25.0", + # cupy requires NVRTC but does not specify it as a package dependency + "nvidia-cuda-nvrtc-cu12", "cupy-cuda12x", - "pre-commit==3.6.0", ] docs = [ "sphinx==7.2.6", @@ -38,16 +43,15 @@ docs = [ "docutils==0.20.1", "myst-parser==2.0.0", "sphinxcontrib-mermaid==0.9.2", - + "tripy[doc_test_common]", ] test = [ - "torch==2.0.0+cu118", "pytest==7.1.3", "pytest-virtualenv==1.7.0", "pytest-cov==4.1.0", - "jax[cuda12_local]==0.4.23", "coverage==7.4.1", "vulture==2.11", + "tripy[doc_test_common]", ] [tool.black] diff --git a/tripy/tests/common/test_array.py b/tripy/tests/common/test_array.py index b5b0b15d7..d0c4007d0 100644 --- a/tripy/tests/common/test_array.py +++ b/tripy/tests/common/test_array.py @@ -15,11 +15,8 @@ # limitations under the License. # -from typing import Any, List import cupy as cp -import jax -import jax.numpy as jnp import numpy as np import pytest import torch @@ -51,12 +48,6 @@ # Extend the data list for Torch GPU tensors data_list.extend([torch.tensor(data).to(torch.device("cuda")) for data in filter(torch_type_supported, np_data)]) -# Extend the data list for Jax CPU arrays -data_list.extend([jax.device_put(jnp.array(data), jax.devices("cpu")[0]) for data in np_data]) - -# Extend the data list for Jax GPU arrays -data_list.extend([jax.device_put(jnp.array(data), jax.devices("cuda")[0]) for data in np_data]) - class TestArray: @pytest.mark.parametrize("input_data", data_list, ids=lambda data: f"{type(data).__qualname__}") diff --git a/tripy/tests/common/test_utils.py b/tripy/tests/common/test_utils.py index a80ca8dad..8e5fafd55 100644 --- a/tripy/tests/common/test_utils.py +++ b/tripy/tests/common/test_utils.py @@ -21,8 +21,6 @@ from textwrap import dedent import cupy as cp -import numpy as np -import jax.numpy as jnp import torch import tripy.common.datatype @@ -89,18 +87,7 @@ def test_convert_frontend_dtype_to_tripy_dtype(): torch.float32: tripy.common.datatype.float32, } - JAX_TO_TRIPY = { - jnp.bool_: tripy.common.datatype.bool, - jnp.int8: tripy.common.datatype.int8, - jnp.int32: tripy.common.datatype.int32, - jnp.int64: tripy.common.datatype.int64, - jnp.float8_e4m3fn: tripy.common.datatype.float8, - jnp.float16: tripy.common.datatype.float16, - jnp.bfloat16: tripy.common.datatype.bfloat16, - jnp.float32: tripy.common.datatype.float32, - } - - FRONTEND_TO_TRIPY = dict(ChainMap(PYTHON_NATIVE_TO_TRIPY, NUMPY_TO_TRIPY, TORCH_TO_TRIPY, JAX_TO_TRIPY)) + FRONTEND_TO_TRIPY = dict(ChainMap(PYTHON_NATIVE_TO_TRIPY, NUMPY_TO_TRIPY, TORCH_TO_TRIPY)) for frontend_type, tripy_type in FRONTEND_TO_TRIPY.items(): assert convert_frontend_dtype_to_tripy_dtype(frontend_type) == tripy_type @@ -119,12 +106,6 @@ def test_convert_frontend_dtype_to_tripy_dtype(): cp.float64, torch.int16, torch.float64, - jnp.int4, - jnp.int16, - jnp.uint16, - jnp.uint32, - jnp.uint64, - jnp.float64, ]: with helper.raises( TripyException, diff --git a/tripy/tests/frontend/test_tensor.py b/tripy/tests/frontend/test_tensor.py index 007667062..d1edc726b 100644 --- a/tripy/tests/frontend/test_tensor.py +++ b/tripy/tests/frontend/test_tensor.py @@ -19,7 +19,6 @@ import sys import cupy as cp -import jax import numpy as np import pytest import torch @@ -169,12 +168,6 @@ def test_dlpack_torch(self, kind): b = torch.from_dlpack(a) assert torch.equal(b.cpu(), torch.tensor([1, 2, 3])) - @pytest.mark.parametrize("kind", ["cpu", "gpu"]) - def test_dlpack_jax(self, kind): - a = tp.Tensor([1, 2, 3], device=tp.device(kind)) - b = jax.dlpack.from_dlpack(a) - assert jax.numpy.array_equal(b, jax.numpy.array([1, 2, 3])) - def test_stack_depth_sanity(self): # Makes sure STACK_DEPTH_OF_BUILD is correct a = tp.ones((2, 3)) diff --git a/tripy/tests/integration/test_dequantize.py b/tripy/tests/integration/test_dequantize.py index 7223e3429..7f95049dc 100644 --- a/tripy/tests/integration/test_dequantize.py +++ b/tripy/tests/integration/test_dequantize.py @@ -16,7 +16,6 @@ # import cupy as cp -import jax.numpy as jnp import numpy as np import pytest import torch diff --git a/tripy/tests/integration/test_functional.py b/tripy/tests/integration/test_functional.py index c199954c2..42877be57 100644 --- a/tripy/tests/integration/test_functional.py +++ b/tripy/tests/integration/test_functional.py @@ -16,8 +16,6 @@ # import cupy as cp -import jax -import jax.numpy as jnp import numpy as np import pytest import torch @@ -87,15 +85,8 @@ def _test_framework_interoperability(self, data, device): else: b = tp.Tensor(torch.tensor(data), device=device) - if device.kind == "gpu": - if isinstance(data, cp.ndarray): - data = data.get() - c = tp.Tensor(jax.device_put(jnp.array(data), jax.devices("gpu")[0]), device=device) - else: - c = tp.Tensor(jax.device_put(jnp.array(data), jax.devices("cpu")[0]), device=device) - - out = a + b + c - assert (cp.from_dlpack(out).get() == np.array([3.0, 3.0], dtype=np.float32)).all() + out = a + b + assert (cp.from_dlpack(out).get() == np.array([2.0, 2.0], dtype=np.float32)).all() def test_cpu_and_gpu_framework_interoperability(self): self._test_framework_interoperability(np.ones(2, np.float32), device=tp.device("cpu")) @@ -125,16 +116,6 @@ def _test_round_tripping(self, data, device): # Below fails as we do allocate a new np array from Torch tensor data. # assert torch_data_round_tripped.data_ptr == torch_data.data_ptr - # Assert round-tripping for Jax data - if device.kind == "gpu": - if isinstance(data, cp.ndarray): - data = data.get() - jax_orig = jax.device_put(jnp.array(data), jax.devices("gpu")[0]) - jax_round_tripped = jnp.array(cp.from_dlpack(tp.Tensor(jax_orig, device=device)).get()) - else: - jax_orig = jax.device_put(jnp.array(data), jax.devices("cpu")[0]) - jax_round_tripped = jnp.array(np.from_dlpack(tp.Tensor(jax_orig, device=device))) - assert jnp.array_equal(jax_round_tripped, jax_orig) # (39): Remove explicit CPU to GPU copies. Add memory pointer checks. # Figure out how to compare two Jax data memory pointers. diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index 5fc4856ab..8a918dcea 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -16,15 +16,12 @@ # import cupy as cp -import jax.numpy as jnp -import jaxlib import numpy as np import pytest import re import torch import tripy as tp -from tripy import TripyException from tests.helper import raises, TORCH_DTYPES from tests.conftest import skip_if_older_than_sm80, skip_if_older_than_sm89 @@ -68,13 +65,9 @@ def test_quantize_fp8_per_tensor(self, dtype): expected = (input / scale).to(dtype=torch.float32) with raises( Exception, - match=re.escape( - "UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits" - ), + match=re.escape("UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits"), ): - assert torch.equal( - expected, torch.from_dlpack(jnp.from_dlpack(quantized)).to(dtype=torch.float32).to("cpu") - ) + assert torch.equal(expected, torch.from_dlpack(quantized).to(dtype=torch.float32).to("cpu")) assert torch.equal(expected, torch.from_dlpack(tp.cast(quantized, dtype=tp.float32)).to("cpu")) @pytest.mark.parametrize( @@ -91,13 +84,9 @@ def test_quantize_fp8_per_channel(self, dtype): expected = (input / scale.reshape(2, 1)).to(dtype=torch.float32) with raises( Exception, - match=re.escape( - "UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits" - ), + match=re.escape("UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits"), ): - assert torch.equal( - expected, torch.from_dlpack(jnp.from_dlpack(quantized)).to(dtype=torch.float32).to("cpu") - ) + assert torch.equal(expected, torch.from_dlpack(quantized).to(dtype=torch.float32).to("cpu")) assert torch.equal(expected, torch.from_dlpack(tp.cast(quantized, dtype=tp.float32)).to("cpu")) @pytest.mark.parametrize( diff --git a/tripy/tripy/__init__.py b/tripy/tripy/__init__.py index 423088cca..de006d885 100644 --- a/tripy/tripy/__init__.py +++ b/tripy/tripy/__init__.py @@ -15,7 +15,10 @@ # limitations under the License. # -__version__ = "0.0.1" +__version__ = "0.0.2" + +# Import TensorRT to make sure all dependent libraries are loaded first. +import tensorrt # export.public_api() will expose things here. To make sure that happens, we just need to # import all the submodules so that the decorator is actually executed. diff --git a/tripy/tripy/flat_ir/ops/plugin.py b/tripy/tripy/flat_ir/ops/plugin.py index 5d89cb9f7..24241ad2a 100644 --- a/tripy/tripy/flat_ir/ops/plugin.py +++ b/tripy/tripy/flat_ir/ops/plugin.py @@ -25,15 +25,18 @@ from mlir_tensorrt.compiler.dialects import tensorrt from tripy import utils -from tripy.common.exception import raise_error from tripy.flat_ir.ops.base import BaseFlatIROp from tripy.utils import Result @utils.call_once def initialize_plugin_registry(): + import tensorrt as trt + + major_version, _, _ = trt.__version__.partition(".") + # TODO (#191): Make this work on Windows too - handle = ctypes.CDLL("libnvinfer_plugin.so") + handle = ctypes.CDLL(f"libnvinfer_plugin.so.{major_version}") handle.initLibNvInferPlugins(None, "")