From 9babbc6f6f7d640a7bf9a2e9aae6147977e14a8d Mon Sep 17 00:00:00 2001 From: Corran Webster Date: Wed, 4 Dec 2024 09:09:33 +0000 Subject: [PATCH] Fix `Count` and `Range` `__getitem__` methods. (#76) These were wrong for slices and negative indices. These are now fixed and tested. Fixes #75 --- src/tempe/data_view.py | 14 +++++---- tests/tempe/test_data_views.py | 54 +++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/tempe/data_view.py b/src/tempe/data_view.py index 8eaef2a..cc665eb 100644 --- a/src/tempe/data_view.py +++ b/src/tempe/data_view.py @@ -278,8 +278,8 @@ def __len__(self): def __getitem__(self, index): if isinstance(index, slice): return [ - self.start + self.step * index - for index in range(slice.start, slice.stop, slice, step) + self.start + self.step * i + for i in range(index.start, index.stop, index.step) ] else: return self.start + self.step * index @@ -305,11 +305,13 @@ def __len__(self): def __getitem__(self, index): if isinstance(index, slice): return [ - self.start + self.step * self.index - for index in range(slice.start, slice.stop, slice, step) + self.start + self.step * i + for i in range(index.start, index.stop, index.step) ] - else: - return self.start + self.step * self.index + elif index < 0: + index = len(self) + index + + return self.start + self.step * index class Slice(DataView): diff --git a/tests/tempe/test_data_views.py b/tests/tempe/test_data_views.py index 5470461..11a5400 100644 --- a/tests/tempe/test_data_views.py +++ b/tests/tempe/test_data_views.py @@ -4,7 +4,7 @@ import unittest -from tempe.data_view import Count +from tempe.data_view import Count, Range class TestCount(unittest.TestCase): @@ -27,6 +27,15 @@ def test_count_getitem(self): self.assertEqual(count[0], 10) self.assertEqual(count[1], 15) self.assertEqual(count[2], 20) + # consistent, so leave it as behaviour + self.assertEqual(count[-1], 5) + + def test_count_getitem_slice(self): + """Test count getitem works with slices.""" + + count = Count(10, 5) + + self.assertEqual(count[2:5:2], [20, 30]) def test_count_default(self): """Test count default starts and steps as expected.""" @@ -48,6 +57,49 @@ def test_count_default_getitem(self): self.assertEqual(count[2], 2) +class TestRange(unittest.TestCase): + + def test_range(self): + """Test range starts and steps as expected.""" + + r = Range(10, 25, 5) + + self.assertEqual(list(r), [10, 15, 20]) + + def test_range_len(self): + """Test range length is correct.""" + + r = Range(10, 25, 5) + + self.assertEqual(len(r), 3) + + def test_range_getitem(self): + """Test range getitem works as expected.""" + + r = Range(10, 25, 5) + + self.assertEqual(r[0], 10) + self.assertEqual(r[1], 15) + self.assertEqual(r[2], 20) + self.assertEqual(r[-1], 20) + + def test_range_getitem_slice(self): + """Test range getitem works with slices.""" + + r = Range(10, 50, 5) + + self.assertEqual(r[2:5:2], [20, 30]) + + def test_range_defaults(self): + """Test range default starts and steps as expected.""" + + r1 = Range(10) + r2 = Range(10, 15) + + self.assertEqual(list(r1), list(range(10))) + self.assertEqual(list(r2), list(range(10, 15))) + + if __name__ == "__main__": result = unittest.main() if not result.wasSuccessful():