Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have VectorData expandable by default #1064

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..data_utils import DataIO, AbstractDataChunkIterator
from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional, check_type, is_ragged
from ..term_set import TermSetWrapper
from ..backends.hdf5.h5_utils import H5DataIO


@register_class('VectorData')
Expand All @@ -39,9 +40,13 @@ class VectorData(Data):
{'name': 'description', 'type': str, 'doc': 'a description for this column'},
{'name': 'data', 'type': ('array_data', 'data'),
'doc': 'a dataset where the first dimension is a concatenation of multiple vectors', 'default': list()},
{'name': 'extendable', 'type': bool, 'default': True,
'doc': 'Bool to decide whether to wrap the data with H5DataIO to be extendable by default.'},
allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
description = popargs('description', kwargs)
description, extendable = popargs('description', 'extendable', kwargs)
if not isinstance(kwargs['data'], DataIO) and extendable:
kwargs['data'] = H5DataIO(data=kwargs['data'], maxshape=(None,))
super().__init__(**kwargs)
self.description = description

Expand Down Expand Up @@ -99,16 +104,16 @@ class VectorIndex(VectorData):
allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
target = popargs('target', kwargs)
kwargs['description'] = "Index for VectorData '%s'" % target.name
super().__init__(**kwargs)
self.target = target
self.__uint = np.uint8
self.__maxval = 255
if isinstance(self.data, (list, np.ndarray)):
if len(self.data) > 0:
self.__check_precision(len(self.target))
if isinstance(kwargs['data'], (list, np.ndarray)):
if len(kwargs['data']) > 0:
self.__check_precision(len(target))
# adjust precision for types that we can adjust precision for
self.__adjust_precision(self.__uint)
kwargs['data'] = self.__adjust_precision(self.__uint, kwargs['data'])
kwargs['description'] = "Index for VectorData '%s'" % target.name
super().__init__(**kwargs)
self.target = target

def add_vector(self, arg, **kwargs):
"""
Expand Down Expand Up @@ -138,22 +143,25 @@ def __check_precision(self, idx):
raise ValueError(msg)
self.__maxval = 2 ** nbits - 1
self.__uint = np.dtype('uint%d' % nbits).type
self.__adjust_precision(self.__uint)
if self.data is not None and not isinstance(self.data, DataIO):
self.__adjust_precision(self.__uint, self.data) #TODO: Cannot adjust when wrapped with H5DataIO
return self.__uint(idx)

def __adjust_precision(self, uint):
def __adjust_precision(self, uint, data):
"""
Adjust precision of data to specified unsigned integer precision.
"""
if isinstance(self.data, list):
for i in range(len(self.data)):
self.data[i] = uint(self.data[i])
elif isinstance(self.data, np.ndarray):
if isinstance(data, list):
for i in range(len(data)):
data[i] = uint(data[i])
elif isinstance(data, np.ndarray):
# use self._Data__data to work around restriction on resetting self.data
self._Data__data = self.data.astype(uint)
data = data.astype(uint)
else:
raise ValueError("cannot adjust precision of type %s to %s", (type(self.data), uint))

return data

def add_row(self, arg, **kwargs):
"""
Convenience function. Same as :py:func:`add_vector`
Expand Down Expand Up @@ -1060,6 +1068,7 @@ def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs):
# return indices (in list, array, etc.) for DTR and ragged DTR
ret[name] = col.get(arg, df=False, index=True, **kwargs)
else:
# breakpoint()
ret[name] = col.get(arg, df=df, index=index, **kwargs)
return ret
# if index is out of range, different errors can be generated depending on the dtype of the column
Expand Down
5 changes: 4 additions & 1 deletion src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,10 @@ def __getitem__(self, item):
"""Delegate slicing to the data object"""
if not self.valid:
raise InvalidDataIOError("Cannot get item from data. Data is not valid.")
return self.data[item]
if isinstance(item, (tuple, list, np.ndarray)):
return [self.data[i] for i in item]
else:
return self.data[item]

