Skip to content

Commit

Permalink
Add basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
thodson-usgs committed Feb 5, 2024
1 parent d91fe45 commit 9b813bf
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 1 deletion.
10 changes: 10 additions & 0 deletions docs/bitinfo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
PCodec
======

.. automodule:: numcodecs.bitinfo

.. autoclass:: BitInfo

.. autoattribute:: codec_id
.. automethod:: encode
.. automethod:: decode
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Contents
delta
fixedscaleoffset
quantize
bitinfo
bitround
packbits
categorize
Expand Down
2 changes: 2 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Unreleased
Enhancements
~~~~~~~~~~~~

* Add BitInfo codec
By :user:`Tim Hodson <thodson-usgs>`.
* Use PyData theme for docs
By :user:`John Kirkham <jakirkham>`, :issue:`485`.

Expand Down
9 changes: 9 additions & 0 deletions numcodecs/bitinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def exponent_mask(dtype):
mask = 0x7F80_0000
elif dtype == np.float64:
mask = 0x7FF0_0000_0000_0000
else:
raise ValueError(f"Unsupported dtype {dtype}")
return mask


Expand Down Expand Up @@ -175,6 +177,7 @@ def signed_exponent(A):
def bitpaircount_u1(a, b):
assert a.dtype == "u1"
assert b.dtype == "u1"

unpack_a = np.unpackbits(a.flatten()).astype("u1")
unpack_b = np.unpackbits(b.flatten()).astype("u1")

Expand All @@ -188,6 +191,7 @@ def bitpaircount_u1(a, b):
def bitpaircount(a, b):
assert a.dtype.kind == "u"
assert b.dtype.kind == "u"

nbytes = max(a.dtype.itemsize, b.dtype.itemsize)

a, b = np.broadcast_arrays(a, b)
Expand All @@ -203,6 +207,9 @@ def bitpaircount(a, b):
def mutual_information(a, b, base=2):
"""Calculate the mutual information between two arrays.
"""
assert a.dtype == b.dtype
assert a.dtype.kind == "u"

size = np.prod(np.broadcast_shapes(a.shape, b.shape))
counts = bitpaircount(a, b)

Expand All @@ -228,6 +235,8 @@ def bitinformation(a, axis=0):
-------
info_per_bit : array
"""
assert a.dtype.kind == "u"

sa = tuple(slice(0, -1) if i == axis else slice(None) for i in range(len(a.shape)))
sb = tuple(
slice(1, None) if i == axis else slice(None) for i in range(len(a.shape))
Expand Down
18 changes: 17 additions & 1 deletion numcodecs/bitround.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,23 @@ def decode(self, buf, out=None):
return ndarray_copy(data, out)

@staticmethod
def bitround(buf, keepbits, dtype):
def bitround(buf, keepbits: int, dtype):
"""Drop bits from the mantissa of a floating point array
Parameters
----------
buf: ndarray
The input array
keepbits: int
The number of bits to keep
dtype: dtype
The dtype of the input array
Returns
-------
ndarray
The bitrounded array transformed to an integer type
"""
bits = max_bits[str(dtype)]
a_int_dtype = np.dtype(buf.dtype.str.replace("f", "i"))
all_set = np.array(-1, dtype=a_int_dtype)
Expand Down
73 changes: 73 additions & 0 deletions numcodecs/tests/test_bitinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np

import pytest

from numcodecs.bitinfo import BitInfo, exponent_bias, mutual_information

def test_bitinfo_initialization():
bitinfo = BitInfo(0.5)
assert bitinfo.info_level == 0.5
assert bitinfo.axes is None

bitinfo = BitInfo(0.5, axes=1)
assert bitinfo.axes == [1]

bitinfo = BitInfo(0.5, axes=[1, 2])
assert bitinfo.axes == [1, 2]

with pytest.raises(ValueError):
BitInfo(-0.1)

with pytest.raises(ValueError):
BitInfo(1.1)

with pytest.raises(ValueError):
BitInfo(0.5, axes=1.5)

with pytest.raises(ValueError):
BitInfo(0.5, axes=[1, 1.5])


def test_bitinfo_encode():
bitinfo = BitInfo(info_level=0.5)
a = np.array([1.0, 2.0, 3.0], dtype="float32")
encoded = bitinfo.encode(a)
decoded = bitinfo.decode(encoded)
assert decoded.dtype == a.dtype


def test_bitinfo_encode_errors():
bitinfo = BitInfo(0.5)
a = np.array([1, 2, 3], dtype="int32")
with pytest.raises(TypeError):
bitinfo.encode(a)

a = np.array([1.0, 2.0, 3.0], dtype="float128")
with pytest.raises(TypeError):
bitinfo.encode(a)


def test_exponent_bias():
assert exponent_bias("f2") == 15
assert exponent_bias("f4") == 127
assert exponent_bias("f8") == 1023

with pytest.raises(ValueError):
exponent_bias("int32")


def test_mutual_information():
""" Test mutual information calculation
Tests for changes to the mutual_information
but not the correcteness of the original.
"""
a = np.arange(10.0, dtype='float32')
b = a + 1000
c = a[::-1].copy()
dt = np.dtype('uint32')
a,b,c = map(lambda x: x.view(dt), [a,b,c])

assert mutual_information(a, a).sum() == 7.020411549771797
assert mutual_information(a, b).sum() == 0.0
assert mutual_information(a, c).sum() == 0.6545015579460758

0 comments on commit 9b813bf

Please sign in to comment.