Skip to content

Commit

Permalink
most tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mavaylon1 committed Mar 17, 2024
1 parent b9fc2a6 commit 0d2229c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 50 deletions.
31 changes: 17 additions & 14 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ class VectorIndex(VectorData):
allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
target = popargs('target', kwargs)
kwargs['description'] = "Index for VectorData '%s'" % target.name
super().__init__(**kwargs)
self.target = target
self.__uint = np.uint8
self.__maxval = 255
if isinstance(self.data, (list, np.ndarray)):
if len(self.data) > 0:
self.__check_precision(len(self.target))
if isinstance(kwargs['data'], (list, np.ndarray)):
if len(kwargs['data']) > 0:
self.__check_precision(len(target))
# adjust precision for types that we can adjust precision for
self.__adjust_precision(self.__uint)
kwargs['data'] = self.__adjust_precision(self.__uint, kwargs['data'])
kwargs['description'] = "Index for VectorData '%s'" % target.name
super().__init__(**kwargs)
self.target = target

def add_vector(self, arg, **kwargs):
"""
Expand Down Expand Up @@ -143,22 +143,24 @@ def __check_precision(self, idx):
raise ValueError(msg)
self.__maxval = 2 ** nbits - 1
self.__uint = np.dtype('uint%d' % nbits).type
self.__adjust_precision(self.__uint)
# self.__adjust_precision(self.__uint) #TODO: Cannot adjust when wrapped with H5DataIO
return self.__uint(idx)

def __adjust_precision(self, uint):
def __adjust_precision(self, uint, data):
"""
Adjust precision of data to specified unsigned integer precision.
"""
if isinstance(self.data, list):
for i in range(len(self.data)):
self.data[i] = uint(self.data[i])
elif isinstance(self.data, np.ndarray):
if isinstance(data, list):
for i in range(len(data)):
data[i] = uint(data[i])
elif isinstance(data, np.ndarray):
# use self._Data__data to work around restriction on resetting self.data
self._Data__data = self.data.astype(uint)
data = data.astype(uint)
else:
raise ValueError("cannot adjust precision of type %s to %s", (type(self.data), uint))

return data

def add_row(self, arg, **kwargs):
"""
Convenience function. Same as :py:func:`add_vector`
Expand Down Expand Up @@ -1045,6 +1047,7 @@ def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs):
# return indices (in list, array, etc.) for DTR and ragged DTR
ret[name] = col.get(arg, df=False, index=True, **kwargs)
else:
# breakpoint()
ret[name] = col.get(arg, df=df, index=index, **kwargs)
return ret
# if index is out of range, different errors can be generated depending on the dtype of the column
Expand Down
5 changes: 4 additions & 1 deletion src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,10 @@ def __getitem__(self, item):
"""Delegate slicing to the data object"""
if not self.valid:
raise InvalidDataIOError("Cannot get item from data. Data is not valid.")
return self.data[item]
if isinstance(item, (tuple, list, np.ndarray)):
return [self.data[i] for i in item]
else:
return self.data[item]