def __array__(self):
"""
Expand Down
17 changes: 12 additions & 5 deletions src/hdmf/testing/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..common import validate as common_validate, get_manager
from ..container import AbstractContainer, Container, Data
from ..utils import get_docval_macro
from ..data_utils import AbstractDataChunkIterator
from ..data_utils import AbstractDataChunkIterator, DataIO


class TestCase(unittest.TestCase):
Expand Down Expand Up @@ -148,10 +148,17 @@ def _assert_data_equal(self,
self.assertTrue(isinstance(data1, Data), message)
self.assertTrue(isinstance(data2, Data), message)
self.assertEqual(len(data1), len(data2), message)
self._assert_array_equal(data1.data, data2.data,
ignore_hdmf_attrs=ignore_hdmf_attrs,
ignore_string_to_byte=ignore_string_to_byte,
message=message)
# breakpoint()
if isinstance(data1.data, DataIO) and isinstance(data2.data, DataIO):
self._assert_array_equal(data1.data.data, data2.data.data,
ignore_hdmf_attrs=ignore_hdmf_attrs,
ignore_string_to_byte=ignore_string_to_byte,
message=message)
else:
self._assert_array_equal(data1.data, data2.data,
ignore_hdmf_attrs=ignore_hdmf_attrs,
ignore_string_to_byte=ignore_string_to_byte,
message=message)
self.assertContainerEqual(container1=data1,
container2=data2,
ignore_hdmf_attrs=ignore_hdmf_attrs,
Expand Down
57 changes: 30 additions & 27 deletions tests/unit/common/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ def test_constructor_spec(self):
self.check_empty_table(table)

def check_table(self, table):
# breakpoint()
self.assertEqual(len(table), 5)
self.assertEqual(table.columns[0].data, [1, 2, 3, 4, 5])
self.assertEqual(table.columns[1].data, [10.0, 20.0, 30.0, 40.0, 50.0])
self.assertEqual(table.columns[2].data, ['cat', 'dog', 'bird', 'fish', 'lizard'])
self.assertEqual(table.columns[0].data.data, [1, 2, 3, 4, 5])
self.assertEqual(table.columns[1].data.data, [10.0, 20.0, 30.0, 40.0, 50.0])
self.assertEqual(table.columns[2].data.data, ['cat', 'dog', 'bird', 'fish', 'lizard'])
self.assertEqual(table.id.data, [0, 1, 2, 3, 4])
self.assertTrue(hasattr(table, 'baz'))

Expand Down Expand Up @@ -531,11 +532,11 @@ def test_add_column_auto_index_int(self):
data=expected,
index=1)
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index.data, [3, 7])
self.assertListEqual(table.qux_index.data.data, [3, 7])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12])
self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ])
self.assertListEqual(table.qux_index.data, [3, 7, 10])
self.assertListEqual(table.qux_index.data.data, [3, 7, 10])

def test_add_column_auto_index_bool(self):
"""
Expand All @@ -551,12 +552,13 @@ def test_add_column_auto_index_bool(self):
description='qux column',
data=expected,
index=True)
# breakpoint()
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index.data, [3, 7])
self.assertListEqual(table.qux_index.data.data, [3, 7])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12])
self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ])
self.assertListEqual(table.qux_index.data, [3, 7, 10])
self.assertListEqual(table.qux_index.data.data, [3, 7, 10])

def test_add_column_auto_multi_index_int(self):
"""
Expand All @@ -573,13 +575,13 @@ def test_add_column_auto_multi_index_int(self):
data=expected,
index=2)
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index_index.data, [2, 4])
self.assertListEqual(table.qux_index.data, [3, 4, 8, 10])
self.assertListEqual(table.qux_index_index.data.data, [2, 4])
self.assertListEqual(table.qux_index.data.data, [3, 4, 8, 10])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ])
self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]])
self.assertListEqual(table.qux_index_index.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data, [3, 4, 8, 10, 13])
self.assertListEqual(table.qux_index_index.data.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data.data, [3, 4, 8, 10, 13])

def test_add_column_auto_multi_index_int_bad_index_levels(self):
"""
Expand Down Expand Up @@ -627,13 +629,13 @@ def test_add_column_auto_multi_index_int_with_empty_slots(self):
data=expected,
index=2)
self.assertListEqual(table['qux'][:], expected)
self.assertListEqual(table.qux_index_index.data, [2, 4])
self.assertListEqual(table.qux_index.data, [0, 0, 0, 0])
self.assertListEqual(table.qux_index_index.data.data, [2, 4])
self.assertListEqual(table.qux_index.data.data, [0, 0, 0, 0])
# Add more rows after we created the column
table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ])
self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]])
self.assertListEqual(table.qux_index_index.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data, [0, 0, 0, 0, 3])
self.assertListEqual(table.qux_index_index.data.data, [2, 4, 5])
self.assertListEqual(table.qux_index.data.data, [0, 0, 0, 0, 3])

