Skip to content

Commit

Permalink
Merge pull request #18 from pedohorse/split_nodes-tests
Browse files Browse the repository at this point in the history
Split nodes tests
  • Loading branch information
pedohorse authored Nov 22, 2023
2 parents bc78946 + 82ca2f7 commit 2d11430
Show file tree
Hide file tree
Showing 7 changed files with 658 additions and 95 deletions.
63 changes: 26 additions & 37 deletions src/lifeblood/core_nodes/attribute_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,17 @@ def _ranges_from_chunksize(start, end, inc, chunk_size):
raise ProcessingError('chunk size must be positive')
ranges = []
cur = start
while cur < end:
count_till_end = math.ceil((end - cur) / inc)
while cur <= end:
count_till_end = math.ceil((end - cur + 1) / inc)
if count_till_end == 0:
break
ranges.append((cur, min(end, cur + inc * min(count_till_end, chunk_size)), min(count_till_end, chunk_size)))
cur = ranges[-1][1]
print(count_till_end)
ranges.append((cur, cur + inc * (min(count_till_end, chunk_size) - 1), min(count_till_end, chunk_size)))
cur = ranges[-1][1] + inc

# just set "end" for the last range
if ranges:
ranges[-1] = tuple(end if i == 1 else x for i, x in enumerate(ranges[-1]))
return ranges

@staticmethod
Expand All @@ -97,13 +102,19 @@ def _ranges_from_chunkcount(start, end, inc, chunk_count):
raise ProcessingError('chunk count must be positive')
ranges = []
print(f'- {start}, {end} :{inc}')
chunk_count = min(chunk_count, int((end-start)/inc)+1)
for i in range(chunk_count):
e1 = start + math.ceil(((end - start) / chunk_count * i) / inc) * inc
e2 = min(end, start + math.ceil(((end - start) / chunk_count * (i + 1)) / inc) * inc)
print(e1, e2)
if e1 >= e2: # case where chunk_count is bigger than possible
continue
new_range = (e1, e2, math.ceil((e2-e1)/inc))
e2 = start + math.ceil(((end - start) / chunk_count * (i + 1)) / inc) * inc
# if e1 >= e2: # case where chunk_count is bigger than possible
# continue
adj = 0
if i == chunk_count - 1:
e2 = end
else:
adj = -inc
print(e1, e2 + adj, math.ceil((e2-e1)/inc))
new_range = (e1, e2 + adj, math.ceil((e2-e1)/inc))
ranges.append(new_range)
return ranges

Expand All @@ -120,8 +131,13 @@ def process_task(self, context):
raise ProcessingError(f'attribute "{attr_name}" must be a list')
res = ProcessingResult()
chunksize = context.param_value('chunk size')
if chunksize <= 0:
raise ProcessingError(f'chunk size cannot be less or equal to zero, got: {chunksize}')

split_into = 1 + (len(attr_value) - 1) // chunksize
if len(attr_value) <= 1:
split_into = 1
else:
split_into = 1 + (len(attr_value) - 1) // chunksize
# yes we can split into 1 part. this should behave the same way as when splitting into multiple parts
# in order to have consistent behaviour in the graph
res.split_task(split_into)
Expand Down Expand Up @@ -170,30 +186,3 @@ def process_task(self, context):
return res
else:
raise NotImplementedError(f'mode "{mode}" is not implemented')


if __name__ == '__main__':
# some fast tests
r = FramerangeSplitter._ranges_from_chunkcount(10, 12, 3, 10)
assert r == [(10, 12, 1)], r
r = FramerangeSplitter._ranges_from_chunkcount(10, 15, 3, 10)
assert r == [(10, 13, 1), (13, 15, 1)], r
r = FramerangeSplitter._ranges_from_chunkcount(15, 16, 3, 10)
assert r == [(15, 16, 1)], r
r = FramerangeSplitter._ranges_from_chunkcount(15, 29, 3, 3)
assert r == [(15, 21, 2), (21, 27, 2), (27, 29, 1)], r
r = FramerangeSplitter._ranges_from_chunkcount(15, 29, 3, 4)
assert r == [(15, 21, 2), (21, 24, 1), (24, 27, 1), (27, 29, 1)], r

