Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add task event list collapse #106

Merged
merged 6 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/lifeblood/ui_events_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import copy

from .ui_events import TaskFullState, TasksChanged, TasksRemoved, TasksUpdated, TaskEvent
from .ui_protocol_data import TaskBatchData, DataNotSet, TaskData, TaskDelta

from typing import Dict, List, Optional


def collapse_task_event_list(event_list: List[TaskEvent]) -> Optional[TaskBatchData]:
if len(event_list) == 0:
return None
collapsed_tasks: Dict[int, TaskData] = {}
db_id = None
event_id = None
timestamp = None
for event in event_list:
if db_id is None:
db_id = event.database_uid
event_id = event.event_id
timestamp = event.timestamp
elif db_id != event.database_uid:
raise RuntimeError('provided event list has events from different databases')
event_id = max(event_id, event.event_id)
timestamp = max(timestamp, event.timestamp)

if isinstance(event, TaskFullState):
collapsed_tasks = {k: copy.copy(v) for k, v in event.task_data.tasks.items()}
elif isinstance(event, TasksRemoved):
for task_id in event.task_ids:
if task_id not in collapsed_tasks:
raise RuntimeError(f'event list inconsistency: task id {task_id} is not in tasks, cannot remove')
collapsed_tasks.pop(task_id)
elif isinstance(event, TasksUpdated):
for task_id, task_data in event.task_data.tasks.items():
collapsed_tasks[task_id] = copy.copy(task_data)
elif isinstance(event, TasksChanged):
for task_delta in event.task_deltas:
task_id = task_delta.id
if task_id not in collapsed_tasks:
print(collapsed_tasks)
raise RuntimeError(f'event list inconsistency: task id {task_id} is not in tasks, cannot apply delta')
for field in TaskDelta.__annotations__.keys():
if (val := getattr(task_delta, field)) is not DataNotSet:
if field == 'id':
assert collapsed_tasks[task_id].id == val
setattr(collapsed_tasks[task_id], field, val)
else:
raise NotImplementedError(f'handling of event type "{type(event)}" is not implemented')

return TaskBatchData(
db_id,
collapsed_tasks
)

33 changes: 22 additions & 11 deletions src/lifeblood_viewer/connection_worker.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
import asyncio
import socket
import struct
import json
import time
import pickle
from io import BytesIO

from lifeblood.uidata import NodeUi
from lifeblood.ui_protocol_data import UiData
from lifeblood.invocationjob import InvocationJob
from lifeblood.nethelpers import recv_exactly, address_to_ip_port, get_default_addr
from lifeblood.nethelpers import address_to_ip_port, get_default_addr
from lifeblood import logging
from lifeblood.enums import NodeParameterType, TaskState, TaskGroupArchivedState
from lifeblood.enums import TaskState, TaskGroupArchivedState
from lifeblood.broadcasting import await_broadcast
from lifeblood.config import get_config
from lifeblood.exceptions import UiClientOperationFailed
from lifeblood.uidata import Parameter
from lifeblood.node_type_metadata import NodeTypeMetadata
from lifeblood.taskspawn import NewTask
from lifeblood.snippets import NodeSnippetData, NodeSnippetDataPlaceholder
from lifeblood.snippets import NodeSnippetData
from lifeblood.defaults import ui_port
from lifeblood.environment_resolver import EnvironmentResolverArguments
from lifeblood.scheduler_ui_protocol import UIProtocolSocketClient
from lifeblood.ui_protocol_data import TaskBatchData
from lifeblood.ui_events import TaskFullState
from lifeblood.ui_events_tools import collapse_task_event_list

import PySide2
from PySide2.QtCore import Signal, Slot, QPointF, QThread
#from PySide2.QtGui import QPoin

from typing import Callable, Optional, Set, List, Union, Dict, Iterable
from typing import Callable, Optional, Set, List, Union, Iterable


logger = logging.get_logger('viewer')
Expand Down Expand Up @@ -398,8 +394,23 @@ def _check_tasks(self):
assert len(task_events) > 0 # on subscription there MUST be at least a single event

if len(task_events) > 0:
first_time_getting_events = self.__last_known_event_id < 0
first_time_receiving_events_for_this_filter = self.__last_known_event_id < 0
self.__last_known_event_id = task_events[-1].event_id
if first_time_receiving_events_for_this_filter:
collapsed_data: Optional[TaskBatchData] = None
try:
collapsed_data = collapse_task_event_list(task_events)
except RuntimeError:
logger.warning("failed to collapse event list, event list malformed!")
if collapsed_data is not None:
subst_event = TaskFullState(
collapsed_data.db_uid,
collapsed_data
)
subst_event.timestamp = task_events[-1].timestamp
subst_event.event_id = task_events[-1].event_id
task_events = [subst_event]

self.tasks_events_arrived.emit(task_events)
else:
tasks_state = self.__client.get_ui_tasks_state(self.__task_group_filter or [], not self.__skip_dead)
Expand Down
223 changes: 223 additions & 0 deletions tests/test_ui_events_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import copy
import random
from unittest import TestCase
from lifeblood.ui_events_tools import collapse_task_event_list
from lifeblood.ui_events import TaskFullState, TasksChanged, TasksRemoved, TasksUpdated
from lifeblood.ui_protocol_data import TaskData, TaskBatchData, TaskDelta
from lifeblood.enums import TaskState


class TestUIEventsTools(TestCase):
def test_collapse_task_event_list_trivial(self):
self.assertIsNone(collapse_task_event_list([]))

data = TaskBatchData(
12345,
{
2: TaskData(
2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'},
)
}
)
fullstate = TaskFullState(
12345,
data,
)

self.assertEqual(data, collapse_task_event_list([fullstate]))

def test_collapse_task_event_list_errors(self):
self.assertRaises(RuntimeError, collapse_task_event_list, [
TaskFullState(
12345, TaskBatchData(12345, {})
),
TasksChanged(
23456, []
),
])

self.assertRaises(RuntimeError, collapse_task_event_list, [
TasksRemoved(
12345, (1,)
)
])

self.assertRaises(RuntimeError, collapse_task_event_list, [
TasksChanged(
12345, [
TaskDelta(
2,
children_count=123,
),
],
)
])

def test_collapse_task_event_list_common1(self):
fullstate_init = TaskFullState(
12345,
TaskBatchData(
12345,
{
2: TaskData(
2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'},
)
}
),
)
fulldata_final = TaskBatchData(
12345,
{
2: TaskData(
2, None, 4, 3, TaskState.WAITING, 'bla blaa', False, 55, 'maino', 'bleion', 'footask', 5, 6, 0.567, 7, 8, 9, {'qwe', 'asd', 'zxc'},
),
22: TaskData(
22, None, 44, 33, TaskState.POST_WAITING, 'beeba', True, 555, 'maino1', 'bleion1', 'bartask', 55, 66, 0.5678, 77, 88, 99, {'fgh'},
),
}
)
event_list = [
fullstate_init,
TasksUpdated(
12345,
TaskBatchData(
12345,
{
22: TaskData(
22, None, 0, 33, TaskState.DONE, None, True, 555, 'maino1', '', 'bartask', 55, 66, 0, 77, 88, 99, set(),
),
}
)
),
TasksChanged(
12345,
[
TaskDelta(
22,
children_count=44,
state=TaskState.POST_WAITING,
state_details='beeba',
node_output_name='badbad',
groups={'hhh'},
)
]
),
TasksUpdated(
12345,
TaskBatchData(
12345,
{
33: TaskData(
33, None, 0, 33, TaskState.DONE, None, True, 555, 'agageh', '', 'jgfjft', 55, 66, 0, 77, 88, 99, set(),
),
}
)
),
TasksChanged(
12345,
[
TaskDelta(
22,
node_output_name='bleion1',
progress=0.5678,
groups={'fgh'},
)
]
),
TasksRemoved(
12345,
(33,),
)
]

self.assertEqual(
fulldata_final,
collapse_task_event_list(event_list)
)

def test_ensure_source_unmodified(self):
update_event = TasksUpdated(
12345,
TaskBatchData(
12345,
{
123: TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},),
}
)
)
full_event = TaskFullState(
12345,
TaskBatchData(
12345,
{
123: TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'}, ),
}
)
)
delta_event = TasksChanged(
12345,
[
TaskDelta(123, children_count=999, split_origin_task_id=888, name='foooooooo')
]
)

update_event_control = copy.deepcopy(update_event)
full_event_control = copy.deepcopy(full_event)

collapsed_data = collapse_task_event_list([full_event, delta_event])
self.assertIsNotNone(collapsed_data)
collapsed_data = collapse_task_event_list([update_event, delta_event])
self.assertIsNotNone(collapsed_data)

self.assertEqual(update_event_control, update_event)
self.assertEqual(full_event_control, full_event)

def test_random_change(self):
rng = random.Random(1827361)
for _ in range(999):
fields = list(TaskDelta.__annotations__.keys())
rng.shuffle(fields)
delta = TaskDelta(123)
attrs_set = {}
for field in fields[:rng.randint(0, len(fields))]:
if field == 'id':
continue
# NOTE: we ignore typing, which may cause test fails on correct implementations
val = random.randint(0, 99999)
setattr(delta, field, val)
attrs_set[field] = val

task_data_control = TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},)
task_data = TaskData(123, 234, 444, 333, TaskState.GENERATING, 'nope', True, 345, 'floo', 'flee', 'nonde', 456, 567, 0.51423, 678, 789, 890, {'karrr'},)
event_list = [
TasksUpdated(
12345,
TaskBatchData(
12345,
{
123: task_data,
}
)
),
TasksChanged(
12345,
[
delta
]
)
]

collapsed_data = collapse_task_event_list(event_list)
self.assertIsNotNone(collapsed_data)

# ensure that original event was not changed
self.assertEqual(task_data_control, task_data)

self.assertSetEqual({123}, set(collapsed_data.tasks.keys()))
for field in TaskDelta.__annotations__.keys():
if field in attrs_set:
expected_val = attrs_set[field]
else:
expected_val = getattr(task_data_control, field)

self.assertEqual(expected_val, getattr(collapsed_data.tasks[123], field), f'fail in "{field}" field')
Loading