Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the out parameter to avoid malloc of output array #147

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 56 additions & 40 deletions bitshuffle/ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,28 @@ def using_AVX512():
return False


def _setup_arr(arr):
def _make_array(shape, dtype, out=None):
if out is None:
out = np.empty(shape, dtype=dtype)
else:
size = np.prod(shape)
if out.base is None:
base = out
else:
base = out.base
out = base.ravel().view(dtype)[:size].reshape(shape)
return out


def _setup_arr(arr, out=None):
shape = tuple(arr.shape)
if not arr.flags['C_CONTIGUOUS']:
msg = "Input array must be C-contiguous."
raise ValueError(msg)
size = arr.size
dtype = arr.dtype
itemsize = dtype.itemsize
out = np.empty(shape, dtype=dtype)
out = _make_array(shape, dtype, out)
return out, size, itemsize


Expand Down Expand Up @@ -295,7 +308,7 @@ def untrans_bit_elem(np.ndarray arr not None):

@cython.boundscheck(False)
@cython.wraparound(False)
def bitshuffle(np.ndarray arr not None, int block_size=0):
def bitshuffle(np.ndarray arr not None, int block_size=0, np.ndarray out=None):
"""Bitshuffle an array.

Output array is the same shape and data type as input array but underlying
Expand All @@ -304,11 +317,12 @@ def bitshuffle(np.ndarray arr not None, int block_size=0):
Parameters
----------
arr : numpy array
Data to ne processed.
Data to be processed.
block_size : positive integer
Block size in number of elements. By default, block size is chosen
automatically.

out : numpy array
Already allocated array to put the results in
Returns
-------
out : numpy array
Expand All @@ -318,8 +332,7 @@ def bitshuffle(np.ndarray arr not None, int block_size=0):
"""

cdef int ii, size, itemsize, count=0
cdef np.ndarray out
out, size, itemsize = _setup_arr(arr)
out, size, itemsize = _setup_arr(arr, out)

cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] arr_flat
arr_flat = arr.view(np.uint8).ravel()
Expand All @@ -340,7 +353,7 @@ def bitshuffle(np.ndarray arr not None, int block_size=0):

@cython.boundscheck(False)
@cython.wraparound(False)
def bitunshuffle(np.ndarray arr not None, int block_size=0):
def bitunshuffle(np.ndarray arr not None, int block_size=0, np.ndarray out=None):
"""Bitshuffle an array.

Output array is the same shape and data type as input array but underlying
Expand All @@ -352,7 +365,9 @@ def bitunshuffle(np.ndarray arr not None, int block_size=0):
Data to ne processed.
block_size : positive integer
Block size in number of elements. Must match value used for shuffling.

out : numpy array
Already allocated array to put the results in

Returns
-------
out : numpy array
Expand All @@ -362,8 +377,7 @@ def bitunshuffle(np.ndarray arr not None, int block_size=0):
"""

cdef int ii, size, itemsize, count=0
cdef np.ndarray out
out, size, itemsize = _setup_arr(arr)
out, size, itemsize = _setup_arr(arr, out=out)

cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] arr_flat
arr_flat = arr.view(np.uint8).ravel()
Expand All @@ -384,7 +398,7 @@ def bitunshuffle(np.ndarray arr not None, int block_size=0):

@cython.boundscheck(False)
@cython.wraparound(False)
def compress_lz4(np.ndarray arr not None, int block_size=0):
def compress_lz4(np.ndarray arr not None, int block_size=0, np.ndarray out=None):
"""Bitshuffle then compress an array using LZ4.

Parameters
Expand All @@ -394,7 +408,9 @@ def compress_lz4(np.ndarray arr not None, int block_size=0):
block_size : positive integer
Block size in number of elements. By default, block size is chosen
automatically.

out : numpy array
Already allocated array to put the results in.

Returns
-------
out : array with np.uint8 data type
Expand All @@ -413,8 +429,7 @@ def compress_lz4(np.ndarray arr not None, int block_size=0):

max_out_size = bshuf_compress_lz4_bound(size, itemsize, block_size)

cdef np.ndarray out
out = np.empty(max_out_size, dtype=np.uint8)
out = _make_array((max_out_size,), dtype=np.uint8, out=out)

cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] arr_flat
arr_flat = arr.view(np.uint8).ravel()
Expand All @@ -434,7 +449,7 @@ def compress_lz4(np.ndarray arr not None, int block_size=0):

@cython.boundscheck(False)
@cython.wraparound(False)
def decompress_lz4(np.ndarray arr not None, shape, dtype, int block_size=0):
def decompress_lz4(np.ndarray arr not None, shape, dtype, int block_size=0, np.ndarray out=None):
"""Decompress a buffer using LZ4 then bitunshuffle it yielding an array.