r = FramerangeSplitter._ranges_from_chunksize(10, 12, 3, 1)
assert r == [(10, 12, 1)], r
r = FramerangeSplitter._ranges_from_chunksize(10, 12, 3, 10)
assert r == [(10, 12, 1)], r
r = FramerangeSplitter._ranges_from_chunksize(10, 15, 3, 1)
assert r == [(10, 13, 1), (13, 15, 1)], r
r = FramerangeSplitter._ranges_from_chunksize(10, 15, 3, 10)
assert r == [(10, 15, 2)], r
r = FramerangeSplitter._ranges_from_chunksize(15, 16, 3, 10)
assert r == [(15, 16, 1)], r
r = FramerangeSplitter._ranges_from_chunksize(15, 38, 3, 3)
assert r == [(15, 24, 3), (24, 33, 3), (33, 38, 2)], r
114 changes: 63 additions & 51 deletions src/lifeblood/core_nodes/split_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,61 @@ def __init__(self, name: str):
ui.add_parameter('sort_by', None, NodeParameterType.STRING, '_builtin_id')
ui.add_parameter('reversed', 'reversed', NodeParameterType.BOOL, False)

def __get_promote_attribs(self, context):
attribs_to_promote = {}
split_id = context.task_field('split_id')
num_attribs = context.param_value('transfer_attribs')
for i in range(num_attribs):
src_attr_name = context.param_value(f'src_attr_name_{i}')
transfer_type = context.param_value(f'transfer_type_{i}')
dst_attr_name = context.param_value(f'dst_attr_name_{i}')
sort_attr_name = context.param_value(f'sort_by_{i}')
sort_reversed = context.param_value(f'reversed_{i}')
if transfer_type == 'append':
gathered_values = []
for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed):
if src_attr_name not in attribs:
continue

attr_val = attribs[src_attr_name]
gathered_values.append(attr_val)
attribs_to_promote[dst_attr_name] = gathered_values
elif transfer_type == 'extend':
gathered_values = []
for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed):
if src_attr_name not in attribs:
continue

attr_val = attribs[src_attr_name]
if isinstance(attr_val, list):
gathered_values.extend(attr_val)
else:
gathered_values.append(attr_val)
attribs_to_promote[dst_attr_name] = gathered_values
elif transfer_type == 'first':
_acd = self.__cache[split_id]['arrived']
if len(_acd) > 0:
if sort_reversed:
attribs = max(_acd.values(), key=lambda x: x.get(sort_attr_name, 0))
else:
attribs = min(_acd.values(), key=lambda x: x.get(sort_attr_name, 0))
if src_attr_name in attribs:
attribs_to_promote[dst_attr_name] = attribs[src_attr_name]
elif transfer_type == 'sum':
# we don't care about the order, assume sum is associative
gathered_values = None
for attribs in self.__cache[split_id]['arrived'].values():
if src_attr_name not in attribs:
continue
if gathered_values is None:
gathered_values = attribs[src_attr_name]
else:
gathered_values += attribs[src_attr_name]
attribs_to_promote[dst_attr_name] = gathered_values
else:
raise NotImplementedError(f'transfer type "{transfer_type}" is not implemented')
return attribs_to_promote