def __array__(self):
"""
Expand Down
17 changes: 12 additions & 5 deletions src/hdmf/testing/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..common import validate as common_validate, get_manager
from ..container import AbstractContainer, Container, Data
from ..utils import get_docval_macro
from ..data_utils import AbstractDataChunkIterator
from ..data_utils import AbstractDataChunkIterator, DataIO


class TestCase(unittest.TestCase):
Expand Down Expand Up @@ -148,10 +148,17 @@ def _assert_data_equal(self,
self.assertTrue(isinstance(data1, Data), message)
self.assertTrue(isinstance(data2, Data), message)
self.assertEqual(len(data1), len(data2), message)
self._assert_array_equal(data1.data, data2.data,
ignore_hdmf_attrs=ignore_hdmf_attrs,
ignore_string_to_byte=ignore_string_to_byte,
message=message)
# breakpoint()
if isinstance(data1.data, DataIO) and isinstance(data2.data, DataIO):
self._assert_array_equal(data1.data.data, data2.data.data,
ignore_hdmf_attrs=ignore_hdmf_attrs,
ignore_string_to_byte=ignore_string_to_byte,
message=message)
else:
self._assert_array_equal(data1.data, data2.data,
ignore_hdmf_attrs=ignore_hdmf_attrs,
ignore_string_to_byte=ignore_string_to_byte,
message=message)
self.assertContainerEqual(container1=data1,
container2=data2,
ignore_hdmf_attrs=ignore_hdmf_attrs,
Expand Down
62 changes: 32 additions & 30 deletions tests/unit/common/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def test_constructor_spec(self):
self.check_empty_table(table)

def check_table(self, table):
# breakpoint()
self.assertEqual(len(table), 5)
self.assertEqual(table.columns[0].data, [1, 2, 3, 4, 5])
self.assertEqual(table.columns[1].data, [10.0, 20.0, 30.0, 40.0, 50.0])
self.assertEqual(table.columns[2].data, ['cat', 'dog', 'bird', 'fish', 'lizard'])
self.assertEqual(table.columns[0].data.data, [1, 2, 3, 4, 5])
self.assertEqual(table.columns[1].data.data, [10.0, 20.0, 30.0, 40.0, 50.0])
self.assertEqual(table.columns[2].data.data, ['cat', 'dog', 'bird', 'fish', 'lizard'])
self.assertEqual(table.id.data, [0, 1, 2, 3, 4])
self.assertTrue(hasattr(table, 'baz'))

Expand Down Expand Up @@ -369,11 +370,11 @@ def test_add_column_auto_index_int(self):
data=expected,
index=1)
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index.data, [3, 7])
self.assertListEqual(table.qux_index.data.data, [3, 7])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12])
self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ])
self.assertListEqual(table.qux_index.data, [3, 7, 10])
self.assertListEqual(table.qux_index.data.data, [3, 7, 10])

def test_add_column_auto_index_bool(self):
"""
Expand All @@ -389,13 +390,13 @@ def test_add_column_auto_index_bool(self):
description='qux column',
data=expected,
index=True)

# breakpoint()
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index.data, [3, 7])
self.assertListEqual(table.qux_index.data.data, [3, 7])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12])
self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ])
self.assertListEqual(table.qux_index.data, [3, 7, 10])
self.assertListEqual(table.qux_index.data.data, [3, 7, 10])

def test_add_column_auto_multi_index_int(self):
"""
Expand All @@ -412,13 +413,13 @@ def test_add_column_auto_multi_index_int(self):
data=expected,
index=2)
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index_index.data, [2, 4])
self.assertListEqual(table.qux_index.data, [3, 4, 8, 10])
self.assertListEqual(table.qux_index_index.data.data, [2, 4])
self.assertListEqual(table.qux_index.data.data, [3, 4, 8, 10])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ])
self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]])
self.assertListEqual(table.qux_index_index.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data, [3, 4, 8, 10, 13])
self.assertListEqual(table.qux_index_index.data.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data.data, [3, 4, 8, 10, 13])

def test_add_column_auto_multi_index_int_bad_index_levels(self):
"""
Expand Down Expand Up @@ -466,13 +467,13 @@ def test_add_column_auto_multi_index_int_with_empty_slots(self):
data=expected,
index=2)
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index_index.data, [2, 4])
self.assertListEqual(table.qux_index.data, [0, 0, 0, 0])
self.assertListEqual(table.qux_index_index.data.data, [2, 4])
self.assertListEqual(table.qux_index.data.data, [0, 0, 0, 0])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ])
self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]])
self.assertListEqual(table.qux_index_index.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data, [0, 0, 0, 0, 3])
self.assertListEqual(table.qux_index_index.data.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data.data, [0, 0, 0, 0, 3])

def test_auto_multi_index_required(self):

Expand Down Expand Up @@ -514,7 +515,7 @@ class TestTable(DynamicTable):
]
]
self.assertListEqual(table['qux'][:], expected)
self.assertEqual(table.qux_index_index_index.data, [1, 2])
self.assertEqual(table.qux_index_index_index.data.data, [1, 2])

def test_auto_multi_index(self):

Expand Down Expand Up @@ -556,7 +557,7 @@ class TestTable(DynamicTable):
]
]
self.assertListEqual(table['qux'][:], expected)
self.assertEqual(table.qux_index_index_index.data, [1, 2])
self.assertEqual(table.qux_index_index_index.data.data, [1, 2])

def test_getitem_row_num(self):
table = self.with_spec()
Expand Down Expand Up @@ -1409,7 +1410,7 @@ def test_dci_int_ok(self):
a = np.arange(30)
dci = DataChunkIterator(data=a, buffer_size=1)
e = ElementIdentifiers(name='ids', data=dci) # test that no error is raised
self.assertIs(e.data, dci)
self.assertIs(e.data.data, dci)

