diff --git a/src/lifeblood/core_nodes/attribute_splitter.py b/src/lifeblood/core_nodes/attribute_splitter.py index ecabb495..03800d4e 100644 --- a/src/lifeblood/core_nodes/attribute_splitter.py +++ b/src/lifeblood/core_nodes/attribute_splitter.py @@ -83,12 +83,17 @@ def _ranges_from_chunksize(start, end, inc, chunk_size): raise ProcessingError('chunk size must be positive') ranges = [] cur = start - while cur < end: - count_till_end = math.ceil((end - cur) / inc) + while cur <= end: + count_till_end = math.ceil((end - cur + 1) / inc) if count_till_end == 0: break - ranges.append((cur, min(end, cur + inc * min(count_till_end, chunk_size)), min(count_till_end, chunk_size))) - cur = ranges[-1][1] + print(count_till_end) + ranges.append((cur, cur + inc * (min(count_till_end, chunk_size) - 1), min(count_till_end, chunk_size))) + cur = ranges[-1][1] + inc + + # just set "end" for the last range + if ranges: + ranges[-1] = tuple(end if i == 1 else x for i, x in enumerate(ranges[-1])) return ranges @staticmethod @@ -97,13 +102,19 @@ def _ranges_from_chunkcount(start, end, inc, chunk_count): raise ProcessingError('chunk count must be positive') ranges = [] print(f'- {start}, {end} :{inc}') + chunk_count = min(chunk_count, int((end-start)/inc)+1) for i in range(chunk_count): e1 = start + math.ceil(((end - start) / chunk_count * i) / inc) * inc - e2 = min(end, start + math.ceil(((end - start) / chunk_count * (i + 1)) / inc) * inc) - print(e1, e2) - if e1 >= e2: # case where chunk_count is bigger than possible - continue - new_range = (e1, e2, math.ceil((e2-e1)/inc)) + e2 = start + math.ceil(((end - start) / chunk_count * (i + 1)) / inc) * inc + # if e1 >= e2: # case where chunk_count is bigger than possible + # continue + adj = 0 + if i == chunk_count - 1: + e2 = end + else: + adj = -inc + print(e1, e2 + adj, math.ceil((e2-e1)/inc)) + new_range = (e1, e2 + adj, math.ceil((e2-e1)/inc)) ranges.append(new_range) return ranges @@ -120,8 +131,13 @@ def process_task(self, context): raise ProcessingError(f'attribute "{attr_name}" must be a list') res = ProcessingResult() chunksize = context.param_value('chunk size') + if chunksize <= 0: + raise ProcessingError(f'chunk size cannot be less or equal to zero, got: {chunksize}') - split_into = 1 + (len(attr_value) - 1) // chunksize + if len(attr_value) <= 1: + split_into = 1 + else: + split_into = 1 + (len(attr_value) - 1) // chunksize # yes we can split into 1 part. this should behave the same way as when splitting into multiple parts # in order to have consistent behaviour in the graph res.split_task(split_into) @@ -170,30 +186,3 @@ def process_task(self, context): return res else: raise NotImplementedError(f'mode "{mode}" is not implemented') - - -if __name__ == '__main__': - # some fast tests - r = FramerangeSplitter._ranges_from_chunkcount(10, 12, 3, 10) - assert r == [(10, 12, 1)], r - r = FramerangeSplitter._ranges_from_chunkcount(10, 15, 3, 10) - assert r == [(10, 13, 1), (13, 15, 1)], r - r = FramerangeSplitter._ranges_from_chunkcount(15, 16, 3, 10) - assert r == [(15, 16, 1)], r - r = FramerangeSplitter._ranges_from_chunkcount(15, 29, 3, 3) - assert r == [(15, 21, 2), (21, 27, 2), (27, 29, 1)], r - r = FramerangeSplitter._ranges_from_chunkcount(15, 29, 3, 4) - assert r == [(15, 21, 2), (21, 24, 1), (24, 27, 1), (27, 29, 1)], r - - r = FramerangeSplitter._ranges_from_chunksize(10, 12, 3, 1) - assert r == [(10, 12, 1)], r - r = FramerangeSplitter._ranges_from_chunksize(10, 12, 3, 10) - assert r == [(10, 12, 1)], r - r = FramerangeSplitter._ranges_from_chunksize(10, 15, 3, 1) - assert r == [(10, 13, 1), (13, 15, 1)], r - r = FramerangeSplitter._ranges_from_chunksize(10, 15, 3, 10) - assert r == [(10, 15, 2)], r - r = FramerangeSplitter._ranges_from_chunksize(15, 16, 3, 10) - assert r == [(15, 16, 1)], r - r = FramerangeSplitter._ranges_from_chunksize(15, 38, 3, 3) - assert r == [(15, 24, 3), (24, 33, 3), (33, 38, 2)], r diff --git a/src/lifeblood/core_nodes/split_waiter.py b/src/lifeblood/core_nodes/split_waiter.py index 6f682600..356dd5be 100644 --- a/src/lifeblood/core_nodes/split_waiter.py +++ b/src/lifeblood/core_nodes/split_waiter.py @@ -59,6 +59,61 @@ def __init__(self, name: str): ui.add_parameter('sort_by', None, NodeParameterType.STRING, '_builtin_id') ui.add_parameter('reversed', 'reversed', NodeParameterType.BOOL, False) + def __get_promote_attribs(self, context): + attribs_to_promote = {} + split_id = context.task_field('split_id') + num_attribs = context.param_value('transfer_attribs') + for i in range(num_attribs): + src_attr_name = context.param_value(f'src_attr_name_{i}') + transfer_type = context.param_value(f'transfer_type_{i}') + dst_attr_name = context.param_value(f'dst_attr_name_{i}') + sort_attr_name = context.param_value(f'sort_by_{i}') + sort_reversed = context.param_value(f'reversed_{i}') + if transfer_type == 'append': + gathered_values = [] + for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): + if src_attr_name not in attribs: + continue + + attr_val = attribs[src_attr_name] + gathered_values.append(attr_val) + attribs_to_promote[dst_attr_name] = gathered_values + elif transfer_type == 'extend': + gathered_values = [] + for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): + if src_attr_name not in attribs: + continue + + attr_val = attribs[src_attr_name] + if isinstance(attr_val, list): + gathered_values.extend(attr_val) + else: + gathered_values.append(attr_val) + attribs_to_promote[dst_attr_name] = gathered_values + elif transfer_type == 'first': + _acd = self.__cache[split_id]['arrived'] + if len(_acd) > 0: + if sort_reversed: + attribs = max(_acd.values(), key=lambda x: x.get(sort_attr_name, 0)) + else: + attribs = min(_acd.values(), key=lambda x: x.get(sort_attr_name, 0)) + if src_attr_name in attribs: + attribs_to_promote[dst_attr_name] = attribs[src_attr_name] + elif transfer_type == 'sum': + # we don't care about the order, assume sum is associative + gathered_values = None + for attribs in self.__cache[split_id]['arrived'].values(): + if src_attr_name not in attribs: + continue + if gathered_values is None: + gathered_values = attribs[src_attr_name] + else: + gathered_values += attribs[src_attr_name] + attribs_to_promote[dst_attr_name] = gathered_values + else: + raise NotImplementedError(f'transfer type "{transfer_type}" is not implemented') + return attribs_to_promote + def ready_to_process_task(self, task_dict) -> bool: context = ProcessingContext(self, task_dict) split_id = context.task_field('split_id') @@ -96,59 +151,9 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib res = ProcessingResult() res.kill_task() self.__cache[split_id]['processed'].add(context.task_field('split_element')) - attribs_to_promote = {} if self.__cache[split_id]['first_to_arrive'] == task_id: # transfer attributes # TODO: delete cache for already processed splits - num_attribs = context.param_value('transfer_attribs') - for i in range(num_attribs): - src_attr_name = context.param_value(f'src_attr_name_{i}') - transfer_type = context.param_value(f'transfer_type_{i}') - dst_attr_name = context.param_value(f'dst_attr_name_{i}') - sort_attr_name = context.param_value(f'sort_by_{i}') - sort_reversed = context.param_value(f'reversed_{i}') - if transfer_type == 'append': - gathered_values = [] - for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): - if src_attr_name not in attribs: - continue - - attr_val = attribs[src_attr_name] - gathered_values.append(attr_val) - attribs_to_promote[dst_attr_name] = gathered_values - elif transfer_type == 'extend': - gathered_values = [] - for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): - if src_attr_name not in attribs: - continue - - attr_val = attribs[src_attr_name] - if isinstance(attr_val, list): - gathered_values.extend(attr_val) - else: - gathered_values.append(attr_val) - attribs_to_promote[dst_attr_name] = gathered_values - elif transfer_type == 'first': - _acd = self.__cache[split_id]['arrived'] - if len(_acd) > 0: - if sort_reversed: - attribs = max(_acd.values(), key=lambda x: x.get(sort_attr_name, 0)) - else: - attribs = min(_acd.values(), key=lambda x: x.get(sort_attr_name, 0)) - if src_attr_name in attribs: - attribs_to_promote[dst_attr_name] = attribs[src_attr_name] - elif transfer_type == 'sum': - # we don't care about the order, assume sum is associative - gathered_values = None - for attribs in self.__cache[split_id]['arrived'].values(): - if src_attr_name not in attribs: - continue - if gathered_values is None: - gathered_values = attribs[src_attr_name] - else: - gathered_values += attribs[src_attr_name] - attribs_to_promote[dst_attr_name] = gathered_values - else: - raise NotImplementedError(f'transfer type "{transfer_type}" is not implemented') + attribs_to_promote = self.__get_promote_attribs(context) res.remove_split(attributes_to_set=attribs_to_promote) changed = True @@ -176,6 +181,13 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib def postprocess_task(self, context) -> ProcessingResult: return ProcessingResult() + def _debug_has_internal_data_for_split(self, split_id: int) -> bool: + """ + method for debug/testing purposes. + returns True if some internal data is present for a given split_id + """ + return split_id in self.__cache + def __getstate__(self): d = super(SplitAwaiterNode, self).__getstate__() assert '_SplitAwaiterNode__main_lock' in d diff --git a/src/lifeblood/core_nodes/wedge.py b/src/lifeblood/core_nodes/wedge.py index 12c9b6c0..2fb3199c 100644 --- a/src/lifeblood/core_nodes/wedge.py +++ b/src/lifeblood/core_nodes/wedge.py @@ -49,7 +49,10 @@ def process_task(self, context) -> ProcessingResult: for i in range(wedges_count): wtype = context.param_value(f'wtype_{i}') if wtype == 0: - wedge_ranges.append((0, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'to_{i}'), context.param_value(f'count_{i}'))) + count = context.param_value(f'count_{i}') + if count <= 0: + raise ProcessingError('count cannot be less or equal to zero') + wedge_ranges.append((0, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'to_{i}'), count)) elif wtype == 1: wedge_ranges.append((1, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'max_{i}'), context.param_value(f'inc_{i}'))) else: @@ -63,13 +66,37 @@ def _do_iter(cur_vals, level=0): return if wedge_ranges[level][0] == 0: _, attr, fr, to, cnt = wedge_ranges[level] + # first check if we actually work with ints + if int(fr) == fr: + fr = int(fr) + if int(to) == to: + to = int(to) + inc = None + if cnt > 1: + finc = (to - fr) / (cnt - 1) + if int(finc) == finc: + inc = int(finc) + + # now create variations for i in range(cnt): new_vals = cur_vals.copy() - t = i * 1.0 / (cnt-1) - new_vals[attr] = fr*(1-t) + to*t + if inc is None: + t = i * 1.0 / (cnt-1) + new_vals[attr] = fr*(1-t) + to*t + else: + new_vals[attr] = fr + i*inc _do_iter(new_vals, level+1) elif wedge_ranges[level][0] == 1: _, attr, fr, to, inc = wedge_ranges[level] + # first check if we actually work with ints + if int(inc) == inc: + inc = int(inc) + if int(fr) == fr: + fr = int(fr) + if int(to) == to: + to = int(to) + + # now create variations if inc == 0: raise ProcessingError('increment cannot be zero') elif inc > 0: @@ -88,6 +115,8 @@ def _do_iter(cur_vals, level=0): _do_iter({}) res = ProcessingResult() + if len(all_wedges) == 0: + raise ProcessingError('unexpectedly no wedges were created') res.split_task(len(all_wedges)) for i, attrs in enumerate(all_wedges): res.set_split_task_attribs(i, attrs) diff --git a/tests/nodes/common.py b/tests/nodes/common.py index d8705045..9bc816b9 100644 --- a/tests/nodes/common.py +++ b/tests/nodes/common.py @@ -41,7 +41,7 @@ def children_ids_for(self, task_id, active_only=False) -> List[int]: class PseudoTask: - def __init__(self, pool: PseudoTaskPool, task_id: int, attrs: dict, parent_id: Optional[int] = None, state: Optional[TaskState] = None): + def __init__(self, pool: PseudoTaskPool, task_id: int, attrs: dict, parent_id: Optional[int] = None, state: Optional[TaskState] = None, extra_fields: Optional[dict] = None): self.__id = task_id self.__parent_id = parent_id self.__state = state or TaskState.GENERATING @@ -52,7 +52,8 @@ def __init__(self, pool: PseudoTaskPool, task_id: int, attrs: dict, parent_id: O 'node_input_name': self.__input_name, 'state': self.__state.value, 'parent_id': parent_id, - 'attributes': json.dumps(attrs) + 'attributes': json.dumps(attrs), + **(extra_fields or {}) } def id(self) -> int: @@ -104,16 +105,32 @@ def __init__(self, scheduler: Scheduler): self.__scheduler = scheduler self.__last_node_id = 135 # cuz why starting with zero? self.__last_task_id = 468 + self.__last_split_id = 765 self.__tasks: Dict[int, PseudoTask] = {} - def create_pseudo_task_with_attrs(self, attrs: dict, task_id: Optional[int] = None, parent_id: Optional[int] = None, state: Optional[TaskState] = None) -> PseudoTask: + def create_pseudo_task_with_attrs(self, attrs: dict, task_id: Optional[int] = None, parent_id: Optional[int] = None, state: Optional[TaskState] = None, extra_fields: Optional[dict] = None) -> PseudoTask: if task_id is None: self.__last_task_id += 1 task_id = self.__last_task_id - task = PseudoTask(self, task_id, attrs, parent_id, state) + task = PseudoTask(self, task_id, attrs, parent_id, state, extra_fields) self.__tasks[task_id] = task return task + def create_pseudo_split_of(self, pseudo_task: PseudoTask, attribs_list): + """ + split count will be taken from attrib_list's len + """ + tasks = [] + self.__last_split_id += 1 + for i, attribs in enumerate(attribs_list): + tasks.append(self.create_pseudo_task_with_attrs(attribs, task_id=None, parent_id=None, state=TaskState.GENERATING, extra_fields={ + 'split_id': self.__last_split_id, + 'split_count': len(attribs_list), + 'split_origin_task_id': pseudo_task.id(), + 'split_element': i, + })) + return tasks + def create_node(self, node_type: str, node_name: str) -> BaseNode: self.__last_node_id += 1 return create_node(node_type, node_name, self.__scheduler, self.__last_node_id) diff --git a/tests/nodes/test_attribute_splitter.py b/tests/nodes/test_attribute_splitter.py new file mode 100644 index 00000000..263c0586 --- /dev/null +++ b/tests/nodes/test_attribute_splitter.py @@ -0,0 +1,192 @@ +import random +from asyncio import Event +from lifeblood.scheduler import Scheduler +from lifeblood.worker import Worker +from lifeblood.basenode import BaseNode, ProcessingError +from lifeblood.exceptions import NodeNotReadyToProcess +from .common import TestCaseBase, PseudoContext + +from typing import List + + +class TestWedge(TestCaseBase): + async def test_basic_functions_list(self): + exp_pairs_list = [ + ( + ('foofattr', 4, []), + ([],) + ), + ( + ('foofattr', 4, [5]), + ([5],) + ), + ( + ('foofattr', 4, [1, 3, 5, 7]), + ([1, 3, 5, 7],) + ), + ( + ('foofattr', 4, [1, 3, 5, 7, 9, 2, 4, 6, 8, 0]), + ([1, 3, 5, 7], [9, 2, 4, 6], [8, 0]) + ), + ( + ('foofattr', 4, [1, 3, 5, 7, 9, 2, 4, 6, 8]), + ([1, 3, 5, 7], [9, 2, 4, 6], [8]) + ), + ( + ('foofattr', 4, [1, 3, 5, 7, 9, 2, 4,]), + ([1, 3, 5, 7], [9, 2, 4]) + ), + ( + ('foofattr', 0, [1, 3, 5, 7, 9, 2, 4, ]), + None + ), + ( + ('foofattr', -1, [1, 3, 5, 7, 9, 2, 4, ]), + None + ), + ] + + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + # rng = random.Random(9377151) + + for exp_inputs, exp_results in exp_pairs_list: + attr_name, chunk_size, attr_val = exp_inputs + + task = context.create_pseudo_task_with_attrs({attr_name: attr_val}) + node = context.create_node('attribute_splitter', 'footest') + node.set_param_value('split type', 'list') + node.set_param_value('attribute name', attr_name) + node.set_param_value('chunk size', chunk_size) + + if exp_results is None: # expect error + self.assertRaises(ProcessingError, context.process_task, node, task) + continue + + res = context.process_task(node, task) + self.assertIsNotNone(res) + + self.assertEqual(len(exp_results), len(res._split_attribs), f'len mismatch! expected {exp_results}, actual {res._split_attribs}') + for exp, attr in zip(exp_results, res._split_attribs): + self.assertIn(attr_name, attr) + self.assertEqual(exp, attr[attr_name]) + + await self._helper_test_node_with_arg_update( + _logic + ) + + async def test_basic_functions_range(self): + exp_pairs_list = [ + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 4, 0, 1, 10, 1, 0, 0, 0), + ((1, 4, None), (5, 8, None), (9, 10, None)) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 10, 12, 1, 0, 0, 2), + ((10, 12, 3), ) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 10, 13, 1, 0, 0, 2), + ((10, 12, 3), (13, 13, 1)) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 10, 12, 10, 0, 0, 2), + ((10, 12, 1), ) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 10, 15, 1, 0, 0, 2), + ((10, 12, 3), (13, 15, 3)) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 10, 15, 10, 0, 0, 2), + ((10, 15, 1),) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 10, 16, 10, 0, 0, 2), + ((10, 16, 1), ) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 3, 0, 15, 38, 3, 0, 0, 2), + ((15, 21, 3), (24, 30, 3), (33, 38, 2)) + ), + # + # count + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 0, 10, 10, 12, 3, 0, 1, 0), + ((10, 12, None),) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 0, 10, 10, 15, 3, 0, 1, 0), + ((10, 10, None), (13, 15, None)) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 0, 10, 15, 16, 3, 0, 1, 1), + ((15, None, 1),) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 0, 3, 15, 29, 3, 0, 1, 2), + ((15, 18, 2), (21, 24, 2), (27, 29, 1)) + ), + ( + # chsz chcnt st end inc type spl-by mode + ('astart', 'aend', 'asize', 0, 4, 15, 29, 3, 0, 1, 2), + ((15, 18, 2), (21, 21, 1), (24, 24, 1), (27, 29, 1)) + ), + ] + + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + # rng = random.Random(9377151) + + for exp_inputs, exp_results in exp_pairs_list: + start_name, end_name, size_name, chunk_size, chunk_count, start, end, inc, rtype, split_by, out_mode = exp_inputs + print(f'{start_name},{end_name},{size_name}: chunksize={chunk_size}, chunk_count={chunk_count}, {start}/{end}/{inc}, rtype={rtype}, split_by={split_by}, out_mode={out_mode}') + + task = context.create_pseudo_task_with_attrs({}) + node = context.create_node('attribute_splitter', 'footest') + node.set_param_value('split type', 'range') + node.set_param_value('out start name', start_name) + node.set_param_value('out end name', end_name) + node.set_param_value('out size name', size_name) + node.set_param_value('range chunk', chunk_size) + node.set_param_value('range count', chunk_count) + node.set_param_value('range start', start) + node.set_param_value('range end', end) + node.set_param_value('range inc', inc) + node.set_param_value('out range type', rtype) + node.set_param_value('range split by', split_by) + node.set_param_value('range mode', out_mode) + + if exp_results is None: # expect error + self.assertRaises(ProcessingError, context.process_task, node, task) + continue + + res = context.process_task(node, task) + self.assertIsNotNone(res) + + self.assertEqual(len(exp_results), len(res._split_attribs), f'len mismatch! expected {exp_results}, actual {res._split_attribs}') + print(res._split_attribs) + for (e_start, e_end, e_size), attr in zip(exp_results, res._split_attribs): + self.assertEqual(e_start, attr[start_name]) + if e_end is not None: + self.assertEqual(e_end, attr[end_name]) + else: + self.assertNotIn(end_name, attr) + if e_size is not None: + self.assertEqual(e_size, attr[size_name]) + else: + self.assertNotIn(size_name, attr) + + await self._helper_test_node_with_arg_update( + _logic + ) diff --git a/tests/nodes/test_split_waiter.py b/tests/nodes/test_split_waiter.py new file mode 100644 index 00000000..7e5ec194 --- /dev/null +++ b/tests/nodes/test_split_waiter.py @@ -0,0 +1,177 @@ +import random +from asyncio import Event +from lifeblood.scheduler import Scheduler +from lifeblood.worker import Worker +from lifeblood.basenode import BaseNode +from lifeblood.exceptions import NodeNotReadyToProcess +from .common import TestCaseBase, PseudoContext + +from typing import List + + +class TestSplitWaiter(TestCaseBase): + async def test_basic_functions_wait(self): + await self._helper_test_basic_functions(0) + + async def test_basic_functions_wait_with_garbage(self): + await self._helper_test_basic_functions(20) + + async def test_basic_functions_nowait(self): + await self._helper_test_basic_functions(0, False) + + async def test_basic_functions_nowait_with_garbage(self): + await self._helper_test_basic_functions(20, False) + + async def _helper_test_basic_functions_wait(self, garbage_split_count): + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + rng = random.Random(716394) + main_task = context.create_pseudo_task_with_attrs({'boo': 32}) + tsplits = context.create_pseudo_split_of(main_task, [ + {'foo': 5}, + {'foo': 8}, + {'foo': 1}, + {'foo': 3}, + ]) + garbage_tasks = [] + garbage_splits = [] + for i in range(garbage_split_count): + gtask = context.create_pseudo_task_with_attrs({'boo': rng.randint(-111, 111)}) + garbage_tasks.append(gtask) + garbage_splits += context.create_pseudo_split_of(gtask, [{} for _ in range(rng.randint(1, 12))]) + garbage_tasks_all = garbage_tasks + garbage_splits + rng.shuffle(garbage_tasks_all) + def _process_garbage_tasks(node): + for _ in range(rng.randint(0, len(garbage_tasks_all) // 2)): + context.process_task(node, rng.choice(garbage_tasks_all)) + + for _ in range(100): + node = context.create_node('split_waiter', 'footest') + node.set_param_value('wait for all', True) + node.set_param_value('transfer_attribs', 1) + node.set_param_value('transfer_type_0', 'extend') + node.set_param_value('src_attr_name_0', 'foo') + node.set_param_value('dst_attr_name_0', 'bar') + node.set_param_value('sort_by_0', '_builtin_id') + + _process_garbage_tasks(node) + + rng.shuffle(tsplits) + for task in tsplits[:-1]: + self.assertIsNone(context.process_task(node, task)) + _process_garbage_tasks(node) + + last_res = context.process_task(node, tsplits[-1]) + self.assertIsNotNone(last_res) + + _process_garbage_tasks(node) + + ress = [] + for task in tsplits[:-1]: + res = context.process_task(node, task) + self.assertIsNotNone(res) + ress.append(res) + _process_garbage_tasks(node) + + ress.append(last_res) + + # check attrs + self.assertSetEqual({'boo'}, set(main_task.attributes().keys())) + self.assertEqual(32, main_task.attributes()['boo']) + self.assertListEqual([5, 8, 1, 3], ress[0].split_attributes_to_set['bar']) + self.assertTrue(ress[0].do_split_remove) + for res in ress[1:]: + self.assertFalse(res.do_split_remove) + for res in ress: + self.assertTrue(res.do_kill_task) + + # check cleanup + self.assertFalse(node._debug_has_internal_data_for_split(tsplits[0].task_dict()['split_id'])) + + await self._helper_test_node_with_arg_update( + _logic + ) + + async def _helper_test_basic_functions(self, garbage_split_count, do_wait=True): + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + rng = random.Random(716394) + main_task = context.create_pseudo_task_with_attrs({'boo': 32}) + tsplits = context.create_pseudo_split_of(main_task, [ + {'foo': 5}, + {'foo': 8}, + {'foo': 1}, + {'foo': 3}, + ]) + garbage_tasks = [] + garbage_splits = [] + for i in range(garbage_split_count): + gtask = context.create_pseudo_task_with_attrs({'boo': rng.randint(-111, 111)}) + garbage_tasks.append(gtask) + garbage_splits += context.create_pseudo_split_of(gtask, [{} for _ in range(rng.randint(1, 12))]) + garbage_tasks_all = garbage_tasks + garbage_splits + rng.shuffle(garbage_tasks_all) + def _process_garbage_tasks(node): + for _ in range(rng.randint(0, len(garbage_tasks_all) // 2)): + context.process_task(node, rng.choice(garbage_tasks_all)) + + for _ in range(100): + node = context.create_node('split_waiter', 'footest') + node.set_param_value('wait for all', do_wait) + node.set_param_value('transfer_attribs', 1) + node.set_param_value('transfer_type_0', 'extend') + node.set_param_value('src_attr_name_0', 'foo') + node.set_param_value('dst_attr_name_0', 'bar') + node.set_param_value('sort_by_0', '_builtin_id') + + _process_garbage_tasks(node) + + rng.shuffle(tsplits) + if do_wait: + for task in tsplits[:-1]: + self.assertIsNone(context.process_task(node, task)) + _process_garbage_tasks(node) + + last_res = context.process_task(node, tsplits[-1]) + self.assertIsNotNone(last_res) + + _process_garbage_tasks(node) + + ress = [] + for task in tsplits[:-1]: + res = context.process_task(node, task) + self.assertIsNotNone(res) + ress.append(res) + _process_garbage_tasks(node) + + ress.append(last_res) + + # check attrs + self.assertSetEqual({'boo'}, set(main_task.attributes().keys())) + self.assertEqual(32, main_task.attributes()['boo']) + self.assertListEqual([5, 8, 1, 3], ress[0].split_attributes_to_set['bar']) + self.assertTrue(ress[0].do_split_remove) + for res in ress[1:]: + self.assertFalse(res.do_split_remove) + for res in ress: + self.assertTrue(res.do_kill_task) + + else: # no wait for all + + res = context.process_task(node, tsplits[0]) + self.assertIsNotNone(res) + self.assertTrue(res.do_split_remove) + self.assertTrue(res.do_kill_task) + self.assertDictEqual({}, res.split_attributes_to_set) # no attribs are set in non-wait mode + for task in tsplits[1:]: + res = context.process_task(node, task) + self.assertIsNotNone(res) + self.assertFalse(res.do_split_remove) + self.assertTrue(res.do_kill_task) + self.assertDictEqual({}, res.split_attributes_to_set) + _process_garbage_tasks(node) + + # check cleanup + self.assertFalse(node._debug_has_internal_data_for_split(tsplits[0].task_dict()['split_id'])) + + await self._helper_test_node_with_arg_update( + _logic + ) diff --git a/tests/nodes/test_wedge.py b/tests/nodes/test_wedge.py new file mode 100644 index 00000000..ae1c8ea3 --- /dev/null +++ b/tests/nodes/test_wedge.py @@ -0,0 +1,147 @@ +import random +from asyncio import Event +from lifeblood.scheduler import Scheduler +from lifeblood.worker import Worker +from lifeblood.basenode import BaseNode, ProcessingError +from lifeblood.exceptions import NodeNotReadyToProcess +from .common import TestCaseBase, PseudoContext + +from typing import List + + +class TestWedge(TestCaseBase): + async def test_basic_functions_wait(self): + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + rng = random.Random(9377151) + for _ in range(200): + task = context.create_pseudo_task_with_attrs({'foo': rng.randint(-999, 999), 'bar': rng.uniform(-999, 999), 'googa': [rng.randint(-999, 999) for _ in range(rng.randint(1, 256))]}) + + rmode = 0 if rng.random() > 0.5 else 1 + rstart = rng.randint(-9999, 9999) if rng.random() > 0.5 else rng.uniform(-9999, 9999) + rend = rng.randint(-9999, 9999) if rng.random() > 0.5 else rng.uniform(-9999, 9999) + rcount = rng.randint(-999, 9999) + if rmode == 1 and rcount < 0: + rcount *= -1 + rmax = rend + rinc = (rmax - rstart) / (rcount - 1) if rcount > 1 else 2*(rmax - rstart) + if int(rinc) == rinc: + rinc = int(rinc) + + node = context.create_node('wedge', 'footest') + node.set_param_value('wedge count', 1) + node.set_param_value('wtype_0', rmode) # by count or by inc + node.set_param_value('attr_0', 'wooga') + node.set_param_value('from_0', rstart) + node.set_param_value('to_0', rend) + node.set_param_value('count_0', rcount) + node.set_param_value('max_0', rmax) + node.set_param_value('inc_0', rinc) + + if rcount < 0 and rmode == 0: # forbid negative count in count mode + self.assertRaises(ProcessingError, context.process_task, node, task) + continue + res = context.process_task(node, task) + + self.assertIsNotNone(res) + if rmode == 1 and (isinstance(rinc, float) or isinstance(rstart, float) or isinstance(rend, float)): # in inc case and float start/end we cannot guarantee + self.assertTrue(rcount == len(res._split_attribs) or rcount == len(res._split_attribs)+1, msg=f'{rcount} not eq to {len(res._split_attribs)}') + for i, attrs in enumerate(res._split_attribs): + if rcount == 1: + t = 1.0 + else: + t = i / (rcount - 1) + self.assertAlmostEqual(rstart*(1-t) + rend*(t), attrs['wooga'], msg=f'failed with {rstart}..{rend}, i={i}, cnt={rcount}') + + await self._helper_test_node_with_arg_update( + _logic + ) + + async def test_wedgings(self): + exp_pairs = [ + ( + (('att1', 0, 1, 10, 10, 0, 0),), + (10, [(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)]) + ), + ( + (('att1', 1, 1, 0, 0, 10, 1),), + (10, [(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)]) + ), + + ( + (('att1', 0, 1.5, 10.5, 10, 0, 0),), + (10, [(1.5,), (2.5,), (3.5,), (4.5,), (5.5,), (6.5,), (7.5,), (8.5,), (9.5,), (10.5,)]) + ), + ( + (('att1', 1, 1.5, 0, 0, 10.5, 1),), + (10, [(1.5,), (2.5,), (3.5,), (4.5,), (5.5,), (6.5,), (7.5,), (8.5,), (9.5,), (10.5,)]) + ), + + ( + (('att1', 0, 1/3, 1/3 + 9, 10, 0, 0),), + (10, [(1/3,), (1/3+1,), (1/3+2,), (1/3+3,), (1/3+4,), (1/3+5,), (1/3+6,), (1/3+7,), (1/3+8,), (1/3+9,)]) + ), + ( + (('att1', 1, 1/3, 0, 0, 1/3 + 9, 1),), + (10, [(1/3,), (1/3+1,), (1/3+2,), (1/3+3,), (1/3+4,), (1/3+5,), (1/3+6,), (1/3+7,), (1/3+8,), (1/3+9,)]) + ), + + ( + (('att1', 0, 1.567, 2.345, 6, 0, 0),), + (6, [(1.567,), (1.7226,), (1.8782,), (2.0338,), (2.1894,), (2.345,)]) + ), + ( + (('att1', 1, 1.567, 0, 0, 2.345, 0.1556),), + (6, [(1.567,), (1.7226,), (1.8782,), (2.0338,), (2.1894,), (2.345,)]) + ), + + ( + (('att1', 0, 1, 2.5, 4, 0, 0), ('att2', 0, 11, 15, 3, 0, 0)), + (12, [(1.0, 11), (1.0, 13), (1.0, 15), (1.5, 11), (1.5, 13), (1.5, 15), (2.0, 11), (2.0, 13), (2.0, 15), (2.5, 11), (2.5, 13), (2.5, 15)]) + ), + ( + (('att1', 1, 1, 0, 0, 2.5, 0.5), ('att2', 1, 11, 0, 0, 15, 2)), + (12, [(1.0, 11), (1.0, 13), (1.0, 15), (1.5, 11), (1.5, 13), (1.5, 15), (2.0, 11), (2.0, 13), (2.0, 15), (2.5, 11), (2.5, 13), (2.5, 15)]) + ), + ] + + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + + task = context.create_pseudo_task_with_attrs({}) + + for exp_inputs, exp_results in exp_pairs: + + node = context.create_node('wedge', 'footest') + node.set_param_value('wedge count', len(exp_inputs)) + attr_name_order = [] + for i, (attr_name, rmode, rfrom, rto, rcount, rmax, rinc) in enumerate(exp_inputs): + print(i, attr_name, rmode, rfrom, rto, rcount, rmax, rinc) + attr_name_order.append(attr_name) + node.set_param_value(f'wtype_{i}', rmode) # by count or by inc + node.set_param_value(f'attr_{i}', attr_name) + node.set_param_value(f'from_{i}', rfrom) + node.set_param_value(f'to_{i}', rto) + node.set_param_value(f'count_{i}', rcount) + node.set_param_value(f'max_{i}', rmax) + node.set_param_value(f'inc_{i}', rinc) + + res = context.process_task(node, task) + + self.assertIsNotNone(res) + attr_tuples = [tuple(attrs[n] for n in attr_name_order) for attrs in res._split_attribs] + self.assertEqual(exp_results[0], len(attr_tuples)) + + if exp_results[1] is not None: + # now some elaborate scheming to compare floats with AlmostEqual + if len(exp_results[1]) and any(isinstance(x, float) for x in exp_results[1][0]): + for ex, ac in zip(exp_results[1], attr_tuples): + for xex, xac in zip(ex, ac): + if isinstance(xex, float): + self.assertAlmostEqual(xex, xac, msg=f'float lists differ at: {ex} vs {ac}. from {exp_results} vs {attr_tuples}') + else: + self.assertEqual(xex, xac, msg=f'float lists differ at: {ex} vs {ac}. from {exp_results} vs {attr_tuples}') + else: + self.assertListEqual(exp_results[1], attr_tuples) + + await self._helper_test_node_with_arg_update( + _logic + )