def test_auto_multi_index_required(self):

Expand Down Expand Up @@ -675,7 +677,7 @@ class TestTable(DynamicTable):
]
]
self.assertListEqual(table['qux'][:], expected)
self.assertEqual(table.qux_index_index_index.data, [1, 2])
self.assertEqual(table.qux_index_index_index.data.data, [1, 2])

def test_auto_multi_index(self):

Expand Down Expand Up @@ -717,7 +719,7 @@ class TestTable(DynamicTable):
]
]
self.assertListEqual(table['qux'][:], expected)
self.assertEqual(table.qux_index_index_index.data, [1, 2])
self.assertEqual(table.qux_index_index_index.data.data, [1, 2])

def test_getitem_row_num(self):
table = self.with_spec()
Expand Down Expand Up @@ -1767,7 +1769,7 @@ def test_add_opt_column_after_data(self):
table = SubTable(name='subtable', description='subtable description')
table.add_row(col1='a', col3='c', col5='e', col7='g')
table.add_column(name='col2', description='column #2', data=('b', ))
self.assertTupleEqual(table.col2.data, ('b', ))
self.assertTupleEqual(table.col2.data.data, ('b', ))

def test_add_opt_ind_column_after_data(self):
"""Test that adding an optional, indexed column from __columns__ with data works."""
Expand All @@ -1784,8 +1786,8 @@ def test_add_row_opt_column(self):
self.assertTupleEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col2', 'col4'))
self.assertEqual(table.col2.description, 'optional column')
self.assertEqual(table.col4.description, 'optional, indexed column')
self.assertListEqual(table.col2.data, ['b', 'b2'])
# self.assertListEqual(table.col4.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work
self.assertListEqual(table.col2.data.data, ['b', 'b2'])
# self.assertListEqual(table.col4.data.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work

def test_add_row_opt_column_after_data(self):
"""Test that adding a row with an optional column after adding a row without the column raises an error."""
Expand Down Expand Up @@ -1981,14 +1983,14 @@ def test_add_row(self):
ed.add_row('b')
ed.add_row('a')
ed.add_row('c')
np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8))
np.testing.assert_array_equal(ed.data.data, np.array([1, 0, 2], dtype=np.uint8))

def test_add_row_index(self):
ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'])
ed.add_row(1, index=True)
ed.add_row(0, index=True)
ed.add_row(2, index=True)
np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8))
np.testing.assert_array_equal(ed.data.data, np.array([1, 0, 2], dtype=np.uint8))


class TestIndexedEnumData(TestCase):
Expand Down Expand Up @@ -2456,12 +2458,12 @@ def test_init_empty(self):
self.assertEqual(foo_ind.name, 'foo_index')
self.assertEqual(foo_ind.description, "Index for VectorData 'foo'")
self.assertIs(foo_ind.target, foo)
self.assertListEqual(foo_ind.data, list())
self.assertListEqual(foo_ind.data.data, list())

def test_init_data(self):
foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c'])
foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3])
self.assertListEqual(foo_ind.data, [2, 3])
self.assertListEqual(foo_ind.data.data, [2, 3])
self.assertListEqual(foo_ind[0], ['a', 'b'])
self.assertListEqual(foo_ind[1], ['c'])

Expand Down Expand Up @@ -2494,11 +2496,11 @@ def test_add_vector(self):

foo_ind_ind.add_vector([['c11', 'c12', 'c13'], ['c21', 'c22']])

self.assertListEqual(foo.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22'])
self.assertListEqual(foo_ind.data, [2, 3, 4, 7, 9])
self.assertListEqual(foo.data.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22'])
self.assertListEqual(foo_ind.data.data, [2, 3, 4, 7, 9])
self.assertListEqual(foo_ind[3], ['c11', 'c12', 'c13'])
self.assertListEqual(foo_ind[4], ['c21', 'c22'])
self.assertListEqual(foo_ind_ind.data, [2, 3, 5])
self.assertListEqual(foo_ind_ind.data.data, [2, 3, 5])
self.assertListEqual(foo_ind_ind[2], [['c11', 'c12', 'c13'], ['c21', 'c22']])


Expand Down Expand Up @@ -2867,6 +2869,7 @@ def set_up_list_index(self):
def test_array_inc_precision(self):
index = self.set_up_array_index()
index.add_vector(np.empty((255, )))
# breakpoint()
self.assertEqual(index.data[0], 255)
self.assertEqual(index.data.dtype, np.uint8)

Expand Down
Loading