def test_dci_float_bad(self):
a = np.arange(30.0)
Expand All @@ -1422,7 +1423,7 @@ def test_dataio_dci_ok(self):
dci = DataChunkIterator(data=a, buffer_size=1)
dio = H5DataIO(dci)
e = ElementIdentifiers(name='ids', data=dio) # test that no error is raised
self.assertIs(e.data, dio)
self.assertIs(e.data.data, dio)


class SubTable(DynamicTable):
Expand Down Expand Up @@ -1606,7 +1607,7 @@ def test_add_opt_column_after_data(self):
table = SubTable(name='subtable', description='subtable description')
table.add_row(col1='a', col3='c', col5='e', col7='g')
table.add_column(name='col2', description='column #2', data=('b', ))
self.assertTupleEqual(table.col2.data, ('b', ))
self.assertTupleEqual(table.col2.data.data, ('b', ))

def test_add_opt_ind_column_after_data(self):
"""Test that adding an optional, indexed column from __columns__ with data works."""
Expand All @@ -1623,8 +1624,8 @@ def test_add_row_opt_column(self):
self.assertTupleEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col2', 'col4'))
self.assertEqual(table.col2.description, 'optional column')
self.assertEqual(table.col4.description, 'optional, indexed column')
self.assertListEqual(table.col2.data, ['b', 'b2'])
# self.assertListEqual(table.col4.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work
self.assertListEqual(table.col2.data.data, ['b', 'b2'])
# self.assertListEqual(table.col4.data.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work

def test_add_row_opt_column_after_data(self):
"""Test that adding a row with an optional column after adding a row without the column raises an error."""
Expand Down Expand Up @@ -1820,14 +1821,14 @@ def test_add_row(self):
ed.add_row('b')
ed.add_row('a')
ed.add_row('c')
np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8))
np.testing.assert_array_equal(ed.data.data, np.array([1, 0, 2], dtype=np.uint8))

def test_add_row_index(self):
ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'])
ed.add_row(1, index=True)
ed.add_row(0, index=True)
ed.add_row(2, index=True)
np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8))
np.testing.assert_array_equal(ed.data.data, np.array([1, 0, 2], dtype=np.uint8))


class TestIndexedEnumData(TestCase):
Expand Down Expand Up @@ -2295,12 +2296,12 @@ def test_init_empty(self):
self.assertEqual(foo_ind.name, 'foo_index')
self.assertEqual(foo_ind.description, "Index for VectorData 'foo'")
self.assertIs(foo_ind.target, foo)
self.assertListEqual(foo_ind.data, list())
self.assertListEqual(foo_ind.data.data, list())

def test_init_data(self):
foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c'])
foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3])
self.assertListEqual(foo_ind.data, [2, 3])
self.assertListEqual(foo_ind.data.data, [2, 3])
self.assertListEqual(foo_ind[0], ['a', 'b'])
self.assertListEqual(foo_ind[1], ['c'])

Expand Down Expand Up @@ -2333,11 +2334,11 @@ def test_add_vector(self):

foo_ind_ind.add_vector([['c11', 'c12', 'c13'], ['c21', 'c22']])

self.assertListEqual(foo.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22'])
self.assertListEqual(foo_ind.data, [2, 3, 4, 7, 9])
self.assertListEqual(foo.data.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22'])
self.assertListEqual(foo_ind.data.data, [2, 3, 4, 7, 9])
self.assertListEqual(foo_ind[3], ['c11', 'c12', 'c13'])
self.assertListEqual(foo_ind[4], ['c21', 'c22'])
self.assertListEqual(foo_ind_ind.data, [2, 3, 5])
self.assertListEqual(foo_ind_ind.data.data, [2, 3, 5])
self.assertListEqual(foo_ind_ind[2], [['c11', 'c12', 'c13'], ['c21', 'c22']])


Expand Down Expand Up @@ -2706,6 +2707,7 @@ def set_up_list_index(self):
def test_array_inc_precision(self):
index = self.set_up_array_index()
index.add_vector(np.empty((255, )))
# breakpoint()
self.assertEqual(index.data[0], 255)
self.assertEqual(index.data.dtype, np.uint8)

Expand Down

0 comments on commit 0d2229c

Please sign in to comment.