Skip to content

Commit

Permalink
Reduces container size, fixes Tripy dependencies, bumps version (#62)
Browse files Browse the repository at this point in the history
- Greatly reduces the size of the development container by using a
smaller base image and cutting out many unnecessary packages.
    
- Adds TRT as a direct dependency of Tripy since in `plugin.py`, we need
to load it to determine the major version of `libnvinfer_plugin`.
    
- Removes CHANGELOG as we can automatically generate release notes via
GitHub releases.
    
- Removes JAX as a test dependency and the tests that use it. We were
previously using it to test interoperability using DLPack, but we
already test that with torch/cupy/numpy. Adding JAX provides no
additional coverage and unnecessarily bloats the container.

- Bumps version to 0.0.2
  • Loading branch information
pranavm-nvidia authored Aug 15, 2024
1 parent 5329679 commit 09d5c95
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 115 deletions.
6 changes: 0 additions & 6 deletions tripy/CHANGELOG.md

This file was deleted.

32 changes: 7 additions & 25 deletions tripy/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM nvcr.io/nvidia/cuda:12.2.2-devel-ubuntu22.04
FROM ubuntu:22.04

LABEL org.opencontainers.image.description Tripy development container
LABEL org.opencontainers.image.description="Tripy development container"

WORKDIR /tripy

Expand All @@ -11,42 +11,24 @@ ARG uid=1000
ARG gid=1000
ENV DEBIAN_FRONTEND=noninteractive

ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64/:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH

# MPI is currently required for MLIR-TRT
RUN groupadd -r -f -g ${gid} trtuser && \
useradd -o -r -l -u ${uid} -g ${gid} -ms /bin/bash trtuser && \
usermod -aG sudo trtuser && \
echo 'trtuser:nvidia' | chpasswd && \
mkdir -p /workspace && chown trtuser /workspace && \
apt-get update && \
apt-get install -y software-properties-common sudo fakeroot python3-pip gdb git wget libcudnn8 curl jq libopenmpi3 libopenmpi-dev && \
mkdir -p /workspace && chown trtuser /workspace

RUN apt-get update && \
apt-get install -y sudo python3 python3-pip gdb git wget curl && \
apt-get clean && \
python3 -m pip install --upgrade pip

# Copy your .lldbinit file into the home directory of the root user
COPY .lldbinit /root/

# Install the recommended version of TensorRT for development.
RUN cd /usr/lib/ && \
wget -q https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.1.0/tars/TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz && \
tar -xzf TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz && \
rm TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz && \
rm -rf /usr/lib/TensorRT-10.1.0.27/data/ /usr/lib/TensorRT-10.1.0.27/doc/ /usr/lib/TensorRT-10.1.0.27/samples /usr/lib/TensorRT-10.1.0.27/bin /usr/lib/TensorRT-10.1.0.27/python
ENV LD_LIBRARY_PATH=/usr/lib/TensorRT-10.1.0.27/lib/:$LD_LIBRARY_PATH

COPY pyproject.toml /tripy/pyproject.toml

RUN pip install build .[docs,dev,test] \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
-f https://nvidia.github.io/TensorRT-Incubator/packages.html \
--extra-index-url https://download.pytorch.org/whl


########################################
# Configure mlir-tensorrt packages
########################################

# Installl lldb for debugging purposes in Tripy container.
# The LLVM version should correspond on LLVM_VERSION specified in https://github.com/NVIDIA/TensorRT-Incubator/blob/main/mlir-tensorrt/build_tools/docker/Dockerfile#L30.
ARG LLVM_VERSION=17
Expand All @@ -64,4 +46,4 @@ RUN echo "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-$LLVM_VERSION main
ln -s /usr/bin/lldb-17 /usr/bin/lldb

# Export tripy into the PYTHONPATH so it doesn't need to be installed after making changes
ENV PYTHONPATH=/tripy:$PYTHONPATH
ENV PYTHONPATH=/tripy
20 changes: 12 additions & 8 deletions tripy/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
[project]
name = "tripy"
version = "0.0.1"
version = "0.0.2"
authors = [{name = "NVIDIA", email="svc_tensorrt@nvidia.com"}]
description = "Tripy: A Python Programming Model For TensorRT"
readme = "README.md"
requires-python = ">= 3.9"
license = {text = "Apache 2.0"}
dependencies = [
"mlir-tensorrt-compiler==0.1.29+cuda12.trt102",
"mlir-tensorrt-runtime==0.1.29+cuda12.trt102",
"tensorrt~=10.0",
"mlir-tensorrt-compiler==0.1.31+cuda12.trt102",
"mlir-tensorrt-runtime==0.1.31+cuda12.trt102",
"colored==2.2.3",
]

Expand All @@ -25,10 +26,14 @@ build-backend = "setuptools.build_meta"

[project.optional-dependencies]
dev = [
"jinja2==3.1.2",
"pre-commit==3.6.0",
]
doc_test_common = [
"torch==2.4.0+cu121",
"numpy==1.25.0",
# cupy requires NVRTC but does not specify it as a package dependency
"nvidia-cuda-nvrtc-cu12",
"cupy-cuda12x",
"pre-commit==3.6.0",
]
docs = [
"sphinx==7.2.6",
Expand All @@ -38,16 +43,15 @@ docs = [
"docutils==0.20.1",
"myst-parser==2.0.0",
"sphinxcontrib-mermaid==0.9.2",

"tripy[doc_test_common]",
]
test = [
"torch==2.0.0+cu118",
"pytest==7.1.3",
"pytest-virtualenv==1.7.0",
"pytest-cov==4.1.0",
"jax[cuda12_local]==0.4.23",
"coverage==7.4.1",
"vulture==2.11",
"tripy[doc_test_common]",
]

[tool.black]
Expand Down
9 changes: 0 additions & 9 deletions tripy/tests/common/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
# limitations under the License.
#

from typing import Any, List

import cupy as cp
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -51,12 +48,6 @@
# Extend the data list for Torch GPU tensors
data_list.extend([torch.tensor(data).to(torch.device("cuda")) for data in filter(torch_type_supported, np_data)])

# Extend the data list for Jax CPU arrays
data_list.extend([jax.device_put(jnp.array(data), jax.devices("cpu")[0]) for data in np_data])

# Extend the data list for Jax GPU arrays
data_list.extend([jax.device_put(jnp.array(data), jax.devices("cuda")[0]) for data in np_data])


class TestArray:
@pytest.mark.parametrize("input_data", data_list, ids=lambda data: f"{type(data).__qualname__}")
Expand Down
21 changes: 1 addition & 20 deletions tripy/tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from textwrap import dedent

import cupy as cp
import numpy as np
import jax.numpy as jnp
import torch

import tripy.common.datatype
Expand Down Expand Up @@ -89,18 +87,7 @@ def test_convert_frontend_dtype_to_tripy_dtype():
torch.float32: tripy.common.datatype.float32,
}

JAX_TO_TRIPY = {
jnp.bool_: tripy.common.datatype.bool,
jnp.int8: tripy.common.datatype.int8,
jnp.int32: tripy.common.datatype.int32,
jnp.int64: tripy.common.datatype.int64,
jnp.float8_e4m3fn: tripy.common.datatype.float8,
jnp.float16: tripy.common.datatype.float16,
jnp.bfloat16: tripy.common.datatype.bfloat16,
jnp.float32: tripy.common.datatype.float32,
}

FRONTEND_TO_TRIPY = dict(ChainMap(PYTHON_NATIVE_TO_TRIPY, NUMPY_TO_TRIPY, TORCH_TO_TRIPY, JAX_TO_TRIPY))
FRONTEND_TO_TRIPY = dict(ChainMap(PYTHON_NATIVE_TO_TRIPY, NUMPY_TO_TRIPY, TORCH_TO_TRIPY))

for frontend_type, tripy_type in FRONTEND_TO_TRIPY.items():
assert convert_frontend_dtype_to_tripy_dtype(frontend_type) == tripy_type
Expand All @@ -119,12 +106,6 @@ def test_convert_frontend_dtype_to_tripy_dtype():
cp.float64,
torch.int16,
torch.float64,
jnp.int4,
jnp.int16,
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.float64,
]:
with helper.raises(
TripyException,
Expand Down
7 changes: 0 additions & 7 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import sys

import cupy as cp
import jax
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -169,12 +168,6 @@ def test_dlpack_torch(self, kind):
b = torch.from_dlpack(a)
assert torch.equal(b.cpu(), torch.tensor([1, 2, 3]))

@pytest.mark.parametrize("kind", ["cpu", "gpu"])
def test_dlpack_jax(self, kind):
a = tp.Tensor([1, 2, 3], device=tp.device(kind))
b = jax.dlpack.from_dlpack(a)
assert jax.numpy.array_equal(b, jax.numpy.array([1, 2, 3]))

def test_stack_depth_sanity(self):
# Makes sure STACK_DEPTH_OF_BUILD is correct
a = tp.ones((2, 3))
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/integration/test_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

import cupy as cp
import jax.numpy as jnp
import numpy as np
import pytest
import torch
Expand Down
23 changes: 2 additions & 21 deletions tripy/tests/integration/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#

import cupy as cp
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -87,15 +85,8 @@ def _test_framework_interoperability(self, data, device):
else:
b = tp.Tensor(torch.tensor(data), device=device)

if device.kind == "gpu":
if isinstance(data, cp.ndarray):
data = data.get()
c = tp.Tensor(jax.device_put(jnp.array(data), jax.devices("gpu")[0]), device=device)
else:
c = tp.Tensor(jax.device_put(jnp.array(data), jax.devices("cpu")[0]), device=device)

out = a + b + c
assert (cp.from_dlpack(out).get() == np.array([3.0, 3.0], dtype=np.float32)).all()
out = a + b
assert (cp.from_dlpack(out).get() == np.array([2.0, 2.0], dtype=np.float32)).all()

def test_cpu_and_gpu_framework_interoperability(self):
self._test_framework_interoperability(np.ones(2, np.float32), device=tp.device("cpu"))
Expand Down Expand Up @@ -125,16 +116,6 @@ def _test_round_tripping(self, data, device):
# Below fails as we do allocate a new np array from Torch tensor data.
# assert torch_data_round_tripped.data_ptr == torch_data.data_ptr

# Assert round-tripping for Jax data
if device.kind == "gpu":
if isinstance(data, cp.ndarray):
data = data.get()
jax_orig = jax.device_put(jnp.array(data), jax.devices("gpu")[0])
jax_round_tripped = jnp.array(cp.from_dlpack(tp.Tensor(jax_orig, device=device)).get())
else:
jax_orig = jax.device_put(jnp.array(data), jax.devices("cpu")[0])
jax_round_tripped = jnp.array(np.from_dlpack(tp.Tensor(jax_orig, device=device)))
assert jnp.array_equal(jax_round_tripped, jax_orig)
# (39): Remove explicit CPU to GPU copies. Add memory pointer checks.
# Figure out how to compare two Jax data memory pointers.

Expand Down
19 changes: 4 additions & 15 deletions tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
#

import cupy as cp
import jax.numpy as jnp
import jaxlib
import numpy as np
import pytest
import re
import torch

import tripy as tp
from tripy import TripyException
from tests.helper import raises, TORCH_DTYPES
from tests.conftest import skip_if_older_than_sm80, skip_if_older_than_sm89

Expand Down Expand Up @@ -68,13 +65,9 @@ def test_quantize_fp8_per_tensor(self, dtype):
expected = (input / scale).to(dtype=torch.float32)
with raises(
Exception,
match=re.escape(
"UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits"
),
match=re.escape("UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits"),
):
assert torch.equal(
expected, torch.from_dlpack(jnp.from_dlpack(quantized)).to(dtype=torch.float32).to("cpu")
)
assert torch.equal(expected, torch.from_dlpack(quantized).to(dtype=torch.float32).to("cpu"))
assert torch.equal(expected, torch.from_dlpack(tp.cast(quantized, dtype=tp.float32)).to("cpu"))

@pytest.mark.parametrize(
Expand All @@ -91,13 +84,9 @@ def test_quantize_fp8_per_channel(self, dtype):
expected = (input / scale.reshape(2, 1)).to(dtype=torch.float32)
with raises(
Exception,
match=re.escape(
"UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits"
),
match=re.escape("UNIMPLEMENTED: Invalid or unsupported DLPack float width: 8 bits"),
):
assert torch.equal(
expected, torch.from_dlpack(jnp.from_dlpack(quantized)).to(dtype=torch.float32).to("cpu")
)
assert torch.equal(expected, torch.from_dlpack(quantized).to(dtype=torch.float32).to("cpu"))
assert torch.equal(expected, torch.from_dlpack(tp.cast(quantized, dtype=tp.float32)).to("cpu"))

@pytest.mark.parametrize(
Expand Down
5 changes: 4 additions & 1 deletion tripy/tripy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# limitations under the License.
#

__version__ = "0.0.1"
__version__ = "0.0.2"

# Import TensorRT to make sure all dependent libraries are loaded first.
import tensorrt

# export.public_api() will expose things here. To make sure that happens, we just need to
# import all the submodules so that the decorator is actually executed.
Expand Down
7 changes: 5 additions & 2 deletions tripy/tripy/flat_ir/ops/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
from mlir_tensorrt.compiler.dialects import tensorrt

from tripy import utils
from tripy.common.exception import raise_error
from tripy.flat_ir.ops.base import BaseFlatIROp
from tripy.utils import Result


@utils.call_once
def initialize_plugin_registry():
import tensorrt as trt

major_version, _, _ = trt.__version__.partition(".")

# TODO (#191): Make this work on Windows too
handle = ctypes.CDLL("libnvinfer_plugin.so")
handle = ctypes.CDLL(f"libnvinfer_plugin.so.{major_version}")
handle.initLibNvInferPlugins(None, "")


Expand Down

0 comments on commit 09d5c95

Please sign in to comment.