diff --git a/src/lifeblood/ui_events_tools.py b/src/lifeblood/ui_events_tools.py new file mode 100644 index 00000000..5f0ae9f9 --- /dev/null +++ b/src/lifeblood/ui_events_tools.py @@ -0,0 +1,54 @@ +import copy + +from .ui_events import TaskFullState, TasksChanged, TasksRemoved, TasksUpdated, TaskEvent +from .ui_protocol_data import TaskBatchData, DataNotSet, TaskData, TaskDelta + +from typing import Dict, List, Optional + + +def collapse_task_event_list(event_list: List[TaskEvent]) -> Optional[TaskBatchData]: + if len(event_list) == 0: + return None + collapsed_tasks: Dict[int, TaskData] = {} + db_id = None + event_id = None + timestamp = None + for event in event_list: + if db_id is None: + db_id = event.database_uid + event_id = event.event_id + timestamp = event.timestamp + elif db_id != event.database_uid: + raise RuntimeError('provided event list has events from different databases') + event_id = max(event_id, event.event_id) + timestamp = max(timestamp, event.timestamp) + + if isinstance(event, TaskFullState): + collapsed_tasks = {k: copy.copy(v) for k, v in event.task_data.tasks.items()} + elif isinstance(event, TasksRemoved): + for task_id in event.task_ids: + if task_id not in collapsed_tasks: + raise RuntimeError(f'event list inconsistency: task id {task_id} is not in tasks, cannot remove') + collapsed_tasks.pop(task_id) + elif isinstance(event, TasksUpdated): + for task_id, task_data in event.task_data.tasks.items(): + collapsed_tasks[task_id] = copy.copy(task_data) + elif isinstance(event, TasksChanged): + for task_delta in event.task_deltas: + task_id = task_delta.id + if task_id not in collapsed_tasks: + print(collapsed_tasks) + raise RuntimeError(f'event list inconsistency: task id {task_id} is not in tasks, cannot apply delta') + for field in TaskDelta.__annotations__.keys(): + if (val := getattr(task_delta, field)) is not DataNotSet: + if field == 'id': + assert collapsed_tasks[task_id].id == val + setattr(collapsed_tasks[task_id], field, val) + else: + raise NotImplementedError(f'handling of event type "{type(event)}" is not implemented') + + return TaskBatchData( + db_id, + collapsed_tasks + ) + diff --git a/src/lifeblood_viewer/connection_worker.py b/src/lifeblood_viewer/connection_worker.py index c6a81141..44e073ef 100644 --- a/src/lifeblood_viewer/connection_worker.py +++ b/src/lifeblood_viewer/connection_worker.py @@ -1,34 +1,30 @@ import asyncio import socket -import struct import json import time -import pickle -from io import BytesIO from lifeblood.uidata import NodeUi -from lifeblood.ui_protocol_data import UiData from lifeblood.invocationjob import InvocationJob -from lifeblood.nethelpers import recv_exactly, address_to_ip_port, get_default_addr +from lifeblood.nethelpers import address_to_ip_port, get_default_addr from lifeblood import logging -from lifeblood.enums import NodeParameterType, TaskState, TaskGroupArchivedState +from lifeblood.enums import TaskState, TaskGroupArchivedState from lifeblood.broadcasting import await_broadcast from lifeblood.config import get_config from lifeblood.exceptions import UiClientOperationFailed from lifeblood.uidata import Parameter -from lifeblood.node_type_metadata import NodeTypeMetadata from lifeblood.taskspawn import NewTask -from lifeblood.snippets import NodeSnippetData, NodeSnippetDataPlaceholder +from lifeblood.snippets import NodeSnippetData from lifeblood.defaults import ui_port from lifeblood.environment_resolver import EnvironmentResolverArguments from lifeblood.scheduler_ui_protocol import UIProtocolSocketClient from lifeblood.ui_protocol_data import TaskBatchData +from lifeblood.ui_events import TaskFullState +from lifeblood.ui_events_tools import collapse_task_event_list import PySide2 from PySide2.QtCore import Signal, Slot, QPointF, QThread -#from PySide2.QtGui import QPoin -from typing import Callable, Optional, Set, List, Union, Dict, Iterable +from typing import Callable, Optional, Set, List, Union, Iterable logger = logging.get_logger('viewer') @@ -398,8 +394,23 @@ def _check_tasks(self): assert len(task_events) > 0 # on subscription there MUST be at least a single event if len(task_events) > 0: - first_time_getting_events = self.__last_known_event_id < 0 + first_time_receiving_events_for_this_filter = self.__last_known_event_id < 0 self.__last_known_event_id = task_events[-1].event_id + if first_time_receiving_events_for_this_filter: + collapsed_data: Optional[TaskBatchData] = None + try: + collapsed_data = collapse_task_event_list(task_events) + except RuntimeError: + logger.warning("failed to collapse event list, event list malformed!") + if collapsed_data is not None: + subst_event = TaskFullState( + collapsed_data.db_uid, + collapsed_data + ) + subst_event.timestamp = task_events[-1].timestamp + subst_event.event_id = task_events[-1].event_id + task_events = [subst_event] + self.tasks_events_arrived.emit(task_events) else: tasks_state = self.__client.get_ui_tasks_state(self.__task_group_filter or [], not self.__skip_dead) diff --git a/tests/test_ui_events_tools.py b/tests/test_ui_events_tools.py new file mode 100644 index 00000000..2d3afff4 --- /dev/null +++ b/tests/test_ui_events_tools.py @@ -0,0 +1,223 @@ +import copy +import random +from unittest import TestCase +from lifeblood.ui_events_tools import collapse_task_event_list +from lifeblood.ui_events import TaskFullState, TasksChanged, TasksRemoved, TasksUpdated +from lifeblood.ui_protocol_data import TaskData, TaskBatchData, TaskDelta +from lifeblood.enums import TaskState + + +class TestUIEventsTools(TestCase): + def test_collapse_task_event_list_trivial(self): + self.assertIsNone(collapse_task_event_list([])) + + data = TaskBatchData( + 12345, + { + 2: TaskData( + 2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'}, + ) + } + ) + fullstate = TaskFullState( + 12345, + data, + ) + + self.assertEqual(data, collapse_task_event_list([fullstate])) + + def test_collapse_task_event_list_errors(self): + self.assertRaises(RuntimeError, collapse_task_event_list, [ + TaskFullState( + 12345, TaskBatchData(12345, {}) + ), + TasksChanged( + 23456, [] + ), + ]) + + self.assertRaises(RuntimeError, collapse_task_event_list, [ + TasksRemoved( + 12345, (1,) + ) + ]) + + self.assertRaises(RuntimeError, collapse_task_event_list, [ + TasksChanged( + 12345, [ + TaskDelta( + 2, + children_count=123, + ), + ], + ) + ]) + + def test_collapse_task_event_list_common1(self): + fullstate_init = TaskFullState( + 12345, + TaskBatchData( + 12345, + { + 2: TaskData( + 2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'}, + ) + } + ), + ) + fulldata_final = TaskBatchData( + 12345, + { + 2: TaskData( + 2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'}, + ), + 22: TaskData( + 22, None, 44, 33, TaskState.POST_WAITING, 'beeba', True, 555, 'maino1', 'bleion1', 'bartask', 55, 66, 0.5678, 77, 88, 99, {'fgh'}, + ), + } + ) + event_list = [ + fullstate_init, + TasksUpdated( + 12345, + TaskBatchData( + 12345, + { + 22: TaskData( + 22, None, 0, 33, TaskState.DONE, None, True, 555, 'maino1', '', 'bartask', 55, 66, 0, 77, 88, 99, set(), + ), + } + ) + ), + TasksChanged( + 12345, + [ + TaskDelta( + 22, + children_count=44, + state=TaskState.POST_WAITING, + state_details='beeba', + node_output_name='badbad', + groups={'hhh'}, + ) + ] + ), + TasksUpdated( + 12345, + TaskBatchData( + 12345, + { + 33: TaskData( + 33, None, 0, 33, TaskState.DONE, None, True, 555, 'agageh', '', 'jgfjft', 55, 66, 0, 77, 88, 99, set(), + ), + } + ) + ), + TasksChanged( + 12345, + [ + TaskDelta( + 22, + node_output_name='bleion1', + progress=0.5678, + groups={'fgh'}, + ) + ] + ), + TasksRemoved( + 12345, + (33,), + ) + ] + + self.assertEqual( + fulldata_final, + collapse_task_event_list(event_list) + ) + + def test_ensure_source_unmodified(self): + update_event = TasksUpdated( + 12345, + TaskBatchData( + 12345, + { + 123: TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},), + } + ) + ) + full_event = TaskFullState( + 12345, + TaskBatchData( + 12345, + { + 123: TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'}, ), + } + ) + ) + delta_event = TasksChanged( + 12345, + [ + TaskDelta(123, children_count=999, split_origin_task_id=888, name='foooooooo') + ] + ) + + update_event_control = copy.deepcopy(update_event) + full_event_control = copy.deepcopy(full_event) + + collapsed_data = collapse_task_event_list([full_event, delta_event]) + self.assertIsNotNone(collapsed_data) + collapsed_data = collapse_task_event_list([update_event, delta_event]) + self.assertIsNotNone(collapsed_data) + + self.assertEqual(update_event_control, update_event) + self.assertEqual(full_event_control, full_event) + + def test_random_change(self): + rng = random.Random(1827361) + for _ in range(999): + fields = list(TaskDelta.__annotations__.keys()) + rng.shuffle(fields) + delta = TaskDelta(123) + attrs_set = {} + for field in fields[:rng.randint(0, len(fields))]: + if field == 'id': + continue + # NOTE: we ignore typing, which may cause test fails on correct implementations + val = random.randint(0, 99999) + setattr(delta, field, val) + attrs_set[field] = val + + task_data_control = TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},) + task_data = TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},) + event_list = [ + TasksUpdated( + 12345, + TaskBatchData( + 12345, + { + 123: task_data, + } + ) + ), + TasksChanged( + 12345, + [ + delta + ] + ) + ] + + collapsed_data = collapse_task_event_list(event_list) + self.assertIsNotNone(collapsed_data) + + # ensure that original event was not changed + self.assertEqual(task_data_control, task_data) + + self.assertSetEqual({123}, set(collapsed_data.tasks.keys())) + for field in TaskDelta.__annotations__.keys(): + if field in attrs_set: + expected_val = attrs_set[field] + else: + expected_val = getattr(task_data_control, field) + + self.assertEqual(expected_val, getattr(collapsed_data.tasks[123], field), f'fail in "{field}" field')