def ready_to_process_task(self, task_dict) -> bool:
context = ProcessingContext(self, task_dict)
split_id = context.task_field('split_id')
Expand Down Expand Up @@ -96,59 +151,9 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib
res = ProcessingResult()
res.kill_task()
self.__cache[split_id]['processed'].add(context.task_field('split_element'))
attribs_to_promote = {}
if self.__cache[split_id]['first_to_arrive'] == task_id:
# transfer attributes # TODO: delete cache for already processed splits
num_attribs = context.param_value('transfer_attribs')
for i in range(num_attribs):
src_attr_name = context.param_value(f'src_attr_name_{i}')
transfer_type = context.param_value(f'transfer_type_{i}')
dst_attr_name = context.param_value(f'dst_attr_name_{i}')
sort_attr_name = context.param_value(f'sort_by_{i}')
sort_reversed = context.param_value(f'reversed_{i}')
if transfer_type == 'append':
gathered_values = []
for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed):
if src_attr_name not in attribs:
continue

attr_val = attribs[src_attr_name]
gathered_values.append(attr_val)
attribs_to_promote[dst_attr_name] = gathered_values
elif transfer_type == 'extend':
gathered_values = []
for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed):
if src_attr_name not in attribs:
continue

attr_val = attribs[src_attr_name]
if isinstance(attr_val, list):
gathered_values.extend(attr_val)
else:
gathered_values.append(attr_val)
attribs_to_promote[dst_attr_name] = gathered_values
elif transfer_type == 'first':
_acd = self.__cache[split_id]['arrived']
if len(_acd) > 0:
if sort_reversed:
attribs = max(_acd.values(), key=lambda x: x.get(sort_attr_name, 0))
else:
attribs = min(_acd.values(), key=lambda x: x.get(sort_attr_name, 0))
if src_attr_name in attribs:
attribs_to_promote[dst_attr_name] = attribs[src_attr_name]
elif transfer_type == 'sum':
# we don't care about the order, assume sum is associative
gathered_values = None
for attribs in self.__cache[split_id]['arrived'].values():
if src_attr_name not in attribs:
continue
if gathered_values is None:
gathered_values = attribs[src_attr_name]
else:
gathered_values += attribs[src_attr_name]
attribs_to_promote[dst_attr_name] = gathered_values
else:
raise NotImplementedError(f'transfer type "{transfer_type}" is not implemented')
attribs_to_promote = self.__get_promote_attribs(context)

res.remove_split(attributes_to_set=attribs_to_promote)
changed = True
Expand Down Expand Up @@ -176,6 +181,13 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib
def postprocess_task(self, context) -> ProcessingResult:
return ProcessingResult()

def _debug_has_internal_data_for_split(self, split_id: int) -> bool:
"""
method for debug/testing purposes.
returns True if some internal data is present for a given split_id
"""
return split_id in self.__cache

def __getstate__(self):
d = super(SplitAwaiterNode, self).__getstate__()
assert '_SplitAwaiterNode__main_lock' in d
Expand Down
35 changes: 32 additions & 3 deletions src/lifeblood/core_nodes/wedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def process_task(self, context) -> ProcessingResult:
for i in range(wedges_count):
wtype = context.param_value(f'wtype_{i}')
if wtype == 0:
wedge_ranges.append((0, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'to_{i}'), context.param_value(f'count_{i}')))
count = context.param_value(f'count_{i}')
if count <= 0:
raise ProcessingError('count cannot be less or equal to zero')
wedge_ranges.append((0, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'to_{i}'), count))
elif wtype == 1:
wedge_ranges.append((1, context.param_value(f'attr_{i}'), context.param_value(f'from_{i}'), context.param_value(f'max_{i}'), context.param_value(f'inc_{i}')))
else:
Expand All @@ -63,13 +66,37 @@ def _do_iter(cur_vals, level=0):
return
if wedge_ranges[level][0] == 0:
_, attr, fr, to, cnt = wedge_ranges[level]
# first check if we actually work with ints
if int(fr) == fr:
fr = int(fr)
if int(to) == to:
to = int(to)
inc = None
if cnt > 1:
finc = (to - fr) / (cnt - 1)
if int(finc) == finc:
inc = int(finc)

# now create variations
for i in range(cnt):
new_vals = cur_vals.copy()
t = i * 1.0 / (cnt-1)
new_vals[attr] = fr*(1-t) + to*t
if inc is None:
t = i * 1.0 / (cnt-1)
new_vals[attr] = fr*(1-t) + to*t
else:
new_vals[attr] = fr + i*inc
_do_iter(new_vals, level+1)
elif wedge_ranges[level][0] == 1:
_, attr, fr, to, inc = wedge_ranges[level]
# first check if we actually work with ints
if int(inc) == inc:
inc = int(inc)
if int(fr) == fr:
fr = int(fr)
if int(to) == to:
to = int(to)

# now create variations
if inc == 0:
raise ProcessingError('increment cannot be zero')
elif inc > 0:
Expand All @@ -88,6 +115,8 @@ def _do_iter(cur_vals, level=0):
_do_iter({})

res = ProcessingResult()
if len(all_wedges) == 0:
raise ProcessingError('unexpectedly no wedges were created')
res.split_task(len(all_wedges))
for i, attrs in enumerate(all_wedges):
res.set_split_task_attribs(i, attrs)
Expand Down
25 changes: 21 additions & 4 deletions tests/nodes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def children_ids_for(self, task_id, active_only=False) -> List[int]:


class PseudoTask:
def __init__(self, pool: PseudoTaskPool, task_id: int, attrs: dict, parent_id: Optional[int] = None, state: Optional[TaskState] = None):
def __init__(self, pool: PseudoTaskPool, task_id: int, attrs: dict, parent_id: Optional[int] = None, state: Optional[TaskState] = None, extra_fields: Optional[dict] = None):
self.__id = task_id
self.__parent_id = parent_id
self.__state = state or TaskState.GENERATING
Expand All @@ -52,7 +52,8 @@ def __init__(self, pool: PseudoTaskPool, task_id: int, attrs: dict, parent_id: O
'node_input_name': self.__input_name,
'state': self.__state.value,
'parent_id': parent_id,
'attributes': json.dumps(attrs)
'attributes': json.dumps(attrs),
**(extra_fields or {})
}

def id(self) -> int:
Expand Down Expand Up @@ -104,16 +105,32 @@ def __init__(self, scheduler: Scheduler):
self.__scheduler = scheduler
self.__last_node_id = 135 # cuz why starting with zero?
self.__last_task_id = 468
self.__last_split_id = 765
self.__tasks: Dict[int, PseudoTask] = {}

def create_pseudo_task_with_attrs(self, attrs: dict, task_id: Optional[int] = None, parent_id: Optional[int] = None, state: Optional[TaskState] = None) -> PseudoTask:
def create_pseudo_task_with_attrs(self, attrs: dict, task_id: Optional[int] = None, parent_id: Optional[int] = None, state: Optional[TaskState] = None, extra_fields: Optional[dict] = None) -> PseudoTask:
if task_id is None:
self.__last_task_id += 1
task_id = self.__last_task_id
task = PseudoTask(self, task_id, attrs, parent_id, state)
task = PseudoTask(self, task_id, attrs, parent_id, state, extra_fields)
self.__tasks[task_id] = task
return task

def create_pseudo_split_of(self, pseudo_task: PseudoTask, attribs_list):
"""
split count will be taken from attrib_list's len
"""
tasks = []
self.__last_split_id += 1
for i, attribs in enumerate(attribs_list):
tasks.append(self.create_pseudo_task_with_attrs(attribs, task_id=None, parent_id=None, state=TaskState.GENERATING, extra_fields={
'split_id': self.__last_split_id,
'split_count': len(attribs_list),
'split_origin_task_id': pseudo_task.id(),
'split_element': i,
}))
return tasks

def create_node(self, node_type: str, node_name: str) -> BaseNode:
self.__last_node_id += 1
return create_node(node_type, node_name, self.__scheduler, self.__last_node_id)
Expand Down
Loading

0 comments on commit 2d11430

Please sign in to comment.