Parameters
Expand All @@ -450,7 +465,9 @@ def decompress_lz4(np.ndarray arr not None, shape, dtype, int block_size=0):
block_size : positive integer
Block size in number of elements. Must match value used for
compression.

out : numpy array
Already allocated array to put the results in

Returns
-------
out : numpy array with shape *shape* and data type *dtype*
Expand All @@ -464,9 +481,8 @@ def decompress_lz4(np.ndarray arr not None, shape, dtype, int block_size=0):
raise ValueError(msg)
size = np.prod(shape)
itemsize = dtype.itemsize

cdef np.ndarray out
out = np.empty(tuple(shape), dtype=dtype)

out = _make_array(tuple(shape), dtype=dtype, out=out)

cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] arr_flat
arr_flat = arr.view(np.uint8).ravel()
Expand All @@ -492,9 +508,9 @@ def decompress_lz4(np.ndarray arr not None, shape, dtype, int block_size=0):
IF ZSTD_SUPPORT:
@cython.boundscheck(False)
@cython.wraparound(False)
def compress_zstd(np.ndarray arr not None, int block_size=0, int comp_lvl=1):
def compress_zstd(np.ndarray arr not None, int block_size=0, int comp_lvl=1, np.ndarray out=None):
"""Bitshuffle then compress an array using ZSTD.

Parameters
----------
arr : numpy array
Expand All @@ -504,14 +520,15 @@ IF ZSTD_SUPPORT:
automatically.
comp_lvl : positive integer
Compression level applied by ZSTD

out : numpy array
Already allocated array to put the results in
Returns
-------
out : array with np.uint8 data type
Buffer holding compressed data.

"""

cdef int ii, size, itemsize, count=0
shape = (arr.shape[i] for i in range(arr.ndim))
if not arr.flags['C_CONTIGUOUS']:
Expand All @@ -520,12 +537,11 @@ IF ZSTD_SUPPORT:
size = arr.size
dtype = arr.dtype
itemsize = dtype.itemsize

max_out_size = bshuf_compress_zstd_bound(size, itemsize, block_size)

cdef np.ndarray out
out = np.empty(max_out_size, dtype=np.uint8)


out = _make_array((max_out_size,), dtype=np.uint8, out=out)

cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] arr_flat
arr_flat = arr.view(np.uint8).ravel()
cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] out_flat
Expand All @@ -540,12 +556,12 @@ IF ZSTD_SUPPORT:
excp = RuntimeError(msg % count, count)
raise excp
return out[:count]

@cython.boundscheck(False)
@cython.wraparound(False)
def decompress_zstd(np.ndarray arr not None, shape, dtype, int block_size=0):
def decompress_zstd(np.ndarray arr not None, shape, dtype, int block_size=0, np.ndarray out=None):
"""Decompress a buffer using ZSTD then bitunshuffle it yielding an array.

Parameters
----------
arr : numpy array
Expand All @@ -559,24 +575,24 @@ IF ZSTD_SUPPORT:
block_size : positive integer
Block size in number of elements. Must match value used for
compression.

out : numpy array
Already allocated array to put the results in
Returns
-------
out : numpy array with shape *shape* and data type *dtype*
Decompressed data.

"""

cdef int ii, size, itemsize, count=0
if not arr.flags['C_CONTIGUOUS']:
msg = "Input array must be C-contiguous."
raise ValueError(msg)
size = np.prod(shape)
itemsize = dtype.itemsize

cdef np.ndarray out
out = np.empty(tuple(shape), dtype=dtype)


out = _make_array(tuple(shape), dtype=dtype, out=out)

cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] arr_flat
arr_flat = arr.view(np.uint8).ravel()
cdef np.ndarray[dtype=np.uint8_t, ndim=1, mode="c"] out_flat
Expand Down
Loading