diff --git a/src/lifeblood/core_nodes/wedge.py b/src/lifeblood/core_nodes/wedge.py index 3bc12a49..30170f52 100644 --- a/src/lifeblood/core_nodes/wedge.py +++ b/src/lifeblood/core_nodes/wedge.py @@ -46,15 +46,24 @@ def process_task(self, context) -> ProcessingResult: if wedges_count <= 0: return ProcessingResult() wedge_ranges = [] + attribute_names = set() # to check for duplication for i in range(wedges_count): wtype = context.param_value(f'wtype_{i}') + attr_name = context.param_value(f'attr_{i}') + attr_name = attr_name.strip() + if not attr_name: + raise ProcessingError('wedged attribute must not be empty.') + if attr_name in attribute_names: + raise ProcessingError(f'Each attribute must only appear once in the list. Attribute named "{attr_name}" is duplicated') + attribute_names.add(attr_name) + if wtype == 0: count = context.param_value(f'count_{i}') if count <= 0: raise ProcessingError('count cannot be less or equal to zero') - wedge_ranges.append((0, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'to_{i}'), count)) + wedge_ranges.append((0, attr_name, context.param_value(f'from_{i}'), context.param_value(f'to_{i}'), count)) elif wtype == 1: - wedge_ranges.append((1, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'max_{i}'), context.param_value(f'inc_{i}'))) + wedge_ranges.append((1, attr_name, context.param_value(f'from_{i}'), context.param_value(f'max_{i}'), context.param_value(f'inc_{i}'))) else: raise ProcessingError('bad wedge type') @@ -100,12 +109,16 @@ def _do_iter(cur_vals, level=0): if inc == 0: raise ProcessingError('increment cannot be zero') elif inc > 0: + if to < fr: + raise ProcessingError('max value is less than min, while inc is greater than zero') while fr <= to: new_vals = cur_vals.copy() new_vals[attr] = fr _do_iter(new_vals, level+1) fr += inc - else: + else: # inc < 0 + if to > fr: + raise ProcessingError('max value is greater than min, while inc is less than zero') while fr >= to: new_vals = cur_vals.copy() new_vals[attr] = fr diff --git a/tests/nodes/test_wedge.py b/tests/nodes/test_wedge.py index 52cedd89..a4442e13 100644 --- a/tests/nodes/test_wedge.py +++ b/tests/nodes/test_wedge.py @@ -148,3 +148,191 @@ async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, co await self._helper_test_node_with_arg_update( _logic ) + + async def test_incorrect_count(self): + """ + Ensure error is generated when input parameters are invalid + """ + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + for mode, invalid_val, invalid_max, invalid_inc, valid in ( + (0, -10, 10, 1, False), + (0, -1, 10, 1, False), + (0, 0, 10, 1, False), + (0, 1, 10, 1, True), + (0, -10, -10, 1, False), + (0, -1, -10, 1, False), + (0, 0, -10, 1, False), + (0, 1, -10, 1, True), + (0, -10, 10, -1, False), + (0, -1, 10, -1, False), + (0, 0, 10, -1, False), + (0, 1, 10, -1, True), + (1, -10, 10, 1, True), + (1, -1, 10, 1, True), + (1, 0, 10, 1, True), + (1, 1, 10, 1, True), + (1, -10, -10, 1, False), + (1, -1, -10, 1, False), + (1, 0, -10, 1, False), + (1, 1, -10, 1, False), + (1, -10, 10, -1, False), + (1, -1, 10, -1, False), + (1, 0, 10, -1, False), + (1, 1, 10, -1, False), + (1, -10, -10, -1, True), + (1, -1, -10, -1, True), + (1, 0, -10, -1, True), + (1, 1, -10, -1, True), + ): + # invalid_max, invalid_inc should NOT affect anything + task = context.create_pseudo_task_with_attrs({}) + + node = context.create_node('wedge', 'footest') + node.set_param_value('wedge count', 2) + + node.set_param_value('wtype_0', mode) # by count or by inc + node.set_param_value('attr_0', 'foo') + node.set_param_value('from_0', 1) + node.set_param_value('to_0', 10) + node.set_param_value('count_0', invalid_val) + node.set_param_value('max_0', invalid_max) + node.set_param_value('inc_0', invalid_inc) + + # second one to ensure no empty wedge set is triggered + node.set_param_value('wtype_1', 0) # by count or by inc + node.set_param_value('attr_1', 'bar') + node.set_param_value('from_1', 1) + node.set_param_value('to_1', 2) + node.set_param_value('count_1', 2) + node.set_param_value('max_1', 2) + node.set_param_value('inc_1', 1) + + if valid: + context.process_task(node, task) # all good + else: + print(f'try mode = {mode}, count = {invalid_val}, max = {invalid_max}, inc = {invalid_inc}, should be valid = {valid}') + self.assertRaises(ProcessingError, context.process_task, node, task) + + await self._helper_test_node_with_arg_update( + _logic + ) + + async def test_attribute_names(self): + """ + Ensure that leading and trailing spaces are removed from attribute names + this decision is driven only to prevent user confusion as trailing spaces can easily be overlooked + """ + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + task = context.create_pseudo_task_with_attrs({}) + + node = context.create_node('wedge', 'footest') + node.set_param_value('wedge count', 3) + + node.set_param_value('wtype_0', 0) # by count or by inc + node.set_param_value('attr_0', ' foo') + node.set_param_value('from_0', 1) + node.set_param_value('to_0', 10) + node.set_param_value('count_0', 10) + node.set_param_value('max_0', 10) + node.set_param_value('inc_0', 1) + + node.set_param_value('wtype_1', 0) # by count or by inc + node.set_param_value('attr_1', 'bar ') + node.set_param_value('from_1', 1) + node.set_param_value('to_1', 10) + node.set_param_value('count_1', 10) + node.set_param_value('max_1', 10) + node.set_param_value('inc_1', 1) + + node.set_param_value('wtype_2', 0) # by count or by inc + node.set_param_value('attr_2', ' cat ') + node.set_param_value('from_2', 1) + node.set_param_value('to_2', 10) + node.set_param_value('count_2', 10) + node.set_param_value('max_2', 10) + node.set_param_value('inc_2', 1) + + res = context.process_task(node, task) + + self.assertIsNotNone(res) + for attrdict in res._split_attribs: + self.assertSetEqual({'foo', 'bar', 'cat'}, set(attrdict.keys())) + + await self._helper_test_node_with_arg_update( + _logic + ) + + async def test_empty_attribute_names(self): + """ + Ensure that leading and trailing spaces are removed from attribute names + this decision is driven only to prevent user confusion as trailing spaces can easily be overlooked + """ + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + for attrs in [[''], ['foo', ''], [' '], [' ', 'fqfq'], ['qqwe', ' ', 'gggrg']]: + task = context.create_pseudo_task_with_attrs({}) + + node = context.create_node('wedge', 'footest') + node.set_param_value('wedge count', len(attrs)) + + for i, attr_name in enumerate(attrs): + node.set_param_value(f'wtype_{i}', 0) # by count or by inc + node.set_param_value(f'attr_{i}', attr_name) + node.set_param_value(f'from_{i}', 1) + node.set_param_value(f'to_{i}', 3) + node.set_param_value(f'count_{i}', 3) + node.set_param_value(f'max_{i}', 3) + node.set_param_value(f'inc_{i}', 1) + + self.assertRaises(ProcessingError, context.process_task, node, task) + + await self._helper_test_node_with_arg_update( + _logic + ) + + async def test_attribute_names_duplications(self): + """ + Ensure error is generated when one attribute name is duplicated in the list + """ + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + task = context.create_pseudo_task_with_attrs({}) + + node = context.create_node('wedge', 'footest') + node.set_param_value('wedge count', 4) + + node.set_param_value('wtype_0', 0) # by count or by inc + node.set_param_value('attr_0', 'name one') + node.set_param_value('from_0', 1) + node.set_param_value('to_0', 3) + node.set_param_value('count_0', 3) + node.set_param_value('max_0', 3) + node.set_param_value('inc_0', 1) + + node.set_param_value('wtype_1', 0) # by count or by inc + node.set_param_value('attr_1', 'two') + node.set_param_value('from_1', 1) + node.set_param_value('to_1', 3) + node.set_param_value('count_1', 3) + node.set_param_value('max_1', 3) + node.set_param_value('inc_1', 1) + + node.set_param_value('wtype_2', 0) # by count or by inc + node.set_param_value('attr_2', ' fooooo') + node.set_param_value('from_2', 1) + node.set_param_value('to_2', 3) + node.set_param_value('count_2', 3) + node.set_param_value('max_2', 3) + node.set_param_value('inc_2', 1) + + node.set_param_value('wtype_3', 0) # by count or by inc + node.set_param_value('attr_3', 'two') + node.set_param_value('from_3', 1) + node.set_param_value('to_3', 3) + node.set_param_value('count_3', 3) + node.set_param_value('max_3', 3) + node.set_param_value('inc_3', 1) + + self.assertRaises(ProcessingError, context.process_task, node, task) + + await self._helper_test_node_with_arg_update( + _logic + )