-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from pedohorse/add-task-event-list-collapse
Add task event list collapse
- Loading branch information
Showing
3 changed files
with
299 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import copy | ||
|
||
from .ui_events import TaskFullState, TasksChanged, TasksRemoved, TasksUpdated, TaskEvent | ||
from .ui_protocol_data import TaskBatchData, DataNotSet, TaskData, TaskDelta | ||
|
||
from typing import Dict, List, Optional | ||
|
||
|
||
def collapse_task_event_list(event_list: List[TaskEvent]) -> Optional[TaskBatchData]: | ||
if len(event_list) == 0: | ||
return None | ||
collapsed_tasks: Dict[int, TaskData] = {} | ||
db_id = None | ||
event_id = None | ||
timestamp = None | ||
for event in event_list: | ||
if db_id is None: | ||
db_id = event.database_uid | ||
event_id = event.event_id | ||
timestamp = event.timestamp | ||
elif db_id != event.database_uid: | ||
raise RuntimeError('provided event list has events from different databases') | ||
event_id = max(event_id, event.event_id) | ||
timestamp = max(timestamp, event.timestamp) | ||
|
||
if isinstance(event, TaskFullState): | ||
collapsed_tasks = {k: copy.copy(v) for k, v in event.task_data.tasks.items()} | ||
elif isinstance(event, TasksRemoved): | ||
for task_id in event.task_ids: | ||
if task_id not in collapsed_tasks: | ||
raise RuntimeError(f'event list inconsistency: task id {task_id} is not in tasks, cannot remove') | ||
collapsed_tasks.pop(task_id) | ||
elif isinstance(event, TasksUpdated): | ||
for task_id, task_data in event.task_data.tasks.items(): | ||
collapsed_tasks[task_id] = copy.copy(task_data) | ||
elif isinstance(event, TasksChanged): | ||
for task_delta in event.task_deltas: | ||
task_id = task_delta.id | ||
if task_id not in collapsed_tasks: | ||
print(collapsed_tasks) | ||
raise RuntimeError(f'event list inconsistency: task id {task_id} is not in tasks, cannot apply delta') | ||
for field in TaskDelta.__annotations__.keys(): | ||
if (val := getattr(task_delta, field)) is not DataNotSet: | ||
if field == 'id': | ||
assert collapsed_tasks[task_id].id == val | ||
setattr(collapsed_tasks[task_id], field, val) | ||
else: | ||
raise NotImplementedError(f'handling of event type "{type(event)}" is not implemented') | ||
|
||
return TaskBatchData( | ||
db_id, | ||
collapsed_tasks | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import copy | ||
import random | ||
from unittest import TestCase | ||
from lifeblood.ui_events_tools import collapse_task_event_list | ||
from lifeblood.ui_events import TaskFullState, TasksChanged, TasksRemoved, TasksUpdated | ||
from lifeblood.ui_protocol_data import TaskData, TaskBatchData, TaskDelta | ||
from lifeblood.enums import TaskState | ||
|
||
|
||
class TestUIEventsTools(TestCase): | ||
def test_collapse_task_event_list_trivial(self): | ||
self.assertIsNone(collapse_task_event_list([])) | ||
|
||
data = TaskBatchData( | ||
12345, | ||
{ | ||
2: TaskData( | ||
2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'}, | ||
) | ||
} | ||
) | ||
fullstate = TaskFullState( | ||
12345, | ||
data, | ||
) | ||
|
||
self.assertEqual(data, collapse_task_event_list([fullstate])) | ||
|
||
def test_collapse_task_event_list_errors(self): | ||
self.assertRaises(RuntimeError, collapse_task_event_list, [ | ||
TaskFullState( | ||
12345, TaskBatchData(12345, {}) | ||
), | ||
TasksChanged( | ||
23456, [] | ||
), | ||
]) | ||
|
||
self.assertRaises(RuntimeError, collapse_task_event_list, [ | ||
TasksRemoved( | ||
12345, (1,) | ||
) | ||
]) | ||
|
||
self.assertRaises(RuntimeError, collapse_task_event_list, [ | ||
TasksChanged( | ||
12345, [ | ||
TaskDelta( | ||
2, | ||
children_count=123, | ||
), | ||
], | ||
) | ||
]) | ||
|
||
def test_collapse_task_event_list_common1(self): | ||
fullstate_init = TaskFullState( | ||
12345, | ||
TaskBatchData( | ||
12345, | ||
{ | ||
2: TaskData( | ||
2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'}, | ||
) | ||
} | ||
), | ||
) | ||
fulldata_final = TaskBatchData( | ||
12345, | ||
{ | ||
2: TaskData( | ||
2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'}, | ||
), | ||
22: TaskData( | ||
22, None, 44, 33, TaskState.POST_WAITING, 'beeba', True, 555, 'maino1', 'bleion1', 'bartask', 55, 66, 0.5678, 77, 88, 99, {'fgh'}, | ||
), | ||
} | ||
) | ||
event_list = [ | ||
fullstate_init, | ||
TasksUpdated( | ||
12345, | ||
TaskBatchData( | ||
12345, | ||
{ | ||
22: TaskData( | ||
22, None, 0, 33, TaskState.DONE, None, True, 555, 'maino1', '', 'bartask', 55, 66, 0, 77, 88, 99, set(), | ||
), | ||
} | ||
) | ||
), | ||
TasksChanged( | ||
12345, | ||
[ | ||
TaskDelta( | ||
22, | ||
children_count=44, | ||
state=TaskState.POST_WAITING, | ||
state_details='beeba', | ||
node_output_name='badbad', | ||
groups={'hhh'}, | ||
) | ||
] | ||
), | ||
TasksUpdated( | ||
12345, | ||
TaskBatchData( | ||
12345, | ||
{ | ||
33: TaskData( | ||
33, None, 0, 33, TaskState.DONE, None, True, 555, 'agageh', '', 'jgfjft', 55, 66, 0, 77, 88, 99, set(), | ||
), | ||
} | ||
) | ||
), | ||
TasksChanged( | ||
12345, | ||
[ | ||
TaskDelta( | ||
22, | ||
node_output_name='bleion1', | ||
progress=0.5678, | ||
groups={'fgh'}, | ||
) | ||
] | ||
), | ||
TasksRemoved( | ||
12345, | ||
(33,), | ||
) | ||
] | ||
|
||
self.assertEqual( | ||
fulldata_final, | ||
collapse_task_event_list(event_list) | ||
) | ||
|
||
def test_ensure_source_unmodified(self): | ||
update_event = TasksUpdated( | ||
12345, | ||
TaskBatchData( | ||
12345, | ||
{ | ||
123: TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},), | ||
} | ||
) | ||
) | ||
full_event = TaskFullState( | ||
12345, | ||
TaskBatchData( | ||
12345, | ||
{ | ||
123: TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'}, ), | ||
} | ||
) | ||
) | ||
delta_event = TasksChanged( | ||
12345, | ||
[ | ||
TaskDelta(123, children_count=999, split_origin_task_id=888, name='foooooooo') | ||
] | ||
) | ||
|
||
update_event_control = copy.deepcopy(update_event) | ||
full_event_control = copy.deepcopy(full_event) | ||
|
||
collapsed_data = collapse_task_event_list([full_event, delta_event]) | ||
self.assertIsNotNone(collapsed_data) | ||
collapsed_data = collapse_task_event_list([update_event, delta_event]) | ||
self.assertIsNotNone(collapsed_data) | ||
|
||
self.assertEqual(update_event_control, update_event) | ||
self.assertEqual(full_event_control, full_event) | ||
|
||
def test_random_change(self): | ||
rng = random.Random(1827361) | ||
for _ in range(999): | ||
fields = list(TaskDelta.__annotations__.keys()) | ||
rng.shuffle(fields) | ||
delta = TaskDelta(123) | ||
attrs_set = {} | ||
for field in fields[:rng.randint(0, len(fields))]: | ||
if field == 'id': | ||
continue | ||
# NOTE: we ignore typing, which may cause test fails on correct implementations | ||
val = random.randint(0, 99999) | ||
setattr(delta, field, val) | ||
attrs_set[field] = val | ||
|
||
task_data_control = TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},) | ||
task_data = TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},) | ||
event_list = [ | ||
TasksUpdated( | ||
12345, | ||
TaskBatchData( | ||
12345, | ||
{ | ||
123: task_data, | ||
} | ||
) | ||
), | ||
TasksChanged( | ||
12345, | ||
[ | ||
delta | ||
] | ||
) | ||
] | ||
|
||
collapsed_data = collapse_task_event_list(event_list) | ||
self.assertIsNotNone(collapsed_data) | ||
|
||
# ensure that original event was not changed | ||
self.assertEqual(task_data_control, task_data) | ||
|
||
self.assertSetEqual({123}, set(collapsed_data.tasks.keys())) | ||
for field in TaskDelta.__annotations__.keys(): | ||
if field in attrs_set: | ||
expected_val = attrs_set[field] | ||
else: | ||
expected_val = getattr(task_data_control, field) | ||
|
||
self.assertEqual(expected_val, getattr(collapsed_data.tasks[123], field), f'fail in "{field}" field') |