diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 3b67ff19d..91cdc4673 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -17,6 +17,7 @@ from ..data_utils import DataIO, AbstractDataChunkIterator from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional, check_type, is_ragged from ..term_set import TermSetWrapper +from ..backends.hdf5.h5_utils import H5DataIO @register_class('VectorData') @@ -39,9 +40,13 @@ class VectorData(Data): {'name': 'description', 'type': str, 'doc': 'a description for this column'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors', 'default': list()}, + {'name': 'extendable', 'type': bool, 'default': True, + 'doc': 'Bool to decide whether to wrap the data with H5DataIO to be extendable by default.'}, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): - description = popargs('description', kwargs) + description, extendable = popargs('description', 'extendable', kwargs) + if not isinstance(kwargs['data'], DataIO) and extendable: + kwargs['data'] = H5DataIO(data=kwargs['data'], maxshape=(None,)) super().__init__(**kwargs) self.description = description @@ -99,16 +104,16 @@ class VectorIndex(VectorData): allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): target = popargs('target', kwargs) - kwargs['description'] = "Index for VectorData '%s'" % target.name - super().__init__(**kwargs) - self.target = target self.__uint = np.uint8 self.__maxval = 255 - if isinstance(self.data, (list, np.ndarray)): - if len(self.data) > 0: - self.__check_precision(len(self.target)) + if isinstance(kwargs['data'], (list, np.ndarray)): + if len(kwargs['data']) > 0: + self.__check_precision(len(target)) # adjust precision for types that we can adjust precision for - self.__adjust_precision(self.__uint) + kwargs['data'] = self.__adjust_precision(self.__uint, kwargs['data']) + kwargs['description'] = "Index for VectorData '%s'" % target.name + super().__init__(**kwargs) + self.target = target def add_vector(self, arg, **kwargs): """ @@ -138,22 +143,25 @@ def __check_precision(self, idx): raise ValueError(msg) self.__maxval = 2 ** nbits - 1 self.__uint = np.dtype('uint%d' % nbits).type - self.__adjust_precision(self.__uint) + if self.data is not None and not isinstance(self.data, DataIO): + self.__adjust_precision(self.__uint, self.data) #TODO: Cannot adjust when wrapped with H5DataIO return self.__uint(idx) - def __adjust_precision(self, uint): + def __adjust_precision(self, uint, data): """ Adjust precision of data to specified unsigned integer precision. """ - if isinstance(self.data, list): - for i in range(len(self.data)): - self.data[i] = uint(self.data[i]) - elif isinstance(self.data, np.ndarray): + if isinstance(data, list): + for i in range(len(data)): + data[i] = uint(data[i]) + elif isinstance(data, np.ndarray): # use self._Data__data to work around restriction on resetting self.data - self._Data__data = self.data.astype(uint) + data = data.astype(uint) else: raise ValueError("cannot adjust precision of type %s to %s", (type(self.data), uint)) + return data + def add_row(self, arg, **kwargs): """ Convenience function. Same as :py:func:`add_vector` @@ -1060,6 +1068,7 @@ def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs): # return indices (in list, array, etc.) for DTR and ragged DTR ret[name] = col.get(arg, df=False, index=True, **kwargs) else: + # breakpoint() ret[name] = col.get(arg, df=df, index=index, **kwargs) return ret # if index is out of range, different errors can be generated depending on the dtype of the column diff --git a/src/hdmf/data_utils.py b/src/hdmf/data_utils.py index 23f0b4019..b7c6241f7 100644 --- a/src/hdmf/data_utils.py +++ b/src/hdmf/data_utils.py @@ -1085,7 +1085,10 @@ def __getitem__(self, item): """Delegate slicing to the data object""" if not self.valid: raise InvalidDataIOError("Cannot get item from data. Data is not valid.") - return self.data[item] + if isinstance(item, (tuple, list, np.ndarray)): + return [self.data[i] for i in item] + else: + return self.data[item] def __array__(self): """ diff --git a/src/hdmf/testing/testcase.py b/src/hdmf/testing/testcase.py index 798df6fe4..5ff3c919a 100644 --- a/src/hdmf/testing/testcase.py +++ b/src/hdmf/testing/testcase.py @@ -10,7 +10,7 @@ from ..common import validate as common_validate, get_manager from ..container import AbstractContainer, Container, Data from ..utils import get_docval_macro -from ..data_utils import AbstractDataChunkIterator +from ..data_utils import AbstractDataChunkIterator, DataIO class TestCase(unittest.TestCase): @@ -148,10 +148,17 @@ def _assert_data_equal(self, self.assertTrue(isinstance(data1, Data), message) self.assertTrue(isinstance(data2, Data), message) self.assertEqual(len(data1), len(data2), message) - self._assert_array_equal(data1.data, data2.data, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + # breakpoint() + if isinstance(data1.data, DataIO) and isinstance(data2.data, DataIO): + self._assert_array_equal(data1.data.data, data2.data.data, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message) + else: + self._assert_array_equal(data1.data, data2.data, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message) self.assertContainerEqual(container1=data1, container2=data2, ignore_hdmf_attrs=ignore_hdmf_attrs, diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index 00b3c14a3..141158047 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -83,10 +83,11 @@ def test_constructor_spec(self): self.check_empty_table(table) def check_table(self, table): + # breakpoint() self.assertEqual(len(table), 5) - self.assertEqual(table.columns[0].data, [1, 2, 3, 4, 5]) - self.assertEqual(table.columns[1].data, [10.0, 20.0, 30.0, 40.0, 50.0]) - self.assertEqual(table.columns[2].data, ['cat', 'dog', 'bird', 'fish', 'lizard']) + self.assertEqual(table.columns[0].data.data, [1, 2, 3, 4, 5]) + self.assertEqual(table.columns[1].data.data, [10.0, 20.0, 30.0, 40.0, 50.0]) + self.assertEqual(table.columns[2].data.data, ['cat', 'dog', 'bird', 'fish', 'lizard']) self.assertEqual(table.id.data, [0, 1, 2, 3, 4]) self.assertTrue(hasattr(table, 'baz')) @@ -531,11 +532,11 @@ def test_add_column_auto_index_int(self): data=expected, index=1) self.assertListEqual(table['qux'][:], expected) - self.assertListEqual(table.qux_index.data, [3, 7]) + self.assertListEqual(table.qux_index.data.data, [3, 7]) # Add more rows after we created the column table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12]) self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ]) - self.assertListEqual(table.qux_index.data, [3, 7, 10]) + self.assertListEqual(table.qux_index.data.data, [3, 7, 10]) def test_add_column_auto_index_bool(self): """ @@ -551,12 +552,13 @@ def test_add_column_auto_index_bool(self): description='qux column', data=expected, index=True) + # breakpoint() self.assertListEqual(table['qux'][:], expected) - self.assertListEqual(table.qux_index.data, [3, 7]) + self.assertListEqual(table.qux_index.data.data, [3, 7]) # Add more rows after we created the column table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12]) self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ]) - self.assertListEqual(table.qux_index.data, [3, 7, 10]) + self.assertListEqual(table.qux_index.data.data, [3, 7, 10]) def test_add_column_auto_multi_index_int(self): """ @@ -573,13 +575,13 @@ def test_add_column_auto_multi_index_int(self): data=expected, index=2) self.assertListEqual(table['qux'][:], expected) - self.assertListEqual(table.qux_index_index.data, [2, 4]) - self.assertListEqual(table.qux_index.data, [3, 4, 8, 10]) + self.assertListEqual(table.qux_index_index.data.data, [2, 4]) + self.assertListEqual(table.qux_index.data.data, [3, 4, 8, 10]) # Add more rows after we created the column table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ]) self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]]) - self.assertListEqual(table.qux_index_index.data, [2, 4, 5]) - self.assertListEqual(table.qux_index.data, [3, 4, 8, 10, 13]) + self.assertListEqual(table.qux_index_index.data.data, [2, 4, 5]) + self.assertListEqual(table.qux_index.data.data, [3, 4, 8, 10, 13]) def test_add_column_auto_multi_index_int_bad_index_levels(self): """ @@ -627,13 +629,13 @@ def test_add_column_auto_multi_index_int_with_empty_slots(self): data=expected, index=2) self.assertListEqual(table['qux'][:], expected) - self.assertListEqual(table.qux_index_index.data, [2, 4]) - self.assertListEqual(table.qux_index.data, [0, 0, 0, 0]) + self.assertListEqual(table.qux_index_index.data.data, [2, 4]) + self.assertListEqual(table.qux_index.data.data, [0, 0, 0, 0]) # Add more rows after we created the column table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ]) self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]]) - self.assertListEqual(table.qux_index_index.data, [2, 4, 5]) - self.assertListEqual(table.qux_index.data, [0, 0, 0, 0, 3]) + self.assertListEqual(table.qux_index_index.data.data, [2, 4, 5]) + self.assertListEqual(table.qux_index.data.data, [0, 0, 0, 0, 3]) def test_auto_multi_index_required(self): @@ -675,7 +677,7 @@ class TestTable(DynamicTable): ] ] self.assertListEqual(table['qux'][:], expected) - self.assertEqual(table.qux_index_index_index.data, [1, 2]) + self.assertEqual(table.qux_index_index_index.data.data, [1, 2]) def test_auto_multi_index(self): @@ -717,7 +719,7 @@ class TestTable(DynamicTable): ] ] self.assertListEqual(table['qux'][:], expected) - self.assertEqual(table.qux_index_index_index.data, [1, 2]) + self.assertEqual(table.qux_index_index_index.data.data, [1, 2]) def test_getitem_row_num(self): table = self.with_spec() @@ -1767,7 +1769,7 @@ def test_add_opt_column_after_data(self): table = SubTable(name='subtable', description='subtable description') table.add_row(col1='a', col3='c', col5='e', col7='g') table.add_column(name='col2', description='column #2', data=('b', )) - self.assertTupleEqual(table.col2.data, ('b', )) + self.assertTupleEqual(table.col2.data.data, ('b', )) def test_add_opt_ind_column_after_data(self): """Test that adding an optional, indexed column from __columns__ with data works.""" @@ -1784,8 +1786,8 @@ def test_add_row_opt_column(self): self.assertTupleEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col2', 'col4')) self.assertEqual(table.col2.description, 'optional column') self.assertEqual(table.col4.description, 'optional, indexed column') - self.assertListEqual(table.col2.data, ['b', 'b2']) - # self.assertListEqual(table.col4.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work + self.assertListEqual(table.col2.data.data, ['b', 'b2']) + # self.assertListEqual(table.col4.data.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work def test_add_row_opt_column_after_data(self): """Test that adding a row with an optional column after adding a row without the column raises an error.""" @@ -1981,14 +1983,14 @@ def test_add_row(self): ed.add_row('b') ed.add_row('a') ed.add_row('c') - np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8)) + np.testing.assert_array_equal(ed.data.data, np.array([1, 0, 2], dtype=np.uint8)) def test_add_row_index(self): ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c']) ed.add_row(1, index=True) ed.add_row(0, index=True) ed.add_row(2, index=True) - np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8)) + np.testing.assert_array_equal(ed.data.data, np.array([1, 0, 2], dtype=np.uint8)) class TestIndexedEnumData(TestCase): @@ -2456,12 +2458,12 @@ def test_init_empty(self): self.assertEqual(foo_ind.name, 'foo_index') self.assertEqual(foo_ind.description, "Index for VectorData 'foo'") self.assertIs(foo_ind.target, foo) - self.assertListEqual(foo_ind.data, list()) + self.assertListEqual(foo_ind.data.data, list()) def test_init_data(self): foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3]) - self.assertListEqual(foo_ind.data, [2, 3]) + self.assertListEqual(foo_ind.data.data, [2, 3]) self.assertListEqual(foo_ind[0], ['a', 'b']) self.assertListEqual(foo_ind[1], ['c']) @@ -2494,11 +2496,11 @@ def test_add_vector(self): foo_ind_ind.add_vector([['c11', 'c12', 'c13'], ['c21', 'c22']]) - self.assertListEqual(foo.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22']) - self.assertListEqual(foo_ind.data, [2, 3, 4, 7, 9]) + self.assertListEqual(foo.data.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22']) + self.assertListEqual(foo_ind.data.data, [2, 3, 4, 7, 9]) self.assertListEqual(foo_ind[3], ['c11', 'c12', 'c13']) self.assertListEqual(foo_ind[4], ['c21', 'c22']) - self.assertListEqual(foo_ind_ind.data, [2, 3, 5]) + self.assertListEqual(foo_ind_ind.data.data, [2, 3, 5]) self.assertListEqual(foo_ind_ind[2], [['c11', 'c12', 'c13'], ['c21', 'c22']]) @@ -2867,6 +2869,7 @@ def set_up_list_index(self): def test_array_inc_precision(self): index = self.set_up_array_index() index.add_vector(np.empty((255, ))) + # breakpoint() self.assertEqual(index.data[0], 255) self.assertEqual(index.data.dtype, np.uint8)