From 82adbb4f120e61aa52e86e378bbcd3a11b9e97ef Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:53:10 +0100 Subject: [PATCH 01/10] split scheduler into core and main --- src/lifeblood/scheduler/scheduler.py | 1865 +-------------------- src/lifeblood/scheduler/scheduler_core.py | 1863 ++++++++++++++++++++ 2 files changed, 1877 insertions(+), 1851 deletions(-) create mode 100644 src/lifeblood/scheduler/scheduler_core.py diff --git a/src/lifeblood/scheduler/scheduler.py b/src/lifeblood/scheduler/scheduler.py index 4357d703..f9934bb4 100644 --- a/src/lifeblood/scheduler/scheduler.py +++ b/src/lifeblood/scheduler/scheduler.py @@ -1,1864 +1,27 @@ -import os -from pathlib import Path -import time -from datetime import datetime -import json -import itertools -import asyncio -import aiosqlite -import aiofiles -from aiorwlock import RWLock -from contextlib import asynccontextmanager - -from .. import logging -from ..nodegraph_holder_base import NodeGraphHolderBase -from ..attribute_serialization import serialize_attributes, deserialize_attributes -from ..worker_messsage_processor import WorkerControlClient -from ..scheduler_task_protocol import SchedulerTaskProtocol, SpawnStatus -from ..scheduler_ui_protocol import SchedulerUiProtocol -from ..hardware_resources import HardwareResources -from ..invocationjob import Invocation, InvocationJob, Requirements -from ..environment_resolver import EnvironmentResolverArguments -from ..broadcasting import create_broadcaster -from ..simple_worker_pool import WorkerPool -from ..nethelpers import get_broadcast_addr_for, all_interfaces -from ..worker_metadata import WorkerMetadata -from ..taskspawn import TaskSpawn -from ..basenode import BaseNode -from ..exceptions import * from ..node_dataprovider_base import NodeDataProvider -from ..basenode_serialization import NodeSerializerBase, IncompatibleDeserializationMethod, FailedToDeserialize -from ..enums import WorkerState, WorkerPingState, TaskState, InvocationState, WorkerType, \ - SchedulerMode, TaskGroupArchivedState -from .. import aiosqlite_overlay -from ..ui_protocol_data import TaskData, TaskDelta, IncompleteInvocationLogData, InvocationLogData +from ..basenode_serialization import NodeSerializerBase -from ..net_messages.address import DirectAddress, AddressChain -from ..scheduler_message_processor import SchedulerMessageProcessor from ..scheduler_config_provider_base import SchedulerConfigProviderBase +from ..scheduler_task_protocol import SchedulerTaskProtocol +from ..scheduler_ui_protocol import SchedulerUiProtocol +from ..scheduler_message_processor import SchedulerMessageProcessor -from .data_access import DataAccess -from .scheduler_component_base import SchedulerComponentBase -from .pinger import Pinger -from .task_processor import TaskProcessor -from .ui_state_accessor import UIStateAccessor +from .scheduler_core import SchedulerCore -from typing import Optional, Any, Tuple, List, Iterable, Union, Dict, Set +from typing import List -class Scheduler(NodeGraphHolderBase): +class Scheduler(SchedulerCore): def __init__(self, *, scheduler_config_provider: SchedulerConfigProviderBase, node_data_provider: NodeDataProvider, node_serializers: List[NodeSerializerBase], ): - """ - TODO: add a docstring - - :param scheduler_config_provider: - """ - self.__node_data_provider: NodeDataProvider = node_data_provider - if len(node_serializers) < 1: - raise ValueError('at least one serializer must be provided!') - self.__node_serializers = list(node_serializers) - self.__logger = logging.get_logger('scheduler') - self.__logger.info('loading core plugins') - self.__node_objects: Dict[int, BaseNode] = {} - self.__node_objects_locks: Dict[int, RWLock] = {} - self.__node_objects_creation_locks: Dict[int, asyncio.Lock] = {} - self.__config_provider: SchedulerConfigProviderBase = scheduler_config_provider - - # this lock will prevent tasks from being reported cancelled and done at the same exact time should that ever happen - # this lock is overkill already, but we can make it even more overkill by using set of locks for each invoc id - # which would be completely useless now cuz sqlite locks DB as a whole, not even a single table, especially not just parts of table - self.__invocation_reporting_lock = asyncio.Lock() - - self.__all_components = None - self.__started_event = asyncio.Event() - - self.__db_path = scheduler_config_provider.main_database_location() - if not self.__db_path.startswith('file:'): # if schema is used - we do not modify the db uri in any way - self.__db_path = os.path.realpath(os.path.expanduser(self.__db_path)) - self.__logger.debug(f'starting scheduler with database: {self.__db_path}') - self.data_access: DataAccess = DataAccess( - config_provider=self.__config_provider, + super().__init__( + scheduler_config_provider=scheduler_config_provider, + node_data_provider=node_data_provider, + node_serializers=node_serializers, + message_processor_factory=SchedulerMessageProcessor, + legacy_task_protocol_factory=SchedulerTaskProtocol, + ui_protocol_factory=SchedulerUiProtocol, ) - ## - - self.__use_external_log = self.__config_provider.external_log_location() is not None - self.__external_log_location: Optional[Path] = self.__config_provider.external_log_location() - if self.__use_external_log: - external_log_path = Path(self.__use_external_log) - if external_log_path.exists() and external_log_path.is_file(): - external_log_path.unlink() - if not external_log_path.exists(): - external_log_path.mkdir(parents=True) - if not os.access(self.__external_log_location, os.X_OK | os.W_OK): - raise RuntimeError('cannot write to external log location provided') - - self.__pinger: Pinger = Pinger(self) - self.task_processor: TaskProcessor = TaskProcessor(self) - self.ui_state_access: UIStateAccessor = UIStateAccessor(self) - - self.__message_processor_addresses = [] - self.__ui_address = None - self.__legacy_command_server_address = None - - legacy_server_ip, legacy_server_port = self.__config_provider.legacy_server_address() # TODO: this CAN be None - for message_server_ip, message_server_port in self.__config_provider.server_message_addresses(): - self.__message_processor_addresses.append(DirectAddress.from_host_port(message_server_ip, message_server_port)) - self.__legacy_command_server_address = (legacy_server_ip, legacy_server_port) - - self.__ui_address = self.__config_provider.server_ui_address() - - self.__stop_event = asyncio.Event() - self.__server_closing_task = None - self.__cleanup_tasks = None - - self.__legacy_command_server = None - self.__message_processor: Optional[SchedulerMessageProcessor] = None - self.__ui_server = None - self.__ui_server_coro_args = {'protocol_factory': self._ui_protocol_factory, 'host': self.__ui_address[0], 'port': self.__ui_address[1], 'backlog': 16} - self.__legacy_server_coro_args = {'protocol_factory': self._scheduler_protocol_factory, 'host': legacy_server_ip, 'port': legacy_server_port, 'backlog': 16} - - self.__do_broadcasting = self.__config_provider.broadcast_interval() is not None - self.__broadcasting_interval = self.__config_provider.broadcast_interval() or 0 - self.__broadcasting_servers = [] - - self.__worker_pool = None - self.__worker_pool_helpers_minimal_idle_to_ensure = self.__config_provider.scheduler_helpers_minimal() - - self.__event_loop = asyncio.get_running_loop() - assert self.__event_loop is not None, 'Scheduler MUST be created within working event loop, in the main thread' - - @property - def config_provider(self) -> SchedulerConfigProviderBase: - return self.__config_provider - - def get_event_loop(self): - return self.__event_loop - - def node_data_provider(self) -> NodeDataProvider: - return self.__node_data_provider - - def _scheduler_protocol_factory(self): - return SchedulerTaskProtocol(self) - - def _ui_protocol_factory(self): - return SchedulerUiProtocol(self) - - def db_uid(self) -> int: - """ - unique id that was generated on creation for the DB currently in use - - :return: 64 bit unsigned int - """ - return self.data_access.db_uid - - def wake(self): - """ - scheduler may go into DORMANT mode when he things there's nothing to do - in that case wake() call exits DORMANT mode immediately - if wake is not called on some change- eventually scheduler will check it's shit and will decide to exit DORMANT mode on it's own, it will just waste some time first - if currently not in DORMANT mode - nothing will happen - - :return: - """ - self.task_processor.wake() - self.__pinger.wake() - - def poke_task_processor(self): - """ - kick that lazy ass to stop it's waitings and immediately perform another processing iteration - this is not connected to wake, __sleep and DORMANT mode, - this is just one-time kick - good to perform when task was changed somewhere async, outside of task_processor - - :return: - """ - self.task_processor.poke() - - def _component_changed_mode(self, component: SchedulerComponentBase, mode: SchedulerMode): - if component == self.task_processor and mode == SchedulerMode.DORMANT: - self.__logger.info('task processor switched to DORMANT mode') - self.__pinger.sleep() - - def message_processor(self) -> SchedulerMessageProcessor: - """ - get scheduler's main message processor - """ - return self.__message_processor - - async def get_node_type_and_name_by_id(self, node_id: int) -> (str, str): - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT "type", "name" FROM "nodes" WHERE "id" = ?', (node_id,)) as nodecur: - node_row = await nodecur.fetchone() - if node_row is None: - raise RuntimeError(f'node with given id {node_id} does not exist') - return node_row['type'], node_row['name'] - - @asynccontextmanager - async def node_object_by_id_for_reading(self, node_id: int): - async with self.get_node_lock_by_id(node_id).reader_lock: - yield await self._get_node_object_by_id(node_id) - - @asynccontextmanager - async def node_object_by_id_for_writing(self, node_id: int): - async with self.get_node_lock_by_id(node_id).writer_lock: - yield await self._get_node_object_by_id(node_id) - - async def _get_node_object_by_id(self, node_id: int) -> BaseNode: - """ - When accessing node this way - be aware that you SHOULD ensure your access happens within a lock - returned by get_node_lock_by_id. - If you don't want to deal with that - use scheduler's wrappers to access nodes in a safe way - (lol, wrappers are not implemented) - - :param node_id: - :return: - """ - if node_id in self.__node_objects: - return self.__node_objects[node_id] - async with self.__get_node_creation_lock_by_id(node_id): - # in case by the time we got here and the node was already created - if node_id in self.__node_objects: - return self.__node_objects[node_id] - # if no - need to create one after all - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT * FROM "nodes" WHERE "id" = ?', (node_id,)) as nodecur: - node_row = await nodecur.fetchone() - if node_row is None: - raise RuntimeError('node id is invalid') - - node_type = node_row['type'] - if not self.__node_data_provider.has_node_factory(node_type): - raise RuntimeError('node type is unsupported') - - if node_row['node_object'] is not None: - try: - for serializer in self.__node_serializers: - try: - node_object = await serializer.deserialize_async(self.__node_data_provider, node_row['node_object'], node_row['node_object_state']) - break - except IncompatibleDeserializationMethod as e: - self.__logger.warning(f'deserialization method failed with {e} ({serializer})') - continue - else: - raise FailedToDeserialize(f'node entry {node_id} has unknown serialization method') - node_object.set_parent(self, node_id) - self.__node_objects[node_id] = node_object - return self.__node_objects[node_id] - except FailedToDeserialize: - if self.__config_provider.ignore_node_deserialization_failures(): - pass # ignore errors, recreate node - else: - raise - - newnode = self.__node_data_provider.node_factory(node_type)(node_row['name']) - newnode.set_parent(self, node_id) - - self.__node_objects[node_id] = newnode - node_data, state_data = await self.__node_serializers[0].serialize_async(newnode) - await con.execute('UPDATE "nodes" SET node_object = ?, node_object_state = ? WHERE "id" = ?', - (node_data, state_data, node_id)) - await con.commit() - - return newnode - - def get_node_lock_by_id(self, node_id: int) -> RWLock: - """ - All read/write operations for a node should be locked within a per node rw lock that scheduler maintains. - Usually you do NOT have to be concerned with this. - But in cases you get the node object with functions like get_node_object_by_id. - it is your responsibility to ensure data is locked when accessed. - Lock is not part of the node itself. - - :param node_id: node id to get lock to - :return: rw lock for the node - """ - if node_id not in self.__node_objects_locks: - self.__node_objects_locks[node_id] = RWLock(fast=True) # read about fast on github. the points is if we have awaits inside critical section - it's safe to use fast - return self.__node_objects_locks[node_id] - - def __get_node_creation_lock_by_id(self, node_id: int) -> asyncio.Lock: - """ - This lock is for node creation/deserialization sections ONLY - """ - if node_id not in self.__node_objects_creation_locks: - self.__node_objects_creation_locks[node_id] = asyncio.Lock() - return self.__node_objects_creation_locks[node_id] - - async def get_task_attributes(self, task_id: int) -> Tuple[Dict[str, Any], Optional[EnvironmentResolverArguments]]: - """ - get tasks, atributes and it's enviroment resolver's attributes - - :param task_id: - :return: - """ - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT attributes, environment_resolver_data FROM tasks WHERE "id" = ?', (task_id,)) as cur: - res = await cur.fetchone() - if res is None: - raise RuntimeError('task with specified id was not found') - env_res_args = None - if res['environment_resolver_data'] is not None: - env_res_args = await EnvironmentResolverArguments.deserialize_async(res['environment_resolver_data']) - return await deserialize_attributes(res['attributes']), env_res_args - - async def get_task_fields(self, task_id: int) -> Dict[str, Any]: - """ - returns information about the given task, excluding thicc fields like attributes or env resolver - for those - use get_task_attributes - - :param task_id: - :return: - """ - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT "id", "name", parent_id, children_count, active_children_count, "state", paused, ' - '"node_id", split_level, priority, "dead" FROM tasks WHERE "id" == ?', (task_id,)) as cur: - res = await cur.fetchone() - if res is None: - raise RuntimeError('task with specified id was not found') - return dict(res) - - async def task_name_to_id(self, name: str) -> List[int]: - """ - get the list of task ids that have specified name - - :param name: - :return: - """ - async with self.data_access.data_connection() as con: - async with con.execute('SELECT "id" FROM "tasks" WHERE "name" = ?', (name,)) as cur: - return list(x[0] for x in await cur.fetchall()) - - async def get_task_invocation_serialized(self, task_id: int) -> Optional[bytes]: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT work_data FROM tasks WHERE "id" = ?', (task_id,)) as cur: - res = await cur.fetchone() - if res is None: - raise RuntimeError('task with specified id was not found') - return res[0] - - async def worker_id_from_address(self, addr: str) -> Optional[int]: - async with self.data_access.data_connection() as con: - async with con.execute('SELECT "id" FROM workers WHERE last_address = ?', (addr,)) as cur: - ret = await cur.fetchone() - if ret is None: - return None - return ret[0] - - async def get_worker_state(self, wid: int, con: Optional[aiosqlite.Connection] = None) -> WorkerState: - if con is None: - async with self.data_access.data_connection() as con: - async with con.execute('SELECT "state" FROM "workers" WHERE "id" = ?', (wid,)) as cur: - res = await cur.fetchone() - else: - async with con.execute('SELECT "state" FROM "workers" WHERE "id" = ?', (wid,)) as cur: - res = await cur.fetchone() - if res is None: - raise ValueError(f'worker with given wid={wid} was not found') - return WorkerState(res[0]) - - async def get_task_invocation(self, task_id: int): - data = await self.get_task_invocation_serialized(task_id) - if data is None: - return None - return await InvocationJob.deserialize_async(data) - - async def get_invocation_worker(self, invocation_id: int) -> Optional[AddressChain]: - async with self.data_access.data_connection() as con: - async with con.execute( - 'SELECT workers.last_address ' - 'FROM invocations LEFT JOIN workers ' - 'ON invocations.worker_id == workers.id ' - 'WHERE invocations.id == ?', (invocation_id,)) as cur: - res = await cur.fetchone() - if res is None: - return None - return AddressChain(res[0]) - - async def get_invocation_state(self, invocation_id: int) -> Optional[InvocationState]: - async with self.data_access.data_connection() as con: - async with con.execute( - 'SELECT state FROM invocations WHERE id == ?', (invocation_id,)) as cur: - res = await cur.fetchone() - if res is None: - return None - return InvocationState(res[0]) - - def stop(self): - async def _server_closer(): - # for server in self.__broadcasting_servers: - # server.wait_closed() - # ensure all components stop first - await self.__pinger.wait_till_stops() - await self.task_processor.wait_till_stops() - await self.__worker_pool.wait_till_stops() - await self.__ui_server.wait_closed() - if self.__legacy_command_server is not None: - self.__legacy_command_server.close() - await self.__legacy_command_server.wait_closed() - self.__logger.debug('stopping message processor...') - self.__message_processor.stop() - await self.__message_processor.wait_till_stops() - self.__logger.debug('message processor stopped') - - async def _db_cache_writeback(): - await self.__pinger.wait_till_stops() - await self.task_processor.wait_till_stops() - await self.__server_closing_task - await self._save_all_cached_nodes_to_db() - await self.data_access.write_back_cache() - - if self.__stop_event.is_set(): - self.__logger.error('cannot double stop!') - return # no double stopping - if not self.__started_event.is_set(): - self.__logger.error('cannot stop what is not started!') - return - self.__logger.info('STOPPING SCHEDULER') - # for server in self.__broadcasting_servers: - # server.close() - self.__stop_event.set() # this will stop things including task_processor - self.__pinger.stop() - self.task_processor.stop() - self.ui_state_access.stop() - self.__worker_pool.stop() - self.__server_closing_task = asyncio.create_task(_server_closer()) # we ensure worker pool stops BEFORE server, so workers have chance to report back - self.__cleanup_tasks = [asyncio.create_task(_db_cache_writeback())] - if self.__ui_server is not None: - self.__ui_server.close() - - def _stop_event_wait(self): # TODO: this is currently being used by ui proto to stop long connections, but not used in task proto, but what if it'll also get long living connections? - return self.__stop_event.wait() - - async def start(self): - # prepare - async with self.data_access.data_connection() as con: - # we play it the safest for now: - # all workers set to UNKNOWN state, all active invocations are reset, all tasks in the middle of processing are reset to closest waiting state - con.row_factory = aiosqlite.Row - await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" IN (?, ?)', - (TaskState.READY.value, TaskState.IN_PROGRESS.value, TaskState.INVOKING.value)) - await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', - (TaskState.WAITING.value, TaskState.GENERATING.value)) - await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', - (TaskState.WAITING.value, TaskState.WAITING_BLOCKED.value)) - await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', - (TaskState.POST_WAITING.value, TaskState.POST_GENERATING.value)) - await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', - (TaskState.POST_WAITING.value, TaskState.POST_WAITING_BLOCKED.value)) - await con.execute('UPDATE "invocations" SET "state" = ? WHERE "state" = ?', (InvocationState.FINISHED.value, InvocationState.IN_PROGRESS.value)) - # for now invoking invocation are invalidated by deletion (here and in task_processor) - await con.execute('DELETE FROM invocations WHERE "state" = ?', (InvocationState.INVOKING.value,)) - await con.execute('UPDATE workers SET "ping_state" = ?', (WorkerPingState.UNKNOWN.value,)) - await con.execute('UPDATE "workers" SET "state" = ?', (WorkerState.UNKNOWN.value,)) - await con.commit() - - # update volatile mem cache: - async with con.execute('SELECT "id", last_seen, last_checked, ping_state FROM workers') as worcur: - async for row in worcur: - self.data_access.mem_cache_workers_state[row['id']] = {k: row[k] for k in dict(row)} - - # start - loop = asyncio.get_event_loop() - self.__legacy_command_server = await loop.create_server(**self.__legacy_server_coro_args) - self.__ui_server = await loop.create_server(**self.__ui_server_coro_args) - # start message processor - - self.__message_processor = SchedulerMessageProcessor(self, self.__message_processor_addresses) - await self.__message_processor.start() - worker_pool_message_proxy_address = (self.__message_processor_addresses[0].split(':', 1)[0], None) # use same ip as scheduler's message processor, but default port - self.__worker_pool = WorkerPool(WorkerType.SCHEDULER_HELPER, - minimal_idle_to_ensure=self.__worker_pool_helpers_minimal_idle_to_ensure, - scheduler_address=self.server_message_address(DirectAddress(worker_pool_message_proxy_address[0])), - message_proxy_address=worker_pool_message_proxy_address, - ) - await self.__worker_pool.start() - # - # broadcasting - if self.__do_broadcasting: - # need to start a broadcaster for each interface from union of message and ui addresses - for iface_addr in all_interfaces()[1:]: # skipping first, as first is localhost - broadcast_address = get_broadcast_addr_for(iface_addr) - if broadcast_address is None: # broadcast not supported - continue - broadcast_data = {} - if direct_address := {x.split(':', 1)[0]: x for x in self.__message_processor_addresses}.get(iface_addr): - broadcast_data['message_address'] = str(direct_address) - if iface_addr == self.__ui_address[0] or self.__ui_address[0] == '0.0.0.0': - broadcast_data['ui'] = ':'.join(str(x) for x in (iface_addr, self.__ui_address[1])) - if iface_addr == self.__legacy_command_server_address[0] or self.__legacy_command_server_address[0] == '0.0.0.0': - broadcast_data['worker'] = ':'.join(str(x) for x in (iface_addr, self.__legacy_command_server_address[1])) - self.__broadcasting_servers.append( - ( - broadcast_address, - await create_broadcaster( - 'lifeblood_scheduler', - json.dumps(broadcast_data), - ip=broadcast_address, - broadcast_interval=self.__broadcasting_interval - ) - ) - ) - - await self.task_processor.start() - await self.__pinger.start() - await self.ui_state_access.start() - # run - self.__all_components = \ - asyncio.gather(self.task_processor.wait_till_stops(), - self.__pinger.wait_till_stops(), - self.ui_state_access.wait_till_stops(), - self.__legacy_command_server.wait_closed(), # TODO: shit being waited here below is very unnecessary - self.__ui_server.wait_closed(), - self.__worker_pool.wait_till_stops()) - - self.__started_event.set() - # print information - self.__logger.info('scheduler started') - self.__logger.info( - 'scheduler listening on:\n' - ' message processors:\n' - + '\n'.join((f' {addr}' for addr in self.__message_processor_addresses)) + - '\n' - ' ui servers:\n' - f' {":".join(str(x) for x in self.__ui_address)}\n' - ' legacy command servers:\n' - f' {":".join(str(x) for x in self.__legacy_command_server_address)}' - ) - self.__logger.info( - 'broadcasting enabled for:\n' - + '\n'.join((f' {info[0]}' for info in self.__broadcasting_servers)) - ) - - async def wait_till_starts(self): - return await self.__started_event.wait() - - async def wait_till_stops(self): - await self.__started_event.wait() - assert self.__all_components is not None - await self.__all_components - await self.__server_closing_task - for task in self.__cleanup_tasks: - await task - - async def _save_all_cached_nodes_to_db(self): - self.__logger.info('saving nodes to db') - for node_id in self.__node_objects: - await self.save_node_to_database(node_id) - self.__logger.debug(f'node {node_id} saved to db') - - def is_started(self): - return self.__started_event.is_set() - - def is_stopping(self) -> bool: - """ - True if stopped or in process of stopping - """ - return self.__stop_event.is_set() - - # - # helper functions - # - - async def reset_invocations_for_worker(self, worker_id: int, con: aiosqlite_overlay.ConnectionWithCallbacks, also_update_resources=True) -> bool: - """ - - :param worker_id: - :param con: - :param also_update_resources: - :return: need commit? - """ - async with con.execute('SELECT * FROM invocations WHERE "worker_id" = ? AND "state" == ?', - (worker_id, InvocationState.IN_PROGRESS.value)) as incur: - all_invoc_rows = await incur.fetchall() # we don't really want to update db while reading it - need_commit = False - for invoc_row in all_invoc_rows: # mark all (probably single one) invocations - need_commit = True - self.__logger.debug("fixing dangling invocation %d" % (invoc_row['id'],)) - await con.execute('UPDATE invocations SET "state" = ? WHERE "id" = ?', - (InvocationState.FINISHED.value, invoc_row['id'])) - await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', - (TaskState.READY.value, invoc_row['task_id'])) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, invoc_row['task_id']) # ui event - if also_update_resources: - need_commit = need_commit or await self._update_worker_resouce_usage(worker_id, connection=con) - return need_commit - - # - # invocation consistency checker - async def invocation_consistency_checker(self): - """ - both scheduler and woker might crash at any time. so we need to check that - worker may crash working on a task ( - :return: - """ - pass - - # - # callbacks - - # - # worker reports done task - async def task_done_reported(self, task: Invocation, stdout: str, stderr: str): - """ - scheduler comm protocols should call this when a task is done - TODO: this is almost the same code as for task_cancel_reported, maybe unify? - """ - for attempt in range(120): # TODO: this should be configurable - # if invocation is super fast - this may happen even before submission is completed, - # so we might need to wait a bit - try: - return await self.__task_done_reported_inner(task, stdout, stderr) - except NeedToRetryLater: - self.__logger.debug('attempt %d to report invocation %d done notified it needs to wait', attempt, task.invocation_id()) - await asyncio.sleep(0.5) # TODO: this should be configurable - continue - else: - self.__logger.error(f'out of attempts trying to report done invocation {task.invocation_id()}, probably something is not right with the state of the database') - - async def __task_done_reported_inner(self, task: Invocation, stdout: str, stderr: str): - """ - - """ - async with self.__invocation_reporting_lock, \ - self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - self.__logger.debug('task finished reported %s code %s', repr(task), task.exit_code()) - # sanity check - async with con.execute('SELECT "state" FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as cur: - invoc = await cur.fetchone() - if invoc is None: - self.__logger.error('reported task has non existing invocation id %d' % task.invocation_id()) - return - if invoc['state'] == InvocationState.INVOKING.value: # means _submitter has not yet finished, we should wait - raise NeedToRetryLater() - elif invoc['state'] != InvocationState.IN_PROGRESS.value: - self.__logger.warning(f'reported task for a finished invocation. assuming that worker failed to cancel task previously and ignoring invocation results. (state={invoc["state"]})') - return - await con.execute('UPDATE invocations SET "state" = ?, "return_code" = ?, "runtime" = ? WHERE "id" = ?', - (InvocationState.FINISHED.value, task.exit_code(), task.running_time(), task.invocation_id())) - async with con.execute('SELECT * FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as incur: - invocation = await incur.fetchone() - assert invocation is not None - - await con.execute('UPDATE workers SET "state" = ? WHERE "id" = ?', - (WorkerState.IDLE.value, invocation['worker_id'])) - await self._update_worker_resouce_usage(invocation['worker_id'], connection=con) # remove resource usage info - tasks_to_wait = [] - if not self.__use_external_log: - await con.execute('UPDATE invocations SET "stdout" = ?, "stderr" = ? WHERE "id" = ?', - (stdout, stderr, task.invocation_id())) - else: - await con.execute('UPDATE invocations SET "log_external" = 1 WHERE "id" = ?', - (task.invocation_id(),)) - tasks_to_wait.append(asyncio.create_task(self._save_external_logs(task.invocation_id(), stdout, stderr))) - - self.data_access.clear_invocation_progress(task.invocation_id()) - - ui_task_delta = TaskDelta(invocation['task_id']) # for ui event - if task.finished_needs_retry(): # max retry count will be checked by task processor - await con.execute('UPDATE tasks SET "state" = ?, "work_data_invocation_attempt" = "work_data_invocation_attempt" + 1 WHERE "id" = ?', - (TaskState.READY.value, invocation['task_id'])) - ui_task_delta.state = TaskState.READY # for ui event - elif task.finished_with_error(): - state_details = json.dumps({'message': f'see invocation #{invocation["id"]} log for details', - 'happened_at': TaskState.IN_PROGRESS.value, - 'type': 'invocation'}) - await con.execute('UPDATE tasks SET "state" = ?, "state_details" = ? WHERE "id" = ?', - (TaskState.ERROR.value, - state_details, - invocation['task_id'])) - ui_task_delta.state = TaskState.ERROR # for ui event - ui_task_delta.state_details = state_details # for ui event - else: - await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', - (TaskState.POST_WAITING.value, invocation['task_id'])) - ui_task_delta.state = TaskState.POST_WAITING # for ui event - - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, ui_task_delta) # ui event - await con.commit() - if len(tasks_to_wait) > 0: - await asyncio.wait(tasks_to_wait) - self.wake() - self.poke_task_processor() - - async def _save_external_logs(self, invocation_id, stdout, stderr): - logbasedir = self.__external_log_location / 'invocations' / f'{invocation_id}' - try: - if not logbasedir.exists(): - logbasedir.mkdir(exist_ok=True) - async with aiofiles.open(logbasedir / 'stdout.log', 'w') as fstdout, \ - aiofiles.open(logbasedir / 'stderr.log', 'w') as fstderr: - await asyncio.gather(fstdout.write(stdout), - fstderr.write(stderr)) - except OSError: - self.__logger.exception('error happened saving external logs! Ignoring this error') - - # - # worker reports canceled task - async def task_cancel_reported(self, task: Invocation, stdout: str, stderr: str): - """ - scheduler comm protocols should call this when a task is cancelled - """ - for attempt in range(120): # TODO: this should be configurable - # if invocation is super fast - this may happen even before submission is completed, - # so we might need to wait a bit - try: - return await self.__task_cancel_reported_inner(task, stdout, stderr) - except NeedToRetryLater: - self.__logger.debug('attempt %d to report invocation %d cancelled notified it needs to wait', attempt, task.invocation_id()) - await asyncio.sleep(0.5) # TODO: this should be configurable - continue - else: - self.__logger.error(f'out of attempts trying to report cancel invocation {task.invocation_id()}, probably something is not right with the state of the database') - - async def __task_cancel_reported_inner(self, task: Invocation, stdout: str, stderr: str): - async with self.__invocation_reporting_lock, \ - self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - self.__logger.debug('task cancelled reported %s', repr(task)) - # sanity check - async with con.execute('SELECT "state" FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as cur: - invoc = await cur.fetchone() - if invoc is None: - self.__logger.error('reported task has non existing invocation id %d' % task.invocation_id()) - return - if invoc['state'] == InvocationState.INVOKING.value: # means _submitter has not yet finished, we should wait - raise NeedToRetryLater() - elif invoc['state'] != InvocationState.IN_PROGRESS.value: - self.__logger.warning(f'reported task for a finished invocation. assuming that worker failed to cancel task previously and ignoring invocation results. (state={invoc["state"]})') - return - await con.execute('UPDATE invocations SET "state" = ?, "runtime" = ? WHERE "id" = ?', - (InvocationState.FINISHED.value, task.running_time(), task.invocation_id())) - async with con.execute('SELECT * FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as incur: - invocation = await incur.fetchone() - assert invocation is not None - - self.data_access.clear_invocation_progress(task.invocation_id()) - - await con.execute('UPDATE workers SET "state" = ? WHERE "id" = ?', - (WorkerState.IDLE.value, invocation['worker_id'])) - await self._update_worker_resouce_usage(invocation['worker_id'], connection=con) # remove resource usage info - tasks_to_wait = [] - if not self.__use_external_log: - await con.execute('UPDATE invocations SET "stdout" = ?, "stderr" = ? WHERE "id" = ?', - (stdout, stderr, task.invocation_id())) - else: - await con.execute('UPDATE invocations SET "log_external" = 1, "stdout" = null, "stderr" = null WHERE "id" = ?', - (task.invocation_id(),)) - tasks_to_wait.append(asyncio.create_task(self._save_external_logs(task.invocation_id(), stdout, stderr))) - await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', - (TaskState.READY.value, invocation['task_id'])) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(invocation['task_id'], state=TaskState.READY)) # ui event - await con.commit() - if len(tasks_to_wait) > 0: - await asyncio.wait(tasks_to_wait) - self.__logger.debug(f'cancelling task done {repr(task)}') - self.wake() - self.poke_task_processor() - - # - # add new worker to db - async def add_worker( - # TODO: WorkerResources (de)serialization - # TODO: Worker actually passing new WorkerResources on hello - self, addr: str, worker_type: WorkerType, worker_resources: HardwareResources, # TODO: all resource should also go here - *, - assume_active: bool = True, - worker_metadata: WorkerMetadata): - """ - this is called by network protocol handler when worker reports being up to the scheduler - """ - self.__logger.debug(f'worker reported added: {addr}') - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('BEGIN IMMEDIATE') # important to have locked DB during all this state change - # logic for now: - # - search for same last_address, same hwid - # - if no - search for first entry (OFF or UNKNOWN) with same hwid, ignore address - # - in this case also delete addr from DB if exists - async with con.execute('SELECT "id", state FROM "workers" WHERE "last_address" == ? AND hwid == ?', (addr, worker_resources.hwid)) as worcur: - worker_row = await worcur.fetchone() - if worker_row is None: - # first ensure that there is no entry with the same address - await con.execute('UPDATE "workers" SET "last_address" = ? WHERE "last_address" == ?', (None, addr)) - async with con.execute('SELECT "id", state FROM "workers" WHERE hwid == ? AND ' - '(state == ? OR state == ?)', (worker_resources.hwid, - WorkerState.OFF.value, WorkerState.UNKNOWN.value)) as worcur: - worker_row = await worcur.fetchone() - if assume_active: - ping_state = WorkerPingState.WORKING.value - state = WorkerState.IDLE.value - else: - ping_state = WorkerPingState.OFF.value - state = WorkerState.OFF.value - - tstamp = int(time.time()) - if worker_row is not None: - if worker_row['state'] == WorkerState.INVOKING.value: # so we are in the middle of sumbission - state = WorkerState.INVOKING.value # then we preserve INVOKING state - await self.reset_invocations_for_worker(worker_row['id'], con=con, also_update_resources=False) # we update later - await con.execute('UPDATE "workers" SET ' - 'hwid=?, ' - 'last_seen=?, ping_state=?, state=?, worker_type=?, ' - 'last_address=? ' - 'WHERE "id"=?', - (worker_resources.hwid, - tstamp, ping_state, state, worker_type.value, - addr, - worker_row['id'])) - # async with con.execute('SELECT "id" FROM "workers" WHERE last_address=?', (addr,)) as worcur: - # worker_id = (await worcur.fetchone())['id'] - worker_id = worker_row['id'] - self.data_access.mem_cache_workers_state[worker_id].update({'last_seen': tstamp, - 'last_checked': tstamp, - 'ping_state': ping_state, - 'worker_id': worker_id}) - # await con.execute('UPDATE tmpdb.tmp_workers_states SET ' - # 'last_seen=?, ping_state=? ' - # 'WHERE worker_id=?', - # (tstamp, ping_state, worker_id)) - else: - async with con.execute('INSERT INTO "workers" ' - '(hwid, ' - 'last_address, last_seen, ping_state, state, worker_type) ' - 'VALUES ' - '(?, ?, ?, ?, ?, ?)', - (worker_resources.hwid, addr, tstamp, ping_state, state, worker_type.value)) as insworcur: - worker_id = insworcur.lastrowid - self.data_access.mem_cache_workers_state[worker_id] = {'last_seen': tstamp, - 'last_checked': tstamp, - 'ping_state': ping_state, - 'worker_id': worker_id} - # await con.execute('INSERT INTO tmpdb.tmp_workers_states ' - # '(worker_id, last_seen, ping_state) ' - # 'VALUES ' - # '(?, ?, ?)', - # (worker_id, tstamp, ping_state)) - - resource_fields: Tuple[str, ...] = tuple(x.name for x in self.__config_provider.hardware_resource_definitions()) - # device_type_names = tuple(x.name for x in self.__config_provider.hardware_device_type_definitions()) - device_type_resource_fields: Dict[str, Tuple[str, ...]] = {x.name: tuple(r.name for r in x.resources) for x in self.__config_provider.hardware_device_type_definitions()} - # in case worker_resources contain dev_types not known to config - they will be ignored - devices_to_register = [] - # checks - for field in resource_fields: - if field not in worker_resources: - self.__logger.warning(f'worker (hwid:{worker_resources.hwid}) does not declare expected resource "{field}", assume value=0') - for res_name, _ in worker_resources.items(): - if res_name not in resource_fields: - self.__logger.warning(f'worker (hwid:{worker_resources.hwid}) declares resource "{res_name}" unknown to the scheduler, ignoring') - for dev_type, dev_name, dev_res in worker_resources.devices(): - if dev_type not in device_type_resource_fields: - self.__logger.warning(f'worker (hwid:{worker_resources.hwid}) declares device type "{dev_type}" unknown to the scheduler, ignoring') - continue - devices_to_register.append((dev_type, dev_name, {res_name: res_val for res_name, res_val in dev_res.items() if res_name in device_type_resource_fields[dev_type]})) - - # TODO: note that below sql breaks if there are no resource_fields (which is an unlikely config, but not impossible) - await con.execute('INSERT INTO resources ' - '(hwid, ' + - ', '.join(f'{field}, total_{field}' for field in resource_fields) + - ') ' - 'VALUES (?' + ', ?'*(2*len(resource_fields)) + ') ' - 'ON CONFLICT(hwid) DO UPDATE SET ' + - ', '.join(f'"{field}"=excluded.{field}, "total_{field}"=excluded.total_{field}' for field in resource_fields) - , - (worker_resources.hwid, - *(x for field in resource_fields for x in ( - (worker_resources[field].value, worker_resources[field].value) if field in worker_resources else (0, 0)) # TODO: do NOT invent defaults here, only set known fields, like in dev code below - )) - ) - - for dev_type, dev_name, dev_res in sorted(devices_to_register, key=lambda x: (x[0], x[1])): # sort by (deva_type, dev_name) to ensure some consistent order - - dev_type_table_name = f'hardware_device_type__{dev_type}' - if dev_res: - await con.execute( - f'INSERT INTO "{dev_type_table_name}" ' - f'(hwid, hw_dev_name, ' + - ', '.join(f'res__{field}' for field, _ in dev_res.items()) + - ') ' - 'VALUES (?, ?' + ', ?'*(len(dev_res)) + ') ' - 'ON CONFLICT(hwid,"hw_dev_name") DO UPDATE SET ' + - ', '.join(f'"res__{field}"=excluded.res__{field}' for field in dev_res) - , - (worker_resources.hwid, dev_name, - *(res_val.value for _, res_val in dev_res.items()) - ) - ) - else: - await con.execute( - f'INSERT INTO "{dev_type_table_name}" ' - f'(hwid, hw_dev_name) ' + - 'VALUES (?, ?) ' - 'ON CONFLICT(hwid,"hw_dev_name") DO NOTHING' - , - (worker_resources.hwid, dev_name) - ) - - await self._update_worker_resouce_usage(worker_id, hwid=worker_resources.hwid, connection=con) # used resources are inited to none - self.data_access.set_worker_metadata(worker_resources.hwid, worker_metadata) - await con.commit() - self.__logger.debug(f'finished worker reported added: {addr}') - self.poke_task_processor() - - # TODO: add decorator that locks method from reentry or smth - # potentially a worker may report done while this works, - # or when scheduler picked worker and about to run this, which will lead to inconsistency warning - # NOTE!: so far it's always called from a STARTED transaction, so there should not be reentry possible - # But that is not enforced right now, easy to make mistake - async def _update_worker_resouce_usage(self, worker_id: int, resources: Optional[Requirements] = None, *, hwid=None, connection: aiosqlite.Connection) -> bool: - """ - updates resource information based on new worker resources usage - as part of ongoing transaction - Note: con SHOULD HAVE STARTED TRANSACTION, otherwise it might be not safe to call this - - :param worker_id: - :param hwid: if hwid of worker_id is already known - provide it here to skip extra db query. but be SURE it's correct! - :param connection: opened db connection. expected to have Row as row factory - :return: if commit is needed on connection (if db set operation happened) - """ - assert connection.in_transaction, 'expectation failure' - - resource_fields = tuple(x.name for x in self.__config_provider.hardware_resource_definitions()) - device_type_names = tuple(x.name for x in self.__config_provider.hardware_device_type_definitions()) - - workers_resources = self.data_access.mem_cache_workers_resources - if hwid is None: - async with connection.execute('SELECT "hwid" FROM "workers" WHERE "id" == ?', (worker_id,)) as worcur: - hwid = (await worcur.fetchone())['hwid'] - - # calculate available resources NOT counting current worker_id - async with connection.execute(f'SELECT ' - f'{", ".join(resource_fields)}, ' - f'{", ".join("total_"+x for x in resource_fields)} ' - f'FROM resources WHERE hwid == ?', (hwid,)) as rescur: - available_res = dict(await rescur.fetchone()) - available_dev_type_to_ids: Dict[str, Dict[int, Dict[str, Union[int, float, str]]]] = {} - current_available_dev_type_to_ids: Dict[str, Set[int]] = {} - for dev_type in device_type_names: - dev_type_table_name = f'hardware_device_type__{dev_type}' - async with connection.execute( - f'SELECT * FROM "{dev_type_table_name}" WHERE hwid == ?', - (hwid,)) as rescur: - all_dev_rows = [dict(x) for x in await rescur.fetchall()] - available_dev_type_to_ids[dev_type] = { - x['dev_id']: {k[len('res__'):]: v for k, v in x.items() if k.startswith('res__')} # resource cols start with res__ - for x in all_dev_rows - } # note, these are ALL devices, with "available" 0 and 1 values, we don't *trust* "available", we recalc them below, just like with non-total res - current_available_dev_type_to_ids[dev_type] = {x['dev_id'] for x in all_dev_rows if x['available']} # now this counts available to check later if anything changed - current_available = {k: v for k, v in available_res.items() if not k.startswith('total_')} - available_res = {k[len('total_'):]: v for k, v in available_res.items() if k.startswith('total_')} # start with full total res - - for wid, res in workers_resources.items(): - if wid == worker_id: - continue # SKIP worker_id currently being set - if res.get('hwid') != hwid: - continue - # recalc actual available resources based on cached worker_resources - for field in resource_fields: - if field not in res.get('res', {}): - continue - available_res[field] -= res['res'][field] - # recalc actual available devices based on cached worker_resources - for dev_type in device_type_names: - if dev_type not in res.get('dev', {}): - continue - for dev_id in res['dev'][dev_type]: - available_dev_type_to_ids[dev_type].pop(dev_id) - ## - - # now choose proper amount of resources to pick - if resources is None: - workers_resources[worker_id] = {'hwid': hwid} # remove resource usage info - else: - workers_resources[worker_id] = {'res': {}, 'dev': {}} - for field in resource_fields: - if field not in resources.resources: - continue - if available_res[field] < resources.resources[field].min: - raise NotEnoughResources(f'{field}: {resources.resources[field].min} out of {available_res[field]}') - # so we take preferred amount of resources (or minimum if pref not set), but no more than available - # if preferred is lower than min - it's ignored - workers_resources[worker_id]['res'][field] = min(available_res[field], - max(resources.resources[field].pref, resources.resources[field].min)) - available_res[field] -= workers_resources[worker_id]['res'][field] - - selected_devs: Dict[str, List[int]] = {} # dev_type to list of dev_ids of that type that are picked - for dev_type, dev_reqs in resources.devices.items(): - if dev_reqs.min == 0 and dev_reqs.pref == 0: # trivial check - continue - if dev_type not in available_dev_type_to_ids: - if dev_reqs.min > 0: - raise NotEnoughResources(f'device "{dev_type}" missing') # this shouldn't happen - this whole func is only called when resources are checked - else: - continue - for dev_id, dev_res in available_dev_type_to_ids[dev_type].items(): - # now we check if dev fits requirements - is_good = True - for req_name, req_val in dev_reqs.resources.items(): # we ignore pref in current logic - devices are always taken full - if req_name not in dev_res: - raise NotEnoughResources(f'device "{dev_type}" does not have requested resource "{req_name}"') # this also should not happen - if dev_res[req_name] < req_val.min: - is_good = False - break - if is_good: - selected_devs.setdefault(dev_type, []).append(dev_id) - if len(selected_devs[dev_type]) >= max(dev_reqs.min, dev_reqs.pref): - # we selected enough devices of this type - break - # now remove selected from available - for dev_id in selected_devs.get(dev_type, []): - available_dev_type_to_ids[dev_type].pop(dev_id) - # sanity check - if dev_reqs.min > 0 and len(selected_devs[dev_type]) < dev_reqs.min: - raise NotEnoughResources(f'device "{dev_type}: cannot select {dev_reqs.min} out of {len(selected_devs[dev_type])}') - workers_resources[worker_id]['dev'] = selected_devs - - workers_resources[worker_id]['hwid'] = hwid # just to ensure it was not overriden - - self.__logger.debug(f'updating resources {hwid} with {available_res} against {current_available}') - self.__logger.debug(workers_resources) - - available_res_didnt_change = available_res == current_available - available_devs_didnt_change = all(set(available_dev_type_to_ids[dev_type].keys()) == current_available_dev_type_to_ids[dev_type] for dev_type in device_type_names) - if available_res == current_available and available_devs_didnt_change: # nothing needs to be updated - return False - - if not available_res_didnt_change: - await connection.execute(f'UPDATE resources SET {", ".join(f"{k}={v}" for k, v in available_res.items())} WHERE hwid == ?', (hwid,)) - if not available_devs_didnt_change: - for dev_type in device_type_names: # TODO: only update affected tables - dev_type_table_name = f'hardware_device_type__{dev_type}' - await connection.execute(f'UPDATE "{dev_type_table_name}" SET "available"=0 WHERE hwid==?', (hwid,)) - await connection.executemany(f'UPDATE "{dev_type_table_name}" SET "available"=1 WHERE dev_id==?', ((x,) for x in available_dev_type_to_ids[dev_type].keys())) - return True - - # - # - async def update_invocation_progress(self, invocation_id: int, progress: float): - """ - report progress update on invocation that is being worked on - there are not too many checks here, as progress report is considered non-vital information, - so if such message comes after invocation is finished - it's not big deal - """ - prev_progress = self.data_access.get_invocation_progress(invocation_id) - self.data_access.set_invocation_progress(invocation_id, progress) - if prev_progress != progress: - task_id = None - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT task_id FROM invocations WHERE "state" == ? AND "id" == ?', - (InvocationState.IN_PROGRESS.value, invocation_id,)) as cur: - task_id_row = await cur.fetchone() - if task_id_row is not None: - task_id = task_id_row['task_id'] - if task_id is not None: - self.ui_state_access.scheduler_reports_task_updated(TaskDelta(task_id, progress=progress)) - - # - # worker reports it being stopped - async def worker_stopped(self, addr: str): - """ - - :param addr: - :return: - """ - self.__logger.debug(f'worker reported stopped: {addr}') - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('BEGIN IMMEDIATE') - async with con.execute('SELECT id, hwid from "workers" WHERE "last_address" = ?', (addr,)) as worcur: - worker_row = await worcur.fetchone() - if worker_row is None: - self.__logger.warning(f'unregistered worker reported "stopped": {addr}, ignoring') - await con.rollback() - return - wid = worker_row['id'] - hwid = worker_row['hwid'] - # print(wid) - - # we ensure there are no invocations running with this worker - async with con.execute('SELECT "id", task_id FROM invocations WHERE worker_id = ? AND ("state" = ? OR "state" = ?)', - (wid, InvocationState.IN_PROGRESS.value, InvocationState.INVOKING.value)) as invcur: - invocations = await invcur.fetchall() - - await con.execute('UPDATE workers SET "state" = ? WHERE "id" = ?', (WorkerState.OFF.value, wid)) - await con.executemany('UPDATE invocations SET state = ? WHERE "id" = ?', ((InvocationState.FINISHED.value, x["id"]) for x in invocations)) - await con.executemany('UPDATE tasks SET state = ? WHERE "id" = ?', ((TaskState.READY.value, x["task_id"]) for x in invocations)) - await self._update_worker_resouce_usage(wid, hwid=hwid, connection=con) # oh wait, it happens right here, still an assert won't hurt - del self.data_access.mem_cache_workers_resources[wid] # remove from cache # TODO: ENSURE resources were already unset for this wid - if len(invocations) > 0: - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_updated, [TaskDelta(x["task_id"], state=TaskState.READY) for x in invocations]) # ui event - await con.commit() - self.__logger.debug(f'finished worker reported stopped: {addr}') - - # - # protocol related commands - # - # - # cancel invocation - async def cancel_invocation(self, invocation_id: str): - self.__logger.debug(f'canceling invocation {invocation_id}') - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT * FROM "invocations" WHERE "id" = ?', (invocation_id,)) as cur: - invoc = await cur.fetchone() - if invoc is None or invoc['state'] != InvocationState.IN_PROGRESS.value: - return - async with con.execute('SELECT "last_address" FROM "workers" WHERE "id" = ?', (invoc['worker_id'],)) as cur: - worker = await cur.fetchone() - if worker is None: - self.__logger.error('inconsistent worker ids? how?') - return - addr = AddressChain(worker['last_address']) - - # the logic is: - # - we send the worker a signal to cancel invocation - # - later worker sends task_cancel_reported, and we are happy - # - but worker might be overloaded, broken or whatever and may never send it. and it can even finish task and send task_done_reported, witch we need to treat - with WorkerControlClient.get_worker_control_client(addr, self.message_processor()) as client: # type: WorkerControlClient - await client.cancel_task() - - # oh no, we don't do that, we wait for worker to report task canceled. await con.execute('UPDATE invocations SET "state" = ? WHERE "id" = ?', (InvocationState.FINISHED.value, invocation_id)) - - # - # - async def cancel_invocation_for_task(self, task_id: int): - self.__logger.debug(f'canceling invocation for task {task_id}') - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT "id" FROM "invocations" WHERE "task_id" = ? AND state = ?', (task_id, InvocationState.IN_PROGRESS.value)) as cur: - invoc = await cur.fetchone() - if invoc is None: - return - return await self.cancel_invocation(invoc['id']) - - # - # - async def cancel_invocation_for_worker(self, worker_id: int): - self.__logger.debug(f'canceling invocation for worker {worker_id}') - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT "id" FROM "invocations" WHERE "worker_id" == ? AND state == ?', (worker_id, InvocationState.IN_PROGRESS.value)) as cur: - invoc = await cur.fetchone() - if invoc is None: - return - return await self.cancel_invocation(invoc['id']) - - # - # - async def force_set_node_task(self, task_id: int, node_id: int): - self.__logger.debug(f'forcing task {task_id} to node {node_id}') - try: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('BEGIN IMMEDIATE') - await con.execute('PRAGMA FOREIGN_KEYS = on') - async with con.execute('SELECT "state" FROM tasks WHERE "id" == ?', (task_id,)) as cur: - row = await cur.fetchone() - if row is None: - self.__logger.warning(f'failed to force task node: task {task_id} not found') - await con.rollback() - return - - state = TaskState(row['state']) - new_state = None - if state in (TaskState.WAITING, TaskState.READY, TaskState.POST_WAITING): - new_state = TaskState.WAITING - elif state == TaskState.DONE: - new_state = TaskState.DONE - # if new_state was not set - means state was invalid - if new_state is None: - self.__logger.warning(f'changing node of a task in state {state.name} is not allowed') - await con.rollback() - raise ValueError(f'changing node of a task in state {state.name} is not allowed') - - await con.execute('UPDATE tasks SET "node_id" = ?, "state" = ? WHERE "id" = ?', (node_id, new_state.value, task_id)) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, node_id=node_id)) # ui event - # reset blocking too - await self.data_access.reset_task_blocking(task_id, con=con) - await con.commit() - except aiosqlite.IntegrityError: - self.__logger.error(f'could not set task {task_id} to node {node_id} because of database integrity check') - raise DataIntegrityError() from None - else: - self.wake() - self.poke_task_processor() - - # - # force change task state - async def force_change_task_state(self, task_ids: Union[int, Iterable[int]], state: TaskState): - """ - forces task into given state. - obviously a task cannot be forced into certain states, like IN_PROGRESS, GENERATING, POST_GENERATING - :param task_ids: - :param state: - :return: - """ - if state in (TaskState.IN_PROGRESS, TaskState.GENERATING, TaskState.POST_GENERATING): - self.__logger.error(f'cannot force task {task_ids} into state {state}') - return - if isinstance(task_ids, int): - task_ids = [task_ids] - query = 'UPDATE tasks SET "state" = %d WHERE "id" = ?' % state.value - #print('beep') - async with self.data_access.data_connection() as con: - for task_id in task_ids: - await con.execute('BEGIN IMMEDIATE') - async with con.execute('SELECT "state" FROM tasks WHERE "id" = ?', (task_id,)) as cur: - cur_state = await cur.fetchone() - if cur_state is None: - await con.rollback() - continue - cur_state = TaskState(cur_state[0]) - if cur_state in (TaskState.IN_PROGRESS, TaskState.GENERATING, TaskState.POST_GENERATING): - self.__logger.warning(f'forcing task out of state {cur_state} is not allowed') - await con.rollback() - continue - - await con.execute(query, (task_id,)) - # just in case we also reset blocking - await self.data_access.reset_task_blocking(task_id, con=con) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, state=state)) # ui event - await con.commit() # TODO: this can be optimized into a single transaction - #print('boop') - self.wake() - self.poke_task_processor() - - # - # change task's paused state - async def set_task_paused(self, task_ids_or_group: Union[int, Iterable[int], str], paused: bool): - if isinstance(task_ids_or_group, str): - async with self.data_access.data_connection() as con: - await con.execute('UPDATE tasks SET "paused" = ? WHERE "id" IN (SELECT "task_id" FROM task_groups WHERE "group" = ?)', - (int(paused), task_ids_or_group)) - ui_task_ids = await self.ui_state_access._get_group_tasks(task_ids_or_group) # ui event - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_updated, [TaskDelta(ui_task_id, paused=paused) for ui_task_id in ui_task_ids]) # ui event - await con.commit() - self.wake() - self.poke_task_processor() - return - if isinstance(task_ids_or_group, int): - task_ids_or_group = [task_ids_or_group] - query = 'UPDATE tasks SET "paused" = %d WHERE "id" = ?' % int(paused) - async with self.data_access.data_connection() as con: - await con.executemany(query, ((x,) for x in task_ids_or_group)) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_updated, [TaskDelta(ui_task_id, paused=paused) for ui_task_id in task_ids_or_group]) # ui event - await con.commit() - self.wake() - self.poke_task_processor() - - # - # change task group archived state - async def set_task_group_archived(self, task_group_name: str, state: TaskGroupArchivedState = TaskGroupArchivedState.ARCHIVED) -> None: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('UPDATE task_group_attributes SET state=? WHERE "group"==?', (state.value, task_group_name)) # this triggers all task deadness | 2, so potentially it can be long, beware - # task's dead field's 2nd bit is set, but we currently do not track it - # so no event needed - await con.commit() - if state == TaskGroupArchivedState.NOT_ARCHIVED: - self.poke_task_processor() # unarchived, so kick task processor, just in case - return - # otherwise - it's archived - # now all tasks belonging to that group should be set to dead|2 - # we need to make sure to cancel all running invocations for those tasks - # at this point tasks are archived and won't be processed, - # so we only expect concurrent changes due to already running _submitters and _awaiters, - # like INVOKING->IN_PROGRESS - async with con.execute('SELECT "id" FROM invocations ' - 'INNER JOIN task_groups ON task_groups.task_id == invocations.task_id ' - 'WHERE task_groups."group" == ? AND invocations.state == ?', - (task_group_name, InvocationState.INVOKING.value)) as cur: - invoking_invoc_ids = set(x['id'] for x in await cur.fetchall()) - async with con.execute('SELECT "id" FROM invocations ' - 'INNER JOIN task_groups ON task_groups.task_id == invocations.task_id ' - 'WHERE task_groups."group" == ? AND invocations.state == ?', - (task_group_name, InvocationState.IN_PROGRESS.value)) as cur: - active_invoc_ids = tuple(x['id'] for x in await cur.fetchall()) - # i sure use a lot of fetchall where it's much more natural to iterate cursor - # that is because of a fear of db locking i got BEFORE switching to WAL, when iterating connection was randomly crashing other connections not taking timeout into account at all. - - # note at this point we might have some invoking_invocs_id, but at this point some of them - # might already have been set to in-progress and even got into active_invoc_ids list - - # first - cancel all in-progress invocations - for inv_id in active_invoc_ids: - await self.cancel_invocation(inv_id) - - # now since we dont have the ability to safely cancel running _submitter task - we will just wait till - # invoking invocations change state - # sure it's a bit bruteforce - # but a working solution for now - if len(invoking_invoc_ids) == 0: - return - async with self.data_access.data_connection() as con: - while len(invoking_invoc_ids) > 0: - # TODO: this forever while doesn't seem right - # in average case it should basically never happen at all - # only in case of really bad buggy network connections an invocation can get stuck on INVOKING - # but there are natural timeouts in _submitter that will switch it from INVOKING eventually - # the only question is - do we want to just stay in this function until it's resolved? UI's client is single thread, so it will get stuck waiting - con.row_factory = aiosqlite.Row - async with con.execute('SELECT "id",state FROM invocations WHERE state!={} AND "id" IN ({})'.format( - InvocationState.IN_PROGRESS.value, - ','.join(str(x) for x in invoking_invoc_ids))) as cur: - changed_state_ones = await cur.fetchall() - - for oid, ostate in ((x['id'], x['state']) for x in changed_state_ones): - if ostate == InvocationState.IN_PROGRESS.value: - await self.cancel_invocation(oid) - assert oid in invoking_invoc_ids - invoking_invoc_ids.remove(oid) - await asyncio.sleep(0.5) - - # - # set task name - async def set_task_name(self, task_id: int, new_name: str): - async with self.data_access.data_connection() as con: - await con.execute('UPDATE tasks SET "name" = ? WHERE "id" = ?', (new_name, task_id)) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, name=new_name)) # ui event - await con.commit() - - # - # set task groups - async def set_task_groups(self, task_id: int, group_names: Iterable[str]): - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('BEGIN IMMEDIATE') - async with con.execute('SELECT "group" FROM task_groups WHERE "task_id" = ?', (task_id,)) as cur: - all_groups = set(x['group'] for x in await cur.fetchall()) - group_names = set(group_names) - groups_to_set = group_names - all_groups - groups_to_del = all_groups - group_names - print(task_id, groups_to_set, groups_to_del, all_groups, group_names) - - for group_name in groups_to_set: - await con.execute('INSERT INTO task_groups (task_id, "group") VALUES (?, ?)', (task_id, group_name)) - await con.execute('INSERT OR IGNORE INTO task_group_attributes ("group", "ctime") VALUES (?, ?)', (group_name, int(datetime.utcnow().timestamp()))) - for group_name in groups_to_del: - await con.execute('DELETE FROM task_groups WHERE task_id = ? AND "group" = ?', (task_id, group_name)) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_removed_from_group, [task_id], groups_to_del) # ui event - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, groups_to_set) # ui event - # - # ui event - if len(groups_to_set) > 0: - async with con.execute( - 'SELECT tasks.id, tasks.parent_id, tasks.children_count, tasks.active_children_count, tasks.state, tasks.state_details, tasks.paused, tasks.node_id, ' - 'tasks.node_input_name, tasks.node_output_name, tasks.name, tasks.split_level, tasks.work_data_invocation_attempt, ' - 'task_splits.origin_task_id, task_splits.split_id, invocations."id" as invoc_id ' - 'FROM "tasks" ' - 'LEFT JOIN "task_splits" ON tasks.id=task_splits.task_id ' - 'LEFT JOIN "invocations" ON tasks.id=invocations.task_id AND invocations.state = ? ' - 'WHERE tasks."id" == ?', - (InvocationState.IN_PROGRESS.value, task_id)) as cur: - task_row = await cur.fetchone() - if task_row is not None: - progress = self.data_access.get_invocation_progress(task_row['invoc_id']) - con.add_after_commit_callback( - self.ui_state_access.scheduler_reports_task_added, - TaskData(task_id, task_row['parent_id'], task_row['children_count'], task_row['active_children_count'], TaskState(task_row['state']), - task_row['state_details'], bool(task_row['paused']), task_row['node_id'], task_row['node_input_name'], task_row['node_output_name'], - task_row['name'], task_row['split_level'], task_row['work_data_invocation_attempt'], progress, - task_row['origin_task_id'], task_row['split_id'], task_row['invoc_id'], group_names), - groups_to_set - ) # ui event - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, groups=group_names)) # ui event - # - # - await con.commit() - - # - # update task attributes - async def update_task_attributes(self, task_id: int, attributes_to_update: dict, attributes_to_delete: set): - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('BEGIN IMMEDIATE') - async with con.execute('SELECT "attributes" FROM tasks WHERE "id" = ?', (task_id,)) as cur: - row = await cur.fetchone() - if row is None: - self.__logger.warning(f'update task attributes for {task_id} failed. task id not found.') - await con.commit() - return - attributes = await deserialize_attributes(row['attributes']) - attributes.update(attributes_to_update) - for name in attributes_to_delete: - if name in attributes: - del attributes[name] - await con.execute('UPDATE tasks SET "attributes" = ? WHERE "id" = ?', (await serialize_attributes(attributes), - task_id)) - await con.commit() - - # - # set environment resolver - async def set_task_environment_resolver_arguments(self, task_id: int, env_res: Optional[EnvironmentResolverArguments]): - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('UPDATE tasks SET "environment_resolver_data" = ? WHERE "id" = ?', - (await env_res.serialize_async() if env_res is not None else None, - task_id)) - await con.commit() - - # - # node stuff - async def set_node_name(self, node_id: int, node_name: str) -> str: - """ - rename node. node_name may undergo validation and change. final node name that was set is returned - :param node_id: node id - :param node_name: proposed node name - :return: actual node name set - """ - async with self.data_access.data_connection() as con: - await con.execute('UPDATE "nodes" SET "name" = ? WHERE "id" = ?', (node_name, node_id)) - if node_id in self.__node_objects: - self.__node_objects[node_id].set_name(node_name) - await con.commit() - self.ui_state_access.bump_graph_update_id() - return node_name - - # - # reset node's stored state - async def wipe_node_state(self, node_id): - async with self.data_access.data_connection() as con: - await con.execute('UPDATE "nodes" SET node_object = NULL WHERE "id" = ?', (node_id,)) - if node_id in self.__node_objects: - # TODO: this below may be not safe (at least not proven to be safe yet, but maybe). check - del self.__node_objects[node_id] # it's here to "protect" operation within db transaction. TODO: but a proper __node_object lock should be in place instead - await con.commit() - self.ui_state_access.bump_graph_update_id() # not sure if needed - even number of inputs/outputs is not part of graph description - self.wake() - - # - # copy nodes - async def duplicate_nodes(self, node_ids: Iterable[int]) -> Dict[int, int]: - """ - copies given nodes, including connections between given nodes, - and returns mapping from given node_ids to respective new copies - - :param node_ids: - :return: - """ - old_to_new = {} - for nid in node_ids: - async with self.node_object_by_id_for_reading(nid) as node_obj: - node_type, node_name = await self.get_node_type_and_name_by_id(nid) - new_id = await self.add_node(node_type, f'{node_name} copy') - async with self.node_object_by_id_for_writing(new_id) as new_node_obj: - node_obj.copy_ui_to(new_node_obj) - old_to_new[nid] = new_id - - # now copy connections - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - node_ids_str = f'({",".join(str(x) for x in node_ids)})' - async with con.execute(f'SELECT * FROM node_connections WHERE node_id_in IN {node_ids_str} AND node_id_out IN {node_ids_str}') as cur: - all_cons = await cur.fetchall() - for nodecon in all_cons: - assert nodecon['node_id_in'] in old_to_new - assert nodecon['node_id_out'] in old_to_new - await self.add_node_connection(old_to_new[nodecon['node_id_out']], nodecon['out_name'], old_to_new[nodecon['node_id_in']], nodecon['in_name']) - return old_to_new - # TODO: NotImplementedError("recheck and needs testing") - - # - # - # node reports it's interface was changed. not sure why it exists - async def node_reports_changes_needs_saving(self, node_id): - assert node_id in self.__node_objects, 'this may be caused by race condition with node deletion' - await self.save_node_to_database(node_id) - - # - # save node to database. - async def save_node_to_database(self, node_id): - """ - save node with given node_id to database - if node is not in our list of nodes - we assume it was not touched, not changed, so no saving needed - - :param node_id: - :return: - """ - # TODO: introduce __node_objects lock? or otherwise secure access - # why? this happens on ui_update, which can happen cuz of request from viewer. - # while node processing happens in a different thread, so this CAN happen at the same time with this - # AND THIS IS BAD! (potentially) if a node has changing internal state - this can save some inconsistent snapshot of node state! - # this works now only cuz scheduler_ui_protocol does the locking for param settings - node_object = self.__node_objects[node_id] - if node_object is None: - self.__logger.error('node_object is None while') - return - node_data, state_data = await self.__node_serializers[0].serialize_async(node_object) - async with self.data_access.data_connection() as con: - await con.execute('UPDATE "nodes" SET node_object = ?, node_object_state = ? WHERE "id" = ?', - (node_data, state_data, node_id)) - await con.commit() - - # - # set worker groups - async def set_worker_groups(self, worker_hwid: int, groups: List[str]): - groups = set(groups) - async with self.data_access.data_connection() as con: - await con.execute('BEGIN IMMEDIATE') # start transaction straight away - async with con.execute('SELECT "group" FROM worker_groups WHERE worker_hwid == ?', (worker_hwid,)) as cur: - existing_groups = set(x[0] for x in await cur.fetchall()) - to_delete = existing_groups - groups - to_add = groups - existing_groups - if len(to_delete): - await con.execute(f'DELETE FROM worker_groups WHERE worker_hwid == ? AND "group" IN ({",".join(("?",)*len(to_delete))})', (worker_hwid, *to_delete)) - if len(to_add): - await con.executemany(f'INSERT INTO worker_groups (worker_hwid, "group") VALUES (?, ?)', - ((worker_hwid, x) for x in to_add)) - await con.commit() - - # - # change node connection callback - async def change_node_connection(self, node_connection_id: int, new_out_node_id: Optional[int], new_out_name: Optional[str], - new_in_node_id: Optional[int], new_in_name: Optional[str]): - parts = [] - vals = [] - if new_out_node_id is not None: - parts.append('node_id_out = ?') - vals.append(new_out_node_id) - if new_out_name is not None: - parts.append('out_name = ?') - vals.append(new_out_name) - if new_in_node_id is not None: - parts.append('node_id_in = ?') - vals.append(new_in_node_id) - if new_in_name is not None: - parts.append('in_name = ?') - vals.append(new_in_name) - if len(vals) == 0: # nothing to do - return - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - vals.append(node_connection_id) - await con.execute(f'UPDATE node_connections SET {", ".join(parts)} WHERE "id" = ?', vals) - await con.commit() - self.wake() - self.ui_state_access.bump_graph_update_id() - - # - # add node connection callback - async def add_node_connection(self, out_node_id: int, out_name: str, in_node_id: int, in_name: str) -> int: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('INSERT OR REPLACE INTO node_connections (node_id_out, out_name, node_id_in, in_name) VALUES (?,?,?,?)', # INSERT OR REPLACE here (and not OR ABORT or smth) to ensure lastrowid is set - (out_node_id, out_name, in_node_id, in_name)) as cur: - ret = cur.lastrowid - await con.commit() - self.wake() - self.ui_state_access.bump_graph_update_id() - return ret - - # - # remove node connection callback - async def remove_node_connection(self, node_connection_id: int): - try: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('PRAGMA FOREIGN_KEYS = on') - await con.execute('DELETE FROM node_connections WHERE "id" = ?', (node_connection_id,)) - await con.commit() - self.ui_state_access.bump_graph_update_id() - except aiosqlite.IntegrityError as e: - self.__logger.error(f'could not remove node connection {node_connection_id} because of database integrity check') - raise DataIntegrityError() from None - - # - # add node - async def add_node(self, node_type: str, node_name: str) -> int: - if not self.__node_data_provider.has_node_factory(node_type): # preliminary check - raise RuntimeError(f'unknown node type: "{node_type}"') - - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('INSERT INTO "nodes" ("type", "name") VALUES (?,?)', - (node_type, node_name)) as cur: - ret = cur.lastrowid - await con.commit() - self.ui_state_access.bump_graph_update_id() - return ret - - async def apply_node_settings(self, node_id: int, settings_name: str): - async with self.node_object_by_id_for_writing(node_id) as node_object: - settings = self.__node_data_provider.node_settings(node_object.type_name(), settings_name) - async with self.node_object_by_id_for_writing(node_id) as node: # type: BaseNode - await asyncio.get_event_loop().run_in_executor(None, node.apply_settings, settings) - - async def remove_node(self, node_id: int): - try: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - await con.execute('PRAGMA FOREIGN_KEYS = on') - await con.execute('DELETE FROM "nodes" WHERE "id" = ?', (node_id,)) - await con.commit() - self.ui_state_access.bump_graph_update_id() - except aiosqlite.IntegrityError as e: - self.__logger.error(f'could not remove node {node_id} because of database integrity check') - raise DataIntegrityError('There are invocations (maybe achieved ones) referencing this node') from None - - # - # query connections - async def get_node_input_connections(self, node_id: int, input_name: Optional[str] = None): - return await self.get_node_connections(node_id, True, input_name) - - async def get_node_output_connections(self, node_id: int, output_name: Optional[str] = None): - return await self.get_node_connections(node_id, False, output_name) - - async def get_node_connections(self, node_id: int, query_input: bool = True, name: Optional[str] = None): - if query_input: - nodecol = 'node_id_in' - namecol = 'in_name' - else: - nodecol = 'node_id_out' - namecol = 'out_name' - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - if name is None: - async with con.execute('SELECT * FROM node_connections WHERE "%s" = ?' % (nodecol,), - (node_id,)) as cur: - return [dict(x) for x in await cur.fetchall()] - else: - async with con.execute('SELECT * FROM node_connections WHERE "%s" = ? AND "%s" = ?' % (nodecol, namecol), - (node_id, name)) as cur: - return [dict(x) for x in await cur.fetchall()] - - # - # spawning new task callback - async def spawn_tasks(self, newtasks: Union[Iterable[TaskSpawn], TaskSpawn], con: Optional[aiosqlite_overlay.ConnectionWithCallbacks] = None) -> Union[Tuple[SpawnStatus, Optional[int]], Tuple[Tuple[SpawnStatus, Optional[int]], ...]]: - """ - - :param newtasks: - :param con: - :return: - """ - - async def _inner_shit() -> Tuple[Tuple[SpawnStatus, Optional[int]], ...]: - result = [] - new_tasks = [] - current_timestamp = int(datetime.utcnow().timestamp()) - assert len(newtasks) > 0, 'expectations failure' - if not con.in_transaction: # IF this is called from multiple async tasks with THE SAME con - this may cause race conditions - await con.execute('BEGIN IMMEDIATE') - for newtask in newtasks: - if newtask.source_invocation_id() is not None: - async with con.execute('SELECT node_id, task_id FROM invocations WHERE "id" = ?', - (newtask.source_invocation_id(),)) as incur: - invocrow = await incur.fetchone() - assert invocrow is not None - node_id: int = invocrow['node_id'] - parent_task_id: int = invocrow['task_id'] - elif newtask.forced_node_task_id() is not None: - node_id, parent_task_id = newtask.forced_node_task_id() - else: - self.__logger.error('ERROR CREATING SPAWN TASK: Malformed source') - result.append((SpawnStatus.FAILED, None)) - continue - - async with con.execute('INSERT INTO tasks ("name", "attributes", "parent_id", "state", "node_id", "node_output_name", "environment_resolver_data") VALUES (?, ?, ?, ?, ?, ?, ?)', - (newtask.name(), await serialize_attributes(newtask._attributes()), parent_task_id, # TODO: run dumps in executor - TaskState.SPAWNED.value if newtask.create_as_spawned() else TaskState.WAITING.value, - node_id, newtask.node_output_name(), - newtask.environment_arguments().serialize() if newtask.environment_arguments() is not None else None)) as newcur: - new_id = newcur.lastrowid - - all_groups = set() - if parent_task_id is not None: # inherit all parent's groups - # check and inherit parent's environment wrapper arguments - if newtask.environment_arguments() is None: - await con.execute('UPDATE tasks SET environment_resolver_data = (SELECT environment_resolver_data FROM tasks WHERE "id" == ?) WHERE "id" == ?', - (parent_task_id, new_id)) - - # inc children count happens in db trigger - # inherit groups - async with con.execute('SELECT "group" FROM task_groups WHERE "task_id" = ?', (parent_task_id,)) as gcur: - groups = [x['group'] for x in await gcur.fetchall()] - all_groups.update(groups) - if len(groups) > 0: - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, groups) # ui event - await con.executemany('INSERT INTO task_groups ("task_id", "group") VALUES (?, ?)', - zip(itertools.repeat(new_id, len(groups)), groups)) - else: # parent_task_id is None - # in this case we create a default group for the task. - # task should not be left without groups at all - otherwise it will be impossible to find in UI - new_group = '{name}#{id:d}'.format(name=newtask.name(), id=new_id) - all_groups.add(new_group) - await con.execute('INSERT INTO task_groups ("task_id", "group") VALUES (?, ?)', - (new_id, new_group)) - await con.execute('INSERT OR REPLACE INTO task_group_attributes ("group", "ctime") VALUES (?, ?)', - (new_group, current_timestamp)) - if newtask.default_priority() is not None: - await con.execute('UPDATE task_group_attributes SET "priority" = ? WHERE "group" = ?', - (newtask.default_priority(), new_group)) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, (new_group,)) # ui event - # - if newtask.extra_group_names(): - groups = newtask.extra_group_names() - all_groups.update(groups) - await con.executemany('INSERT INTO task_groups ("task_id", "group") VALUES (?, ?)', - zip(itertools.repeat(new_id, len(groups)), groups)) - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, groups) # ui event - for group in groups: - async with con.execute('SELECT "group" FROM task_group_attributes WHERE "group" == ?', (group,)) as gcur: - need_create = await gcur.fetchone() is None - if not need_create: - continue - await con.execute('INSERT INTO task_group_attributes ("group", "ctime") VALUES (?, ?)', - (group, current_timestamp)) - # TODO: task_groups.group should be a foreign key to task_group_attributes.group - # but then we need to insert those guys in correct order (first in attributes table, then groups) - # then smth like FOREIGN KEY("group") REFERENCES "task_group_attributes"("group") ON UPDATE CASCADE ON DELETE CASCADE - result.append((SpawnStatus.SUCCEEDED, new_id)) - new_tasks.append(TaskData(new_id, parent_task_id, 0, 0, - TaskState.SPAWNED if newtask.create_as_spawned() else TaskState.WAITING, '', - False, node_id, 'main', newtask.node_output_name(), newtask.name(), 0, 0, None, None, None, None, - all_groups)) - - # callbacks for ui events - con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_added, new_tasks) - return tuple(result) - - return_single = False - if isinstance(newtasks, TaskSpawn): - newtasks = (newtasks,) - return_single = True - if len(newtasks) == 0: - return () - if con is not None: - stuff = await _inner_shit() - else: - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - stuff = await _inner_shit() - await con.commit() - self.wake() - self.poke_task_processor() - return stuff[0] if return_single else stuff - - # - async def node_name_to_id(self, name: str) -> List[int]: - """ - get the list of node ids that have specified name - :param name: - :return: - """ - async with self.data_access.data_connection() as con: - async with con.execute('SELECT "id" FROM "nodes" WHERE "name" = ?', (name,)) as cur: - return list(x[0] for x in await cur.fetchall()) - - # - async def get_invocation_metadata(self, task_id: int) -> Dict[int, List[IncompleteInvocationLogData]]: - """ - get task's log metadata - meaning which nodes it ran on and how - :param task_id: - :return: dict[node_id -> list[IncompleteInvocationLogData]] - """ - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - logs = {} - self.__logger.debug(f'fetching log metadata for {task_id}') - async with con.execute('SELECT "id", node_id, runtime, worker_id, state, return_code from "invocations" WHERE "state" != ? AND "task_id" == ?', - (InvocationState.INVOKING.value, task_id)) as cur: - async for entry in cur: - node_id = entry['node_id'] - logs.setdefault(node_id, []).append(IncompleteInvocationLogData( - entry['id'], - entry['worker_id'], - entry['runtime'], # TODO: this should be set to active run time if invocation is running - InvocationState(entry['state']), - entry['return_code'] - )) - return logs - - async def get_log(self, invocation_id: int) -> Optional[InvocationLogData]: - """ - get logs for given task, node and invocation ids - - returns a dict of node_id - - :param invocation_id: - :return: - """ - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - self.__logger.debug(f"fetching for {invocation_id}") - async with con.execute('SELECT "id", task_id, worker_id, node_id, state, return_code, log_external, runtime, stdout, stderr ' - 'FROM "invocations" WHERE "id" = ?', - (invocation_id,)) as cur: - rawentry = await cur.fetchone() # should be exactly 1 or 0 - if rawentry is None: - return None - - entry: InvocationLogData = InvocationLogData(rawentry['id'], - rawentry['worker_id'], - rawentry['runtime'], - rawentry['task_id'], - rawentry['node_id'], - InvocationState(rawentry['state']), - rawentry['return_code'], - rawentry['stdout'] or '', - rawentry['stderr'] or '') - if entry.invocation_state == InvocationState.IN_PROGRESS: - async with con.execute('SELECT last_address FROM workers WHERE "id" = ?', (entry.worker_id,)) as worcur: - workrow = await worcur.fetchone() - if workrow is None: - self.__logger.error('Worker not found during log fetch! this is not supposed to happen! Database inconsistent?') - else: - try: - with WorkerControlClient.get_worker_control_client(AddressChain(workrow['last_address']), self.message_processor()) as client: # type: WorkerControlClient - stdout, stderr = await client.get_log(invocation_id) - if not self.__use_external_log: - await con.execute('UPDATE "invocations" SET stdout = ?, stderr = ? WHERE "id" = ?', # TODO: is this really needed? if it's never really read - (stdout, stderr, invocation_id)) - await con.commit() - # TODO: maybe add else case? save partial log to file? - except ConnectionError: - self.__logger.warning('could not connect to worker to get freshest logs') - else: - entry.stdout = stdout - entry.stderr = stderr - - elif entry.invocation_state == InvocationState.FINISHED and rawentry['log_external'] == 1: - logbasedir = self.__external_log_location / 'invocations' / f'{invocation_id}' - stdout_path = logbasedir / 'stdout.log' - stderr_path = logbasedir / 'stderr.log' - try: - if stdout_path.exists(): - async with aiofiles.open(stdout_path, 'r') as fstdout: - entry.stdout = await fstdout.read() - except IOError: - self.__logger.exception(f'could not read external stdout log for {invocation_id}') - try: - if stderr_path.exists(): - async with aiofiles.open(stderr_path, 'r') as fstderr: - entry.stderr = await fstderr.read() - except IOError: - self.__logger.exception(f'could not read external stdout log for {invocation_id}') - - return entry - - def server_address(self) -> Tuple[str, int]: - if self.__legacy_command_server_address is None: - raise RuntimeError('cannot get listening address of a non started server') - return self.__legacy_command_server_address - - def server_message_address(self, to: AddressChain) -> AddressChain: - if self.__message_processor is None: - raise RuntimeError('cannot get listening address of a non started server') - - return self.message_processor().listening_address(to) - - def server_message_addresses(self) -> Tuple[AddressChain]: - if self.__message_processor is None: - raise RuntimeError('cannot get listening address of a non started server') - - return self.message_processor().listening_addresses() diff --git a/src/lifeblood/scheduler/scheduler_core.py b/src/lifeblood/scheduler/scheduler_core.py new file mode 100644 index 00000000..7d9387cd --- /dev/null +++ b/src/lifeblood/scheduler/scheduler_core.py @@ -0,0 +1,1863 @@ +import os +from pathlib import Path +import time +from datetime import datetime +import json +import itertools +import asyncio +import aiosqlite +import aiofiles +from aiorwlock import RWLock +from contextlib import asynccontextmanager + +from .. import logging +from ..nodegraph_holder_base import NodeGraphHolderBase +from ..attribute_serialization import serialize_attributes, deserialize_attributes +from ..worker_message_processor_client import WorkerControlClient +from ..hardware_resources import HardwareResources +from ..invocationjob import Invocation, InvocationJob, Requirements +from ..environment_resolver import EnvironmentResolverArguments +from ..broadcasting import create_broadcaster +from ..simple_worker_pool import WorkerPool +from ..nethelpers import get_broadcast_addr_for, all_interfaces +from ..worker_metadata import WorkerMetadata +from ..taskspawn import TaskSpawn +from ..basenode import BaseNode +from ..exceptions import * +from ..node_dataprovider_base import NodeDataProvider +from ..basenode_serialization import NodeSerializerBase, IncompatibleDeserializationMethod, FailedToDeserialize +from ..enums import WorkerState, WorkerPingState, TaskState, InvocationState, WorkerType, \ + SchedulerMode, TaskGroupArchivedState, SpawnStatus +from .. import aiosqlite_overlay +from ..ui_protocol_data import TaskData, TaskDelta, IncompleteInvocationLogData, InvocationLogData + +from ..net_messages.address import DirectAddress, AddressChain +from ..net_messages.message_processor import MessageProcessorBase +from ..scheduler_config_provider_base import SchedulerConfigProviderBase +from ..worker_pool_message_processor import WorkerPoolMessageProcessor + +from .data_access import DataAccess +from .scheduler_component_base import SchedulerComponentBase +from .pinger import Pinger +from .task_processor import TaskProcessor +from .ui_state_accessor import UIStateAccessor + +from typing import Optional, Any, Callable, Tuple, List, Iterable, Union, Dict, Set + + +class SchedulerCore(NodeGraphHolderBase): + def __init__(self, *, + scheduler_config_provider: SchedulerConfigProviderBase, + node_data_provider: NodeDataProvider, + node_serializers: List[NodeSerializerBase], + message_processor_factory: Callable[["SchedulerCore", List[DirectAddress]], MessageProcessorBase], + legacy_task_protocol_factory: Callable[["SchedulerCore"], asyncio.StreamReaderProtocol], + ui_protocol_factory: Callable[["SchedulerCore"], asyncio.StreamReaderProtocol], + ): + """ + TODO: add a docstring + + :param scheduler_config_provider: + """ + self.__node_data_provider: NodeDataProvider = node_data_provider + if len(node_serializers) < 1: + raise ValueError('at least one serializer must be provided!') + self.__node_serializers = list(node_serializers) + self.__logger = logging.get_logger('scheduler') + self.__logger.info('loading core plugins') + self.__node_objects: Dict[int, BaseNode] = {} + self.__node_objects_locks: Dict[int, RWLock] = {} + self.__node_objects_creation_locks: Dict[int, asyncio.Lock] = {} + self.__config_provider: SchedulerConfigProviderBase = scheduler_config_provider + + # this lock will prevent tasks from being reported cancelled and done at the same exact time should that ever happen + # this lock is overkill already, but we can make it even more overkill by using set of locks for each invoc id + # which would be completely useless now cuz sqlite locks DB as a whole, not even a single table, especially not just parts of table + self.__invocation_reporting_lock = asyncio.Lock() + + self.__all_components = None + self.__started_event = asyncio.Event() + + self.__db_path = scheduler_config_provider.main_database_location() + if not self.__db_path.startswith('file:'): # if schema is used - we do not modify the db uri in any way + self.__db_path = os.path.realpath(os.path.expanduser(self.__db_path)) + self.__logger.debug(f'starting scheduler with database: {self.__db_path}') + self.data_access: DataAccess = DataAccess( + config_provider=self.__config_provider, + ) + ## + + self.__use_external_log = self.__config_provider.external_log_location() is not None + self.__external_log_location: Optional[Path] = self.__config_provider.external_log_location() + if self.__use_external_log: + external_log_path = Path(self.__use_external_log) + if external_log_path.exists() and external_log_path.is_file(): + external_log_path.unlink() + if not external_log_path.exists(): + external_log_path.mkdir(parents=True) + if not os.access(self.__external_log_location, os.X_OK | os.W_OK): + raise RuntimeError('cannot write to external log location provided') + + self.__pinger: Pinger = Pinger(self) + self.task_processor: TaskProcessor = TaskProcessor(self) + self.ui_state_access: UIStateAccessor = UIStateAccessor(self) + + self.__message_processor_addresses = [] + self.__ui_address = None + self.__legacy_command_server_address = None + + legacy_server_ip, legacy_server_port = self.__config_provider.legacy_server_address() # TODO: this CAN be None + for message_server_ip, message_server_port in self.__config_provider.server_message_addresses(): + self.__message_processor_addresses.append(DirectAddress.from_host_port(message_server_ip, message_server_port)) + self.__legacy_command_server_address = (legacy_server_ip, legacy_server_port) + + self.__ui_address = self.__config_provider.server_ui_address() + + self.__stop_event = asyncio.Event() + self.__server_closing_task = None + self.__cleanup_tasks = None + + self.__legacy_command_server = None + self.__message_processor: Optional[MessageProcessorBase] = None + self.__ui_server = None + self.__ui_server_coro_args = {'protocol_factory': lambda: ui_protocol_factory(self), 'host': self.__ui_address[0], 'port': self.__ui_address[1], 'backlog': 16} + self.__legacy_server_coro_args = {'protocol_factory': lambda: legacy_task_protocol_factory(self), 'host': legacy_server_ip, 'port': legacy_server_port, 'backlog': 16} + self.__message_processor_factory = message_processor_factory + + self.__do_broadcasting = self.__config_provider.broadcast_interval() is not None + self.__broadcasting_interval = self.__config_provider.broadcast_interval() or 0 + self.__broadcasting_servers = [] + + self.__worker_pool = None + self.__worker_pool_helpers_minimal_idle_to_ensure = self.__config_provider.scheduler_helpers_minimal() + + self.__event_loop = asyncio.get_running_loop() + assert self.__event_loop is not None, 'Scheduler MUST be created within working event loop, in the main thread' + + @property + def config_provider(self) -> SchedulerConfigProviderBase: + return self.__config_provider + + def get_event_loop(self): + return self.__event_loop + + def node_data_provider(self) -> NodeDataProvider: + return self.__node_data_provider + + def db_uid(self) -> int: + """ + unique id that was generated on creation for the DB currently in use + + :return: 64 bit unsigned int + """ + return self.data_access.db_uid + + def wake(self): + """ + scheduler may go into DORMANT mode when he things there's nothing to do + in that case wake() call exits DORMANT mode immediately + if wake is not called on some change- eventually scheduler will check it's shit and will decide to exit DORMANT mode on it's own, it will just waste some time first + if currently not in DORMANT mode - nothing will happen + + :return: + """ + self.task_processor.wake() + self.__pinger.wake() + + def poke_task_processor(self): + """ + kick that lazy ass to stop it's waitings and immediately perform another processing iteration + this is not connected to wake, __sleep and DORMANT mode, + this is just one-time kick + good to perform when task was changed somewhere async, outside of task_processor + + :return: + """ + self.task_processor.poke() + + def _component_changed_mode(self, component: SchedulerComponentBase, mode: SchedulerMode): + if component == self.task_processor and mode == SchedulerMode.DORMANT: + self.__logger.info('task processor switched to DORMANT mode') + self.__pinger.sleep() + + def message_processor(self) -> MessageProcessorBase: + """ + get scheduler's main message processor + """ + return self.__message_processor + + async def get_node_type_and_name_by_id(self, node_id: int) -> (str, str): + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT "type", "name" FROM "nodes" WHERE "id" = ?', (node_id,)) as nodecur: + node_row = await nodecur.fetchone() + if node_row is None: + raise RuntimeError(f'node with given id {node_id} does not exist') + return node_row['type'], node_row['name'] + + @asynccontextmanager + async def node_object_by_id_for_reading(self, node_id: int): + async with self.get_node_lock_by_id(node_id).reader_lock: + yield await self._get_node_object_by_id(node_id) + + @asynccontextmanager + async def node_object_by_id_for_writing(self, node_id: int): + async with self.get_node_lock_by_id(node_id).writer_lock: + yield await self._get_node_object_by_id(node_id) + + async def _get_node_object_by_id(self, node_id: int) -> BaseNode: + """ + When accessing node this way - be aware that you SHOULD ensure your access happens within a lock + returned by get_node_lock_by_id. + If you don't want to deal with that - use scheduler's wrappers to access nodes in a safe way + (lol, wrappers are not implemented) + + :param node_id: + :return: + """ + if node_id in self.__node_objects: + return self.__node_objects[node_id] + async with self.__get_node_creation_lock_by_id(node_id): + # in case by the time we got here and the node was already created + if node_id in self.__node_objects: + return self.__node_objects[node_id] + # if no - need to create one after all + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT * FROM "nodes" WHERE "id" = ?', (node_id,)) as nodecur: + node_row = await nodecur.fetchone() + if node_row is None: + raise RuntimeError('node id is invalid') + + node_type = node_row['type'] + if not self.__node_data_provider.has_node_factory(node_type): + raise RuntimeError('node type is unsupported') + + if node_row['node_object'] is not None: + try: + for serializer in self.__node_serializers: + try: + node_object = await serializer.deserialize_async(self.__node_data_provider, node_row['node_object'], node_row['node_object_state']) + break + except IncompatibleDeserializationMethod as e: + self.__logger.warning(f'deserialization method failed with {e} ({serializer})') + continue + else: + raise FailedToDeserialize(f'node entry {node_id} has unknown serialization method') + node_object.set_parent(self, node_id) + self.__node_objects[node_id] = node_object + return self.__node_objects[node_id] + except FailedToDeserialize: + if self.__config_provider.ignore_node_deserialization_failures(): + pass # ignore errors, recreate node + else: + raise + + newnode = self.__node_data_provider.node_factory(node_type)(node_row['name']) + newnode.set_parent(self, node_id) + + self.__node_objects[node_id] = newnode + node_data, state_data = await self.__node_serializers[0].serialize_async(newnode) + await con.execute('UPDATE "nodes" SET node_object = ?, node_object_state = ? WHERE "id" = ?', + (node_data, state_data, node_id)) + await con.commit() + + return newnode + + def get_node_lock_by_id(self, node_id: int) -> RWLock: + """ + All read/write operations for a node should be locked within a per node rw lock that scheduler maintains. + Usually you do NOT have to be concerned with this. + But in cases you get the node object with functions like get_node_object_by_id. + it is your responsibility to ensure data is locked when accessed. + Lock is not part of the node itself. + + :param node_id: node id to get lock to + :return: rw lock for the node + """ + if node_id not in self.__node_objects_locks: + self.__node_objects_locks[node_id] = RWLock(fast=True) # read about fast on github. the points is if we have awaits inside critical section - it's safe to use fast + return self.__node_objects_locks[node_id] + + def __get_node_creation_lock_by_id(self, node_id: int) -> asyncio.Lock: + """ + This lock is for node creation/deserialization sections ONLY + """ + if node_id not in self.__node_objects_creation_locks: + self.__node_objects_creation_locks[node_id] = asyncio.Lock() + return self.__node_objects_creation_locks[node_id] + + async def get_task_attributes(self, task_id: int) -> Tuple[Dict[str, Any], Optional[EnvironmentResolverArguments]]: + """ + get tasks, atributes and it's enviroment resolver's attributes + + :param task_id: + :return: + """ + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT attributes, environment_resolver_data FROM tasks WHERE "id" = ?', (task_id,)) as cur: + res = await cur.fetchone() + if res is None: + raise RuntimeError('task with specified id was not found') + env_res_args = None + if res['environment_resolver_data'] is not None: + env_res_args = await EnvironmentResolverArguments.deserialize_async(res['environment_resolver_data']) + return await deserialize_attributes(res['attributes']), env_res_args + + async def get_task_fields(self, task_id: int) -> Dict[str, Any]: + """ + returns information about the given task, excluding thicc fields like attributes or env resolver + for those - use get_task_attributes + + :param task_id: + :return: + """ + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT "id", "name", parent_id, children_count, active_children_count, "state", paused, ' + '"node_id", split_level, priority, "dead" FROM tasks WHERE "id" == ?', (task_id,)) as cur: + res = await cur.fetchone() + if res is None: + raise RuntimeError('task with specified id was not found') + return dict(res) + + async def task_name_to_id(self, name: str) -> List[int]: + """ + get the list of task ids that have specified name + + :param name: + :return: + """ + async with self.data_access.data_connection() as con: + async with con.execute('SELECT "id" FROM "tasks" WHERE "name" = ?', (name,)) as cur: + return list(x[0] for x in await cur.fetchall()) + + async def get_task_invocation_serialized(self, task_id: int) -> Optional[bytes]: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT work_data FROM tasks WHERE "id" = ?', (task_id,)) as cur: + res = await cur.fetchone() + if res is None: + raise RuntimeError('task with specified id was not found') + return res[0] + + async def worker_id_from_address(self, addr: str) -> Optional[int]: + async with self.data_access.data_connection() as con: + async with con.execute('SELECT "id" FROM workers WHERE last_address = ?', (addr,)) as cur: + ret = await cur.fetchone() + if ret is None: + return None + return ret[0] + + async def get_worker_state(self, wid: int, con: Optional[aiosqlite.Connection] = None) -> WorkerState: + if con is None: + async with self.data_access.data_connection() as con: + async with con.execute('SELECT "state" FROM "workers" WHERE "id" = ?', (wid,)) as cur: + res = await cur.fetchone() + else: + async with con.execute('SELECT "state" FROM "workers" WHERE "id" = ?', (wid,)) as cur: + res = await cur.fetchone() + if res is None: + raise ValueError(f'worker with given wid={wid} was not found') + return WorkerState(res[0]) + + async def get_task_invocation(self, task_id: int): + data = await self.get_task_invocation_serialized(task_id) + if data is None: + return None + return await InvocationJob.deserialize_async(data) + + async def get_invocation_worker(self, invocation_id: int) -> Optional[AddressChain]: + async with self.data_access.data_connection() as con: + async with con.execute( + 'SELECT workers.last_address ' + 'FROM invocations LEFT JOIN workers ' + 'ON invocations.worker_id == workers.id ' + 'WHERE invocations.id == ?', (invocation_id,)) as cur: + res = await cur.fetchone() + if res is None: + return None + return AddressChain(res[0]) + + async def get_invocation_state(self, invocation_id: int) -> Optional[InvocationState]: + async with self.data_access.data_connection() as con: + async with con.execute( + 'SELECT state FROM invocations WHERE id == ?', (invocation_id,)) as cur: + res = await cur.fetchone() + if res is None: + return None + return InvocationState(res[0]) + + def stop(self): + async def _server_closer(): + # for server in self.__broadcasting_servers: + # server.wait_closed() + # ensure all components stop first + await self.__pinger.wait_till_stops() + await self.task_processor.wait_till_stops() + await self.__worker_pool.wait_till_stops() + await self.__ui_server.wait_closed() + if self.__legacy_command_server is not None: + self.__legacy_command_server.close() + await self.__legacy_command_server.wait_closed() + self.__logger.debug('stopping message processor...') + self.__message_processor.stop() + await self.__message_processor.wait_till_stops() + self.__logger.debug('message processor stopped') + + async def _db_cache_writeback(): + await self.__pinger.wait_till_stops() + await self.task_processor.wait_till_stops() + await self.__server_closing_task + await self._save_all_cached_nodes_to_db() + await self.data_access.write_back_cache() + + if self.__stop_event.is_set(): + self.__logger.error('cannot double stop!') + return # no double stopping + if not self.__started_event.is_set(): + self.__logger.error('cannot stop what is not started!') + return + self.__logger.info('STOPPING SCHEDULER') + # for server in self.__broadcasting_servers: + # server.close() + self.__stop_event.set() # this will stop things including task_processor + self.__pinger.stop() + self.task_processor.stop() + self.ui_state_access.stop() + self.__worker_pool.stop() + self.__server_closing_task = asyncio.create_task(_server_closer()) # we ensure worker pool stops BEFORE server, so workers have chance to report back + self.__cleanup_tasks = [asyncio.create_task(_db_cache_writeback())] + if self.__ui_server is not None: + self.__ui_server.close() + + def _stop_event_wait(self): # TODO: this is currently being used by ui proto to stop long connections, but not used in task proto, but what if it'll also get long living connections? + return self.__stop_event.wait() + + async def start(self): + # prepare + async with self.data_access.data_connection() as con: + # we play it the safest for now: + # all workers set to UNKNOWN state, all active invocations are reset, all tasks in the middle of processing are reset to closest waiting state + con.row_factory = aiosqlite.Row + await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" IN (?, ?)', + (TaskState.READY.value, TaskState.IN_PROGRESS.value, TaskState.INVOKING.value)) + await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', + (TaskState.WAITING.value, TaskState.GENERATING.value)) + await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', + (TaskState.WAITING.value, TaskState.WAITING_BLOCKED.value)) + await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', + (TaskState.POST_WAITING.value, TaskState.POST_GENERATING.value)) + await con.execute('UPDATE "tasks" SET "state" = ? WHERE "state" = ?', + (TaskState.POST_WAITING.value, TaskState.POST_WAITING_BLOCKED.value)) + await con.execute('UPDATE "invocations" SET "state" = ? WHERE "state" = ?', (InvocationState.FINISHED.value, InvocationState.IN_PROGRESS.value)) + # for now invoking invocation are invalidated by deletion (here and in task_processor) + await con.execute('DELETE FROM invocations WHERE "state" = ?', (InvocationState.INVOKING.value,)) + await con.execute('UPDATE workers SET "ping_state" = ?', (WorkerPingState.UNKNOWN.value,)) + await con.execute('UPDATE "workers" SET "state" = ?', (WorkerState.UNKNOWN.value,)) + await con.commit() + + # update volatile mem cache: + async with con.execute('SELECT "id", last_seen, last_checked, ping_state FROM workers') as worcur: + async for row in worcur: + self.data_access.mem_cache_workers_state[row['id']] = {k: row[k] for k in dict(row)} + + # start + loop = asyncio.get_event_loop() + self.__legacy_command_server = await loop.create_server(**self.__legacy_server_coro_args) + self.__ui_server = await loop.create_server(**self.__ui_server_coro_args) + # start message processor + + self.__message_processor = self.__message_processor_factory(self, self.__message_processor_addresses) + await self.__message_processor.start() + worker_pool_message_proxy_address = (self.__message_processor_addresses[0].split(':', 1)[0], None) # use same ip as scheduler's message processor, but default port + self.__worker_pool = WorkerPool(WorkerType.SCHEDULER_HELPER, + minimal_idle_to_ensure=self.__worker_pool_helpers_minimal_idle_to_ensure, + scheduler_address=self.server_message_address(DirectAddress(worker_pool_message_proxy_address[0])), + message_proxy_address=worker_pool_message_proxy_address, + message_processor_factory=WorkerPoolMessageProcessor, + ) + await self.__worker_pool.start() + # + # broadcasting + if self.__do_broadcasting: + # need to start a broadcaster for each interface from union of message and ui addresses + for iface_addr in all_interfaces()[1:]: # skipping first, as first is localhost + broadcast_address = get_broadcast_addr_for(iface_addr) + if broadcast_address is None: # broadcast not supported + continue + broadcast_data = {} + if direct_address := {x.split(':', 1)[0]: x for x in self.__message_processor_addresses}.get(iface_addr): + broadcast_data['message_address'] = str(direct_address) + if iface_addr == self.__ui_address[0] or self.__ui_address[0] == '0.0.0.0': + broadcast_data['ui'] = ':'.join(str(x) for x in (iface_addr, self.__ui_address[1])) + if iface_addr == self.__legacy_command_server_address[0] or self.__legacy_command_server_address[0] == '0.0.0.0': + broadcast_data['worker'] = ':'.join(str(x) for x in (iface_addr, self.__legacy_command_server_address[1])) + self.__broadcasting_servers.append( + ( + broadcast_address, + await create_broadcaster( + 'lifeblood_scheduler', + json.dumps(broadcast_data), + ip=broadcast_address, + broadcast_interval=self.__broadcasting_interval + ) + ) + ) + + await self.task_processor.start() + await self.__pinger.start() + await self.ui_state_access.start() + # run + self.__all_components = \ + asyncio.gather(self.task_processor.wait_till_stops(), + self.__pinger.wait_till_stops(), + self.ui_state_access.wait_till_stops(), + self.__legacy_command_server.wait_closed(), # TODO: shit being waited here below is very unnecessary + self.__ui_server.wait_closed(), + self.__worker_pool.wait_till_stops()) + + self.__started_event.set() + # print information + self.__logger.info('scheduler started') + self.__logger.info( + 'scheduler listening on:\n' + ' message processors:\n' + + '\n'.join((f' {addr}' for addr in self.__message_processor_addresses)) + + '\n' + ' ui servers:\n' + f' {":".join(str(x) for x in self.__ui_address)}\n' + ' legacy command servers:\n' + f' {":".join(str(x) for x in self.__legacy_command_server_address)}' + ) + self.__logger.info( + 'broadcasting enabled for:\n' + + '\n'.join((f' {info[0]}' for info in self.__broadcasting_servers)) + ) + + async def wait_till_starts(self): + return await self.__started_event.wait() + + async def wait_till_stops(self): + await self.__started_event.wait() + assert self.__all_components is not None + await self.__all_components + await self.__server_closing_task + for task in self.__cleanup_tasks: + await task + + async def _save_all_cached_nodes_to_db(self): + self.__logger.info('saving nodes to db') + for node_id in self.__node_objects: + await self.save_node_to_database(node_id) + self.__logger.debug(f'node {node_id} saved to db') + + def is_started(self): + return self.__started_event.is_set() + + def is_stopping(self) -> bool: + """ + True if stopped or in process of stopping + """ + return self.__stop_event.is_set() + + # + # helper functions + # + + async def reset_invocations_for_worker(self, worker_id: int, con: aiosqlite_overlay.ConnectionWithCallbacks, also_update_resources=True) -> bool: + """ + + :param worker_id: + :param con: + :param also_update_resources: + :return: need commit? + """ + async with con.execute('SELECT * FROM invocations WHERE "worker_id" = ? AND "state" == ?', + (worker_id, InvocationState.IN_PROGRESS.value)) as incur: + all_invoc_rows = await incur.fetchall() # we don't really want to update db while reading it + need_commit = False + for invoc_row in all_invoc_rows: # mark all (probably single one) invocations + need_commit = True + self.__logger.debug("fixing dangling invocation %d" % (invoc_row['id'],)) + await con.execute('UPDATE invocations SET "state" = ? WHERE "id" = ?', + (InvocationState.FINISHED.value, invoc_row['id'])) + await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', + (TaskState.READY.value, invoc_row['task_id'])) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, invoc_row['task_id']) # ui event + if also_update_resources: + need_commit = need_commit or await self._update_worker_resouce_usage(worker_id, connection=con) + return need_commit + + # + # invocation consistency checker + async def invocation_consistency_checker(self): + """ + both scheduler and woker might crash at any time. so we need to check that + worker may crash working on a task ( + :return: + """ + pass + + # + # callbacks + + # + # worker reports done task + async def task_done_reported(self, task: Invocation, stdout: str, stderr: str): + """ + scheduler comm protocols should call this when a task is done + TODO: this is almost the same code as for task_cancel_reported, maybe unify? + """ + for attempt in range(120): # TODO: this should be configurable + # if invocation is super fast - this may happen even before submission is completed, + # so we might need to wait a bit + try: + return await self.__task_done_reported_inner(task, stdout, stderr) + except NeedToRetryLater: + self.__logger.debug('attempt %d to report invocation %d done notified it needs to wait', attempt, task.invocation_id()) + await asyncio.sleep(0.5) # TODO: this should be configurable + continue + else: + self.__logger.error(f'out of attempts trying to report done invocation {task.invocation_id()}, probably something is not right with the state of the database') + + async def __task_done_reported_inner(self, task: Invocation, stdout: str, stderr: str): + """ + + """ + async with self.__invocation_reporting_lock, \ + self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + self.__logger.debug('task finished reported %s code %s', repr(task), task.exit_code()) + # sanity check + async with con.execute('SELECT "state" FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as cur: + invoc = await cur.fetchone() + if invoc is None: + self.__logger.error('reported task has non existing invocation id %d' % task.invocation_id()) + return + if invoc['state'] == InvocationState.INVOKING.value: # means _submitter has not yet finished, we should wait + raise NeedToRetryLater() + elif invoc['state'] != InvocationState.IN_PROGRESS.value: + self.__logger.warning(f'reported task for a finished invocation. assuming that worker failed to cancel task previously and ignoring invocation results. (state={invoc["state"]})') + return + await con.execute('UPDATE invocations SET "state" = ?, "return_code" = ?, "runtime" = ? WHERE "id" = ?', + (InvocationState.FINISHED.value, task.exit_code(), task.running_time(), task.invocation_id())) + async with con.execute('SELECT * FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as incur: + invocation = await incur.fetchone() + assert invocation is not None + + await con.execute('UPDATE workers SET "state" = ? WHERE "id" = ?', + (WorkerState.IDLE.value, invocation['worker_id'])) + await self._update_worker_resouce_usage(invocation['worker_id'], connection=con) # remove resource usage info + tasks_to_wait = [] + if not self.__use_external_log: + await con.execute('UPDATE invocations SET "stdout" = ?, "stderr" = ? WHERE "id" = ?', + (stdout, stderr, task.invocation_id())) + else: + await con.execute('UPDATE invocations SET "log_external" = 1 WHERE "id" = ?', + (task.invocation_id(),)) + tasks_to_wait.append(asyncio.create_task(self._save_external_logs(task.invocation_id(), stdout, stderr))) + + self.data_access.clear_invocation_progress(task.invocation_id()) + + ui_task_delta = TaskDelta(invocation['task_id']) # for ui event + if task.finished_needs_retry(): # max retry count will be checked by task processor + await con.execute('UPDATE tasks SET "state" = ?, "work_data_invocation_attempt" = "work_data_invocation_attempt" + 1 WHERE "id" = ?', + (TaskState.READY.value, invocation['task_id'])) + ui_task_delta.state = TaskState.READY # for ui event + elif task.finished_with_error(): + state_details = json.dumps({'message': f'see invocation #{invocation["id"]} log for details', + 'happened_at': TaskState.IN_PROGRESS.value, + 'type': 'invocation'}) + await con.execute('UPDATE tasks SET "state" = ?, "state_details" = ? WHERE "id" = ?', + (TaskState.ERROR.value, + state_details, + invocation['task_id'])) + ui_task_delta.state = TaskState.ERROR # for ui event + ui_task_delta.state_details = state_details # for ui event + else: + await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', + (TaskState.POST_WAITING.value, invocation['task_id'])) + ui_task_delta.state = TaskState.POST_WAITING # for ui event + + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, ui_task_delta) # ui event + await con.commit() + if len(tasks_to_wait) > 0: + await asyncio.wait(tasks_to_wait) + self.wake() + self.poke_task_processor() + + async def _save_external_logs(self, invocation_id, stdout, stderr): + logbasedir = self.__external_log_location / 'invocations' / f'{invocation_id}' + try: + if not logbasedir.exists(): + logbasedir.mkdir(exist_ok=True) + async with aiofiles.open(logbasedir / 'stdout.log', 'w') as fstdout, \ + aiofiles.open(logbasedir / 'stderr.log', 'w') as fstderr: + await asyncio.gather(fstdout.write(stdout), + fstderr.write(stderr)) + except OSError: + self.__logger.exception('error happened saving external logs! Ignoring this error') + + # + # worker reports canceled task + async def task_cancel_reported(self, task: Invocation, stdout: str, stderr: str): + """ + scheduler comm protocols should call this when a task is cancelled + """ + for attempt in range(120): # TODO: this should be configurable + # if invocation is super fast - this may happen even before submission is completed, + # so we might need to wait a bit + try: + return await self.__task_cancel_reported_inner(task, stdout, stderr) + except NeedToRetryLater: + self.__logger.debug('attempt %d to report invocation %d cancelled notified it needs to wait', attempt, task.invocation_id()) + await asyncio.sleep(0.5) # TODO: this should be configurable + continue + else: + self.__logger.error(f'out of attempts trying to report cancel invocation {task.invocation_id()}, probably something is not right with the state of the database') + + async def __task_cancel_reported_inner(self, task: Invocation, stdout: str, stderr: str): + async with self.__invocation_reporting_lock, \ + self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + self.__logger.debug('task cancelled reported %s', repr(task)) + # sanity check + async with con.execute('SELECT "state" FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as cur: + invoc = await cur.fetchone() + if invoc is None: + self.__logger.error('reported task has non existing invocation id %d' % task.invocation_id()) + return + if invoc['state'] == InvocationState.INVOKING.value: # means _submitter has not yet finished, we should wait + raise NeedToRetryLater() + elif invoc['state'] != InvocationState.IN_PROGRESS.value: + self.__logger.warning(f'reported task for a finished invocation. assuming that worker failed to cancel task previously and ignoring invocation results. (state={invoc["state"]})') + return + await con.execute('UPDATE invocations SET "state" = ?, "runtime" = ? WHERE "id" = ?', + (InvocationState.FINISHED.value, task.running_time(), task.invocation_id())) + async with con.execute('SELECT * FROM invocations WHERE "id" = ?', (task.invocation_id(),)) as incur: + invocation = await incur.fetchone() + assert invocation is not None + + self.data_access.clear_invocation_progress(task.invocation_id()) + + await con.execute('UPDATE workers SET "state" = ? WHERE "id" = ?', + (WorkerState.IDLE.value, invocation['worker_id'])) + await self._update_worker_resouce_usage(invocation['worker_id'], connection=con) # remove resource usage info + tasks_to_wait = [] + if not self.__use_external_log: + await con.execute('UPDATE invocations SET "stdout" = ?, "stderr" = ? WHERE "id" = ?', + (stdout, stderr, task.invocation_id())) + else: + await con.execute('UPDATE invocations SET "log_external" = 1, "stdout" = null, "stderr" = null WHERE "id" = ?', + (task.invocation_id(),)) + tasks_to_wait.append(asyncio.create_task(self._save_external_logs(task.invocation_id(), stdout, stderr))) + await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', + (TaskState.READY.value, invocation['task_id'])) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(invocation['task_id'], state=TaskState.READY)) # ui event + await con.commit() + if len(tasks_to_wait) > 0: + await asyncio.wait(tasks_to_wait) + self.__logger.debug(f'cancelling task done {repr(task)}') + self.wake() + self.poke_task_processor() + + # + # add new worker to db + async def add_worker( + # TODO: WorkerResources (de)serialization + # TODO: Worker actually passing new WorkerResources on hello + self, addr: str, worker_type: WorkerType, worker_resources: HardwareResources, # TODO: all resource should also go here + *, + assume_active: bool = True, + worker_metadata: WorkerMetadata): + """ + this is called by network protocol handler when worker reports being up to the scheduler + """ + self.__logger.debug(f'worker reported added: {addr}') + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('BEGIN IMMEDIATE') # important to have locked DB during all this state change + # logic for now: + # - search for same last_address, same hwid + # - if no - search for first entry (OFF or UNKNOWN) with same hwid, ignore address + # - in this case also delete addr from DB if exists + async with con.execute('SELECT "id", state FROM "workers" WHERE "last_address" == ? AND hwid == ?', (addr, worker_resources.hwid)) as worcur: + worker_row = await worcur.fetchone() + if worker_row is None: + # first ensure that there is no entry with the same address + await con.execute('UPDATE "workers" SET "last_address" = ? WHERE "last_address" == ?', (None, addr)) + async with con.execute('SELECT "id", state FROM "workers" WHERE hwid == ? AND ' + '(state == ? OR state == ?)', (worker_resources.hwid, + WorkerState.OFF.value, WorkerState.UNKNOWN.value)) as worcur: + worker_row = await worcur.fetchone() + if assume_active: + ping_state = WorkerPingState.WORKING.value + state = WorkerState.IDLE.value + else: + ping_state = WorkerPingState.OFF.value + state = WorkerState.OFF.value + + tstamp = int(time.time()) + if worker_row is not None: + if worker_row['state'] == WorkerState.INVOKING.value: # so we are in the middle of sumbission + state = WorkerState.INVOKING.value # then we preserve INVOKING state + await self.reset_invocations_for_worker(worker_row['id'], con=con, also_update_resources=False) # we update later + await con.execute('UPDATE "workers" SET ' + 'hwid=?, ' + 'last_seen=?, ping_state=?, state=?, worker_type=?, ' + 'last_address=? ' + 'WHERE "id"=?', + (worker_resources.hwid, + tstamp, ping_state, state, worker_type.value, + addr, + worker_row['id'])) + # async with con.execute('SELECT "id" FROM "workers" WHERE last_address=?', (addr,)) as worcur: + # worker_id = (await worcur.fetchone())['id'] + worker_id = worker_row['id'] + self.data_access.mem_cache_workers_state[worker_id].update({'last_seen': tstamp, + 'last_checked': tstamp, + 'ping_state': ping_state, + 'worker_id': worker_id}) + # await con.execute('UPDATE tmpdb.tmp_workers_states SET ' + # 'last_seen=?, ping_state=? ' + # 'WHERE worker_id=?', + # (tstamp, ping_state, worker_id)) + else: + async with con.execute('INSERT INTO "workers" ' + '(hwid, ' + 'last_address, last_seen, ping_state, state, worker_type) ' + 'VALUES ' + '(?, ?, ?, ?, ?, ?)', + (worker_resources.hwid, addr, tstamp, ping_state, state, worker_type.value)) as insworcur: + worker_id = insworcur.lastrowid + self.data_access.mem_cache_workers_state[worker_id] = {'last_seen': tstamp, + 'last_checked': tstamp, + 'ping_state': ping_state, + 'worker_id': worker_id} + # await con.execute('INSERT INTO tmpdb.tmp_workers_states ' + # '(worker_id, last_seen, ping_state) ' + # 'VALUES ' + # '(?, ?, ?)', + # (worker_id, tstamp, ping_state)) + + resource_fields: Tuple[str, ...] = tuple(x.name for x in self.__config_provider.hardware_resource_definitions()) + # device_type_names = tuple(x.name for x in self.__config_provider.hardware_device_type_definitions()) + device_type_resource_fields: Dict[str, Tuple[str, ...]] = {x.name: tuple(r.name for r in x.resources) for x in self.__config_provider.hardware_device_type_definitions()} + # in case worker_resources contain dev_types not known to config - they will be ignored + devices_to_register = [] + # checks + for field in resource_fields: + if field not in worker_resources: + self.__logger.warning(f'worker (hwid:{worker_resources.hwid}) does not declare expected resource "{field}", assume value=0') + for res_name, _ in worker_resources.items(): + if res_name not in resource_fields: + self.__logger.warning(f'worker (hwid:{worker_resources.hwid}) declares resource "{res_name}" unknown to the scheduler, ignoring') + for dev_type, dev_name, dev_res in worker_resources.devices(): + if dev_type not in device_type_resource_fields: + self.__logger.warning(f'worker (hwid:{worker_resources.hwid}) declares device type "{dev_type}" unknown to the scheduler, ignoring') + continue + devices_to_register.append((dev_type, dev_name, {res_name: res_val for res_name, res_val in dev_res.items() if res_name in device_type_resource_fields[dev_type]})) + + # TODO: note that below sql breaks if there are no resource_fields (which is an unlikely config, but not impossible) + await con.execute('INSERT INTO resources ' + '(hwid, ' + + ', '.join(f'{field}, total_{field}' for field in resource_fields) + + ') ' + 'VALUES (?' + ', ?' * (2 * len(resource_fields)) + ') ' + 'ON CONFLICT(hwid) DO UPDATE SET ' + + ', '.join(f'"{field}"=excluded.{field}, "total_{field}"=excluded.total_{field}' for field in resource_fields) + , + (worker_resources.hwid, + *(x for field in resource_fields for x in ( + (worker_resources[field].value, worker_resources[field].value) if field in worker_resources else (0, 0)) # TODO: do NOT invent defaults here, only set known fields, like in dev code below + )) + ) + + for dev_type, dev_name, dev_res in sorted(devices_to_register, key=lambda x: (x[0], x[1])): # sort by (deva_type, dev_name) to ensure some consistent order + + dev_type_table_name = f'hardware_device_type__{dev_type}' + if dev_res: + await con.execute( + f'INSERT INTO "{dev_type_table_name}" ' + f'(hwid, hw_dev_name, ' + + ', '.join(f'res__{field}' for field, _ in dev_res.items()) + + ') ' + 'VALUES (?, ?' + ', ?' * (len(dev_res)) + ') ' + 'ON CONFLICT(hwid,"hw_dev_name") DO UPDATE SET ' + + ', '.join(f'"res__{field}"=excluded.res__{field}' for field in dev_res) + , + (worker_resources.hwid, dev_name, + *(res_val.value for _, res_val in dev_res.items()) + ) + ) + else: + await con.execute( + f'INSERT INTO "{dev_type_table_name}" ' + f'(hwid, hw_dev_name) ' + + 'VALUES (?, ?) ' + 'ON CONFLICT(hwid,"hw_dev_name") DO NOTHING' + , + (worker_resources.hwid, dev_name) + ) + + await self._update_worker_resouce_usage(worker_id, hwid=worker_resources.hwid, connection=con) # used resources are inited to none + self.data_access.set_worker_metadata(worker_resources.hwid, worker_metadata) + await con.commit() + self.__logger.debug(f'finished worker reported added: {addr}') + self.poke_task_processor() + + # TODO: add decorator that locks method from reentry or smth + # potentially a worker may report done while this works, + # or when scheduler picked worker and about to run this, which will lead to inconsistency warning + # NOTE!: so far it's always called from a STARTED transaction, so there should not be reentry possible + # But that is not enforced right now, easy to make mistake + async def _update_worker_resouce_usage(self, worker_id: int, resources: Optional[Requirements] = None, *, hwid=None, connection: aiosqlite.Connection) -> bool: + """ + updates resource information based on new worker resources usage + as part of ongoing transaction + Note: con SHOULD HAVE STARTED TRANSACTION, otherwise it might be not safe to call this + + :param worker_id: + :param hwid: if hwid of worker_id is already known - provide it here to skip extra db query. but be SURE it's correct! + :param connection: opened db connection. expected to have Row as row factory + :return: if commit is needed on connection (if db set operation happened) + """ + assert connection.in_transaction, 'expectation failure' + + resource_fields = tuple(x.name for x in self.__config_provider.hardware_resource_definitions()) + device_type_names = tuple(x.name for x in self.__config_provider.hardware_device_type_definitions()) + + workers_resources = self.data_access.mem_cache_workers_resources + if hwid is None: + async with connection.execute('SELECT "hwid" FROM "workers" WHERE "id" == ?', (worker_id,)) as worcur: + hwid = (await worcur.fetchone())['hwid'] + + # calculate available resources NOT counting current worker_id + async with connection.execute(f'SELECT ' + f'{", ".join(resource_fields)}, ' + f'{", ".join("total_" + x for x in resource_fields)} ' + f'FROM resources WHERE hwid == ?', (hwid,)) as rescur: + available_res = dict(await rescur.fetchone()) + available_dev_type_to_ids: Dict[str, Dict[int, Dict[str, Union[int, float, str]]]] = {} + current_available_dev_type_to_ids: Dict[str, Set[int]] = {} + for dev_type in device_type_names: + dev_type_table_name = f'hardware_device_type__{dev_type}' + async with connection.execute( + f'SELECT * FROM "{dev_type_table_name}" WHERE hwid == ?', + (hwid,)) as rescur: + all_dev_rows = [dict(x) for x in await rescur.fetchall()] + available_dev_type_to_ids[dev_type] = { + x['dev_id']: {k[len('res__'):]: v for k, v in x.items() if k.startswith('res__')} # resource cols start with res__ + for x in all_dev_rows + } # note, these are ALL devices, with "available" 0 and 1 values, we don't *trust* "available", we recalc them below, just like with non-total res + current_available_dev_type_to_ids[dev_type] = {x['dev_id'] for x in all_dev_rows if x['available']} # now this counts available to check later if anything changed + current_available = {k: v for k, v in available_res.items() if not k.startswith('total_')} + available_res = {k[len('total_'):]: v for k, v in available_res.items() if k.startswith('total_')} # start with full total res + + for wid, res in workers_resources.items(): + if wid == worker_id: + continue # SKIP worker_id currently being set + if res.get('hwid') != hwid: + continue + # recalc actual available resources based on cached worker_resources + for field in resource_fields: + if field not in res.get('res', {}): + continue + available_res[field] -= res['res'][field] + # recalc actual available devices based on cached worker_resources + for dev_type in device_type_names: + if dev_type not in res.get('dev', {}): + continue + for dev_id in res['dev'][dev_type]: + available_dev_type_to_ids[dev_type].pop(dev_id) + ## + + # now choose proper amount of resources to pick + if resources is None: + workers_resources[worker_id] = {'hwid': hwid} # remove resource usage info + else: + workers_resources[worker_id] = {'res': {}, 'dev': {}} + for field in resource_fields: + if field not in resources.resources: + continue + if available_res[field] < resources.resources[field].min: + raise NotEnoughResources(f'{field}: {resources.resources[field].min} out of {available_res[field]}') + # so we take preferred amount of resources (or minimum if pref not set), but no more than available + # if preferred is lower than min - it's ignored + workers_resources[worker_id]['res'][field] = min(available_res[field], + max(resources.resources[field].pref, resources.resources[field].min)) + available_res[field] -= workers_resources[worker_id]['res'][field] + + selected_devs: Dict[str, List[int]] = {} # dev_type to list of dev_ids of that type that are picked + for dev_type, dev_reqs in resources.devices.items(): + if dev_reqs.min == 0 and dev_reqs.pref == 0: # trivial check + continue + if dev_type not in available_dev_type_to_ids: + if dev_reqs.min > 0: + raise NotEnoughResources(f'device "{dev_type}" missing') # this shouldn't happen - this whole func is only called when resources are checked + else: + continue + for dev_id, dev_res in available_dev_type_to_ids[dev_type].items(): + # now we check if dev fits requirements + is_good = True + for req_name, req_val in dev_reqs.resources.items(): # we ignore pref in current logic - devices are always taken full + if req_name not in dev_res: + raise NotEnoughResources(f'device "{dev_type}" does not have requested resource "{req_name}"') # this also should not happen + if dev_res[req_name] < req_val.min: + is_good = False + break + if is_good: + selected_devs.setdefault(dev_type, []).append(dev_id) + if len(selected_devs[dev_type]) >= max(dev_reqs.min, dev_reqs.pref): + # we selected enough devices of this type + break + # now remove selected from available + for dev_id in selected_devs.get(dev_type, []): + available_dev_type_to_ids[dev_type].pop(dev_id) + # sanity check + if dev_reqs.min > 0 and len(selected_devs[dev_type]) < dev_reqs.min: + raise NotEnoughResources(f'device "{dev_type}: cannot select {dev_reqs.min} out of {len(selected_devs[dev_type])}') + workers_resources[worker_id]['dev'] = selected_devs + + workers_resources[worker_id]['hwid'] = hwid # just to ensure it was not overriden + + self.__logger.debug(f'updating resources {hwid} with {available_res} against {current_available}') + self.__logger.debug(workers_resources) + + available_res_didnt_change = available_res == current_available + available_devs_didnt_change = all(set(available_dev_type_to_ids[dev_type].keys()) == current_available_dev_type_to_ids[dev_type] for dev_type in device_type_names) + if available_res == current_available and available_devs_didnt_change: # nothing needs to be updated + return False + + if not available_res_didnt_change: + await connection.execute(f'UPDATE resources SET {", ".join(f"{k}={v}" for k, v in available_res.items())} WHERE hwid == ?', (hwid,)) + if not available_devs_didnt_change: + for dev_type in device_type_names: # TODO: only update affected tables + dev_type_table_name = f'hardware_device_type__{dev_type}' + await connection.execute(f'UPDATE "{dev_type_table_name}" SET "available"=0 WHERE hwid==?', (hwid,)) + await connection.executemany(f'UPDATE "{dev_type_table_name}" SET "available"=1 WHERE dev_id==?', ((x,) for x in available_dev_type_to_ids[dev_type].keys())) + return True + + # + # + async def update_invocation_progress(self, invocation_id: int, progress: float): + """ + report progress update on invocation that is being worked on + there are not too many checks here, as progress report is considered non-vital information, + so if such message comes after invocation is finished - it's not big deal + """ + prev_progress = self.data_access.get_invocation_progress(invocation_id) + self.data_access.set_invocation_progress(invocation_id, progress) + if prev_progress != progress: + task_id = None + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT task_id FROM invocations WHERE "state" == ? AND "id" == ?', + (InvocationState.IN_PROGRESS.value, invocation_id,)) as cur: + task_id_row = await cur.fetchone() + if task_id_row is not None: + task_id = task_id_row['task_id'] + if task_id is not None: + self.ui_state_access.scheduler_reports_task_updated(TaskDelta(task_id, progress=progress)) + + # + # worker reports it being stopped + async def worker_stopped(self, addr: str): + """ + + :param addr: + :return: + """ + self.__logger.debug(f'worker reported stopped: {addr}') + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('BEGIN IMMEDIATE') + async with con.execute('SELECT id, hwid from "workers" WHERE "last_address" = ?', (addr,)) as worcur: + worker_row = await worcur.fetchone() + if worker_row is None: + self.__logger.warning(f'unregistered worker reported "stopped": {addr}, ignoring') + await con.rollback() + return + wid = worker_row['id'] + hwid = worker_row['hwid'] + # print(wid) + + # we ensure there are no invocations running with this worker + async with con.execute('SELECT "id", task_id FROM invocations WHERE worker_id = ? AND ("state" = ? OR "state" = ?)', + (wid, InvocationState.IN_PROGRESS.value, InvocationState.INVOKING.value)) as invcur: + invocations = await invcur.fetchall() + + await con.execute('UPDATE workers SET "state" = ? WHERE "id" = ?', (WorkerState.OFF.value, wid)) + await con.executemany('UPDATE invocations SET state = ? WHERE "id" = ?', ((InvocationState.FINISHED.value, x["id"]) for x in invocations)) + await con.executemany('UPDATE tasks SET state = ? WHERE "id" = ?', ((TaskState.READY.value, x["task_id"]) for x in invocations)) + await self._update_worker_resouce_usage(wid, hwid=hwid, connection=con) # oh wait, it happens right here, still an assert won't hurt + del self.data_access.mem_cache_workers_resources[wid] # remove from cache # TODO: ENSURE resources were already unset for this wid + if len(invocations) > 0: + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_updated, [TaskDelta(x["task_id"], state=TaskState.READY) for x in invocations]) # ui event + await con.commit() + self.__logger.debug(f'finished worker reported stopped: {addr}') + + # + # protocol related commands + # + # + # cancel invocation + async def cancel_invocation(self, invocation_id: str): + self.__logger.debug(f'canceling invocation {invocation_id}') + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT * FROM "invocations" WHERE "id" = ?', (invocation_id,)) as cur: + invoc = await cur.fetchone() + if invoc is None or invoc['state'] != InvocationState.IN_PROGRESS.value: + return + async with con.execute('SELECT "last_address" FROM "workers" WHERE "id" = ?', (invoc['worker_id'],)) as cur: + worker = await cur.fetchone() + if worker is None: + self.__logger.error('inconsistent worker ids? how?') + return + addr = AddressChain(worker['last_address']) + + # the logic is: + # - we send the worker a signal to cancel invocation + # - later worker sends task_cancel_reported, and we are happy + # - but worker might be overloaded, broken or whatever and may never send it. and it can even finish task and send task_done_reported, witch we need to treat + with WorkerControlClient.get_worker_control_client(addr, self.message_processor()) as client: # type: WorkerControlClient + await client.cancel_task() + + # oh no, we don't do that, we wait for worker to report task canceled. await con.execute('UPDATE invocations SET "state" = ? WHERE "id" = ?', (InvocationState.FINISHED.value, invocation_id)) + + # + # + async def cancel_invocation_for_task(self, task_id: int): + self.__logger.debug(f'canceling invocation for task {task_id}') + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT "id" FROM "invocations" WHERE "task_id" = ? AND state = ?', (task_id, InvocationState.IN_PROGRESS.value)) as cur: + invoc = await cur.fetchone() + if invoc is None: + return + return await self.cancel_invocation(invoc['id']) + + # + # + async def cancel_invocation_for_worker(self, worker_id: int): + self.__logger.debug(f'canceling invocation for worker {worker_id}') + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT "id" FROM "invocations" WHERE "worker_id" == ? AND state == ?', (worker_id, InvocationState.IN_PROGRESS.value)) as cur: + invoc = await cur.fetchone() + if invoc is None: + return + return await self.cancel_invocation(invoc['id']) + + # + # + async def force_set_node_task(self, task_id: int, node_id: int): + self.__logger.debug(f'forcing task {task_id} to node {node_id}') + try: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('BEGIN IMMEDIATE') + await con.execute('PRAGMA FOREIGN_KEYS = on') + async with con.execute('SELECT "state" FROM tasks WHERE "id" == ?', (task_id,)) as cur: + row = await cur.fetchone() + if row is None: + self.__logger.warning(f'failed to force task node: task {task_id} not found') + await con.rollback() + return + + state = TaskState(row['state']) + new_state = None + if state in (TaskState.WAITING, TaskState.READY, TaskState.POST_WAITING): + new_state = TaskState.WAITING + elif state == TaskState.DONE: + new_state = TaskState.DONE + # if new_state was not set - means state was invalid + if new_state is None: + self.__logger.warning(f'changing node of a task in state {state.name} is not allowed') + await con.rollback() + raise ValueError(f'changing node of a task in state {state.name} is not allowed') + + await con.execute('UPDATE tasks SET "node_id" = ?, "state" = ? WHERE "id" = ?', (node_id, new_state.value, task_id)) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, node_id=node_id)) # ui event + # reset blocking too + await self.data_access.reset_task_blocking(task_id, con=con) + await con.commit() + except aiosqlite.IntegrityError: + self.__logger.error(f'could not set task {task_id} to node {node_id} because of database integrity check') + raise DataIntegrityError() from None + else: + self.wake() + self.poke_task_processor() + + # + # force change task state + async def force_change_task_state(self, task_ids: Union[int, Iterable[int]], state: TaskState): + """ + forces task into given state. + obviously a task cannot be forced into certain states, like IN_PROGRESS, GENERATING, POST_GENERATING + :param task_ids: + :param state: + :return: + """ + if state in (TaskState.IN_PROGRESS, TaskState.GENERATING, TaskState.POST_GENERATING): + self.__logger.error(f'cannot force task {task_ids} into state {state}') + return + if isinstance(task_ids, int): + task_ids = [task_ids] + query = 'UPDATE tasks SET "state" = %d WHERE "id" = ?' % state.value + # print('beep') + async with self.data_access.data_connection() as con: + for task_id in task_ids: + await con.execute('BEGIN IMMEDIATE') + async with con.execute('SELECT "state" FROM tasks WHERE "id" = ?', (task_id,)) as cur: + cur_state = await cur.fetchone() + if cur_state is None: + await con.rollback() + continue + cur_state = TaskState(cur_state[0]) + if cur_state in (TaskState.IN_PROGRESS, TaskState.GENERATING, TaskState.POST_GENERATING): + self.__logger.warning(f'forcing task out of state {cur_state} is not allowed') + await con.rollback() + continue + + await con.execute(query, (task_id,)) + # just in case we also reset blocking + await self.data_access.reset_task_blocking(task_id, con=con) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, state=state)) # ui event + await con.commit() # TODO: this can be optimized into a single transaction + # print('boop') + self.wake() + self.poke_task_processor() + + # + # change task's paused state + async def set_task_paused(self, task_ids_or_group: Union[int, Iterable[int], str], paused: bool): + if isinstance(task_ids_or_group, str): + async with self.data_access.data_connection() as con: + await con.execute('UPDATE tasks SET "paused" = ? WHERE "id" IN (SELECT "task_id" FROM task_groups WHERE "group" = ?)', + (int(paused), task_ids_or_group)) + ui_task_ids = await self.ui_state_access._get_group_tasks(task_ids_or_group) # ui event + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_updated, [TaskDelta(ui_task_id, paused=paused) for ui_task_id in ui_task_ids]) # ui event + await con.commit() + self.wake() + self.poke_task_processor() + return + if isinstance(task_ids_or_group, int): + task_ids_or_group = [task_ids_or_group] + query = 'UPDATE tasks SET "paused" = %d WHERE "id" = ?' % int(paused) + async with self.data_access.data_connection() as con: + await con.executemany(query, ((x,) for x in task_ids_or_group)) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_updated, [TaskDelta(ui_task_id, paused=paused) for ui_task_id in task_ids_or_group]) # ui event + await con.commit() + self.wake() + self.poke_task_processor() + + # + # change task group archived state + async def set_task_group_archived(self, task_group_name: str, state: TaskGroupArchivedState = TaskGroupArchivedState.ARCHIVED) -> None: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('UPDATE task_group_attributes SET state=? WHERE "group"==?', (state.value, task_group_name)) # this triggers all task deadness | 2, so potentially it can be long, beware + # task's dead field's 2nd bit is set, but we currently do not track it + # so no event needed + await con.commit() + if state == TaskGroupArchivedState.NOT_ARCHIVED: + self.poke_task_processor() # unarchived, so kick task processor, just in case + return + # otherwise - it's archived + # now all tasks belonging to that group should be set to dead|2 + # we need to make sure to cancel all running invocations for those tasks + # at this point tasks are archived and won't be processed, + # so we only expect concurrent changes due to already running _submitters and _awaiters, + # like INVOKING->IN_PROGRESS + async with con.execute('SELECT "id" FROM invocations ' + 'INNER JOIN task_groups ON task_groups.task_id == invocations.task_id ' + 'WHERE task_groups."group" == ? AND invocations.state == ?', + (task_group_name, InvocationState.INVOKING.value)) as cur: + invoking_invoc_ids = set(x['id'] for x in await cur.fetchall()) + async with con.execute('SELECT "id" FROM invocations ' + 'INNER JOIN task_groups ON task_groups.task_id == invocations.task_id ' + 'WHERE task_groups."group" == ? AND invocations.state == ?', + (task_group_name, InvocationState.IN_PROGRESS.value)) as cur: + active_invoc_ids = tuple(x['id'] for x in await cur.fetchall()) + # i sure use a lot of fetchall where it's much more natural to iterate cursor + # that is because of a fear of db locking i got BEFORE switching to WAL, when iterating connection was randomly crashing other connections not taking timeout into account at all. + + # note at this point we might have some invoking_invocs_id, but at this point some of them + # might already have been set to in-progress and even got into active_invoc_ids list + + # first - cancel all in-progress invocations + for inv_id in active_invoc_ids: + await self.cancel_invocation(inv_id) + + # now since we dont have the ability to safely cancel running _submitter task - we will just wait till + # invoking invocations change state + # sure it's a bit bruteforce + # but a working solution for now + if len(invoking_invoc_ids) == 0: + return + async with self.data_access.data_connection() as con: + while len(invoking_invoc_ids) > 0: + # TODO: this forever while doesn't seem right + # in average case it should basically never happen at all + # only in case of really bad buggy network connections an invocation can get stuck on INVOKING + # but there are natural timeouts in _submitter that will switch it from INVOKING eventually + # the only question is - do we want to just stay in this function until it's resolved? UI's client is single thread, so it will get stuck waiting + con.row_factory = aiosqlite.Row + async with con.execute('SELECT "id",state FROM invocations WHERE state!={} AND "id" IN ({})'.format( + InvocationState.IN_PROGRESS.value, + ','.join(str(x) for x in invoking_invoc_ids))) as cur: + changed_state_ones = await cur.fetchall() + + for oid, ostate in ((x['id'], x['state']) for x in changed_state_ones): + if ostate == InvocationState.IN_PROGRESS.value: + await self.cancel_invocation(oid) + assert oid in invoking_invoc_ids + invoking_invoc_ids.remove(oid) + await asyncio.sleep(0.5) + + # + # set task name + async def set_task_name(self, task_id: int, new_name: str): + async with self.data_access.data_connection() as con: + await con.execute('UPDATE tasks SET "name" = ? WHERE "id" = ?', (new_name, task_id)) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, name=new_name)) # ui event + await con.commit() + + # + # set task groups + async def set_task_groups(self, task_id: int, group_names: Iterable[str]): + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('BEGIN IMMEDIATE') + async with con.execute('SELECT "group" FROM task_groups WHERE "task_id" = ?', (task_id,)) as cur: + all_groups = set(x['group'] for x in await cur.fetchall()) + group_names = set(group_names) + groups_to_set = group_names - all_groups + groups_to_del = all_groups - group_names + print(task_id, groups_to_set, groups_to_del, all_groups, group_names) + + for group_name in groups_to_set: + await con.execute('INSERT INTO task_groups (task_id, "group") VALUES (?, ?)', (task_id, group_name)) + await con.execute('INSERT OR IGNORE INTO task_group_attributes ("group", "ctime") VALUES (?, ?)', (group_name, int(datetime.utcnow().timestamp()))) + for group_name in groups_to_del: + await con.execute('DELETE FROM task_groups WHERE task_id = ? AND "group" = ?', (task_id, group_name)) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_removed_from_group, [task_id], groups_to_del) # ui event + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, groups_to_set) # ui event + # + # ui event + if len(groups_to_set) > 0: + async with con.execute( + 'SELECT tasks.id, tasks.parent_id, tasks.children_count, tasks.active_children_count, tasks.state, tasks.state_details, tasks.paused, tasks.node_id, ' + 'tasks.node_input_name, tasks.node_output_name, tasks.name, tasks.split_level, tasks.work_data_invocation_attempt, ' + 'task_splits.origin_task_id, task_splits.split_id, invocations."id" as invoc_id ' + 'FROM "tasks" ' + 'LEFT JOIN "task_splits" ON tasks.id=task_splits.task_id ' + 'LEFT JOIN "invocations" ON tasks.id=invocations.task_id AND invocations.state = ? ' + 'WHERE tasks."id" == ?', + (InvocationState.IN_PROGRESS.value, task_id)) as cur: + task_row = await cur.fetchone() + if task_row is not None: + progress = self.data_access.get_invocation_progress(task_row['invoc_id']) + con.add_after_commit_callback( + self.ui_state_access.scheduler_reports_task_added, + TaskData(task_id, task_row['parent_id'], task_row['children_count'], task_row['active_children_count'], TaskState(task_row['state']), + task_row['state_details'], bool(task_row['paused']), task_row['node_id'], task_row['node_input_name'], task_row['node_output_name'], + task_row['name'], task_row['split_level'], task_row['work_data_invocation_attempt'], progress, + task_row['origin_task_id'], task_row['split_id'], task_row['invoc_id'], group_names), + groups_to_set + ) # ui event + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_updated, TaskDelta(task_id, groups=group_names)) # ui event + # + # + await con.commit() + + # + # update task attributes + async def update_task_attributes(self, task_id: int, attributes_to_update: dict, attributes_to_delete: set): + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('BEGIN IMMEDIATE') + async with con.execute('SELECT "attributes" FROM tasks WHERE "id" = ?', (task_id,)) as cur: + row = await cur.fetchone() + if row is None: + self.__logger.warning(f'update task attributes for {task_id} failed. task id not found.') + await con.commit() + return + attributes = await deserialize_attributes(row['attributes']) + attributes.update(attributes_to_update) + for name in attributes_to_delete: + if name in attributes: + del attributes[name] + await con.execute('UPDATE tasks SET "attributes" = ? WHERE "id" = ?', (await serialize_attributes(attributes), + task_id)) + await con.commit() + + # + # set environment resolver + async def set_task_environment_resolver_arguments(self, task_id: int, env_res: Optional[EnvironmentResolverArguments]): + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('UPDATE tasks SET "environment_resolver_data" = ? WHERE "id" = ?', + (await env_res.serialize_async() if env_res is not None else None, + task_id)) + await con.commit() + + # + # node stuff + async def set_node_name(self, node_id: int, node_name: str) -> str: + """ + rename node. node_name may undergo validation and change. final node name that was set is returned + :param node_id: node id + :param node_name: proposed node name + :return: actual node name set + """ + async with self.data_access.data_connection() as con: + await con.execute('UPDATE "nodes" SET "name" = ? WHERE "id" = ?', (node_name, node_id)) + if node_id in self.__node_objects: + self.__node_objects[node_id].set_name(node_name) + await con.commit() + self.ui_state_access.bump_graph_update_id() + return node_name + + # + # reset node's stored state + async def wipe_node_state(self, node_id): + async with self.data_access.data_connection() as con: + await con.execute('UPDATE "nodes" SET node_object = NULL WHERE "id" = ?', (node_id,)) + if node_id in self.__node_objects: + # TODO: this below may be not safe (at least not proven to be safe yet, but maybe). check + del self.__node_objects[node_id] # it's here to "protect" operation within db transaction. TODO: but a proper __node_object lock should be in place instead + await con.commit() + self.ui_state_access.bump_graph_update_id() # not sure if needed - even number of inputs/outputs is not part of graph description + self.wake() + + # + # copy nodes + async def duplicate_nodes(self, node_ids: Iterable[int]) -> Dict[int, int]: + """ + copies given nodes, including connections between given nodes, + and returns mapping from given node_ids to respective new copies + + :param node_ids: + :return: + """ + old_to_new = {} + for nid in node_ids: + async with self.node_object_by_id_for_reading(nid) as node_obj: + node_obj: BaseNode + node_type, node_name = await self.get_node_type_and_name_by_id(nid) + new_id = await self.add_node(node_type, f'{node_name} copy') + async with self.node_object_by_id_for_writing(new_id) as new_node_obj: + node_obj.copy_ui_to(new_node_obj) + old_to_new[nid] = new_id + + # now copy connections + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + node_ids_str = f'({",".join(str(x) for x in node_ids)})' + async with con.execute(f'SELECT * FROM node_connections WHERE node_id_in IN {node_ids_str} AND node_id_out IN {node_ids_str}') as cur: + all_cons = await cur.fetchall() + for nodecon in all_cons: + assert nodecon['node_id_in'] in old_to_new + assert nodecon['node_id_out'] in old_to_new + await self.add_node_connection(old_to_new[nodecon['node_id_out']], nodecon['out_name'], old_to_new[nodecon['node_id_in']], nodecon['in_name']) + return old_to_new + # TODO: NotImplementedError("recheck and needs testing") + + # + # + # node reports it's interface was changed. not sure why it exists + async def node_reports_changes_needs_saving(self, node_id): + assert node_id in self.__node_objects, 'this may be caused by race condition with node deletion' + await self.save_node_to_database(node_id) + + # + # save node to database. + async def save_node_to_database(self, node_id): + """ + save node with given node_id to database + if node is not in our list of nodes - we assume it was not touched, not changed, so no saving needed + + :param node_id: + :return: + """ + # TODO: introduce __node_objects lock? or otherwise secure access + # why? this happens on ui_update, which can happen cuz of request from viewer. + # while node processing happens in a different thread, so this CAN happen at the same time with this + # AND THIS IS BAD! (potentially) if a node has changing internal state - this can save some inconsistent snapshot of node state! + # this works now only cuz scheduler_ui_protocol does the locking for param settings + node_object = self.__node_objects[node_id] + if node_object is None: + self.__logger.error('node_object is None while') + return + node_data, state_data = await self.__node_serializers[0].serialize_async(node_object) + async with self.data_access.data_connection() as con: + await con.execute('UPDATE "nodes" SET node_object = ?, node_object_state = ? WHERE "id" = ?', + (node_data, state_data, node_id)) + await con.commit() + + # + # set worker groups + async def set_worker_groups(self, worker_hwid: int, groups: List[str]): + groups = set(groups) + async with self.data_access.data_connection() as con: + await con.execute('BEGIN IMMEDIATE') # start transaction straight away + async with con.execute('SELECT "group" FROM worker_groups WHERE worker_hwid == ?', (worker_hwid,)) as cur: + existing_groups = set(x[0] for x in await cur.fetchall()) + to_delete = existing_groups - groups + to_add = groups - existing_groups + if len(to_delete): + await con.execute(f'DELETE FROM worker_groups WHERE worker_hwid == ? AND "group" IN ({",".join(("?",) * len(to_delete))})', (worker_hwid, *to_delete)) + if len(to_add): + await con.executemany(f'INSERT INTO worker_groups (worker_hwid, "group") VALUES (?, ?)', + ((worker_hwid, x) for x in to_add)) + await con.commit() + + # + # change node connection callback + async def change_node_connection(self, node_connection_id: int, new_out_node_id: Optional[int], new_out_name: Optional[str], + new_in_node_id: Optional[int], new_in_name: Optional[str]): + parts = [] + vals = [] + if new_out_node_id is not None: + parts.append('node_id_out = ?') + vals.append(new_out_node_id) + if new_out_name is not None: + parts.append('out_name = ?') + vals.append(new_out_name) + if new_in_node_id is not None: + parts.append('node_id_in = ?') + vals.append(new_in_node_id) + if new_in_name is not None: + parts.append('in_name = ?') + vals.append(new_in_name) + if len(vals) == 0: # nothing to do + return + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + vals.append(node_connection_id) + await con.execute(f'UPDATE node_connections SET {", ".join(parts)} WHERE "id" = ?', vals) + await con.commit() + self.wake() + self.ui_state_access.bump_graph_update_id() + + # + # add node connection callback + async def add_node_connection(self, out_node_id: int, out_name: str, in_node_id: int, in_name: str) -> int: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('INSERT OR REPLACE INTO node_connections (node_id_out, out_name, node_id_in, in_name) VALUES (?,?,?,?)', # INSERT OR REPLACE here (and not OR ABORT or smth) to ensure lastrowid is set + (out_node_id, out_name, in_node_id, in_name)) as cur: + ret = cur.lastrowid + await con.commit() + self.wake() + self.ui_state_access.bump_graph_update_id() + return ret + + # + # remove node connection callback + async def remove_node_connection(self, node_connection_id: int): + try: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('PRAGMA FOREIGN_KEYS = on') + await con.execute('DELETE FROM node_connections WHERE "id" = ?', (node_connection_id,)) + await con.commit() + self.ui_state_access.bump_graph_update_id() + except aiosqlite.IntegrityError as e: + self.__logger.error(f'could not remove node connection {node_connection_id} because of database integrity check') + raise DataIntegrityError() from None + + # + # add node + async def add_node(self, node_type: str, node_name: str) -> int: + if not self.__node_data_provider.has_node_factory(node_type): # preliminary check + raise RuntimeError(f'unknown node type: "{node_type}"') + + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('INSERT INTO "nodes" ("type", "name") VALUES (?,?)', + (node_type, node_name)) as cur: + ret = cur.lastrowid + await con.commit() + self.ui_state_access.bump_graph_update_id() + return ret + + async def apply_node_settings(self, node_id: int, settings_name: str): + async with self.node_object_by_id_for_writing(node_id) as node_object: + settings = self.__node_data_provider.node_settings(node_object.type_name(), settings_name) + async with self.node_object_by_id_for_writing(node_id) as node: # type: BaseNode + await asyncio.get_event_loop().run_in_executor(None, node.apply_settings, settings) + + async def remove_node(self, node_id: int): + try: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + await con.execute('PRAGMA FOREIGN_KEYS = on') + await con.execute('DELETE FROM "nodes" WHERE "id" = ?', (node_id,)) + await con.commit() + self.ui_state_access.bump_graph_update_id() + except aiosqlite.IntegrityError as e: + self.__logger.error(f'could not remove node {node_id} because of database integrity check') + raise DataIntegrityError('There are invocations (maybe achieved ones) referencing this node') from None + + # + # query connections + async def get_node_input_connections(self, node_id: int, input_name: Optional[str] = None): + return await self.get_node_connections(node_id, True, input_name) + + async def get_node_output_connections(self, node_id: int, output_name: Optional[str] = None): + return await self.get_node_connections(node_id, False, output_name) + + async def get_node_connections(self, node_id: int, query_input: bool = True, name: Optional[str] = None): + if query_input: + nodecol = 'node_id_in' + namecol = 'in_name' + else: + nodecol = 'node_id_out' + namecol = 'out_name' + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + if name is None: + async with con.execute('SELECT * FROM node_connections WHERE "%s" = ?' % (nodecol,), + (node_id,)) as cur: + return [dict(x) for x in await cur.fetchall()] + else: + async with con.execute('SELECT * FROM node_connections WHERE "%s" = ? AND "%s" = ?' % (nodecol, namecol), + (node_id, name)) as cur: + return [dict(x) for x in await cur.fetchall()] + + # + # spawning new task callback + async def spawn_tasks(self, newtasks: Union[Iterable[TaskSpawn], TaskSpawn], con: Optional[aiosqlite_overlay.ConnectionWithCallbacks] = None) -> Union[Tuple[SpawnStatus, Optional[int]], Tuple[Tuple[SpawnStatus, Optional[int]], ...]]: + """ + + :param newtasks: + :param con: + :return: + """ + + async def _inner_shit() -> Tuple[Tuple[SpawnStatus, Optional[int]], ...]: + result = [] + new_tasks = [] + current_timestamp = int(datetime.utcnow().timestamp()) + assert len(newtasks) > 0, 'expectations failure' + if not con.in_transaction: # IF this is called from multiple async tasks with THE SAME con - this may cause race conditions + await con.execute('BEGIN IMMEDIATE') + for newtask in newtasks: + if newtask.source_invocation_id() is not None: + async with con.execute('SELECT node_id, task_id FROM invocations WHERE "id" = ?', + (newtask.source_invocation_id(),)) as incur: + invocrow = await incur.fetchone() + assert invocrow is not None + node_id: int = invocrow['node_id'] + parent_task_id: int = invocrow['task_id'] + elif newtask.forced_node_task_id() is not None: + node_id, parent_task_id = newtask.forced_node_task_id() + else: + self.__logger.error('ERROR CREATING SPAWN TASK: Malformed source') + result.append((SpawnStatus.FAILED, None)) + continue + + async with con.execute('INSERT INTO tasks ("name", "attributes", "parent_id", "state", "node_id", "node_output_name", "environment_resolver_data") VALUES (?, ?, ?, ?, ?, ?, ?)', + (newtask.name(), await serialize_attributes(newtask._attributes()), parent_task_id, # TODO: run dumps in executor + TaskState.SPAWNED.value if newtask.create_as_spawned() else TaskState.WAITING.value, + node_id, newtask.node_output_name(), + newtask.environment_arguments().serialize() if newtask.environment_arguments() is not None else None)) as newcur: + new_id = newcur.lastrowid + + all_groups = set() + if parent_task_id is not None: # inherit all parent's groups + # check and inherit parent's environment wrapper arguments + if newtask.environment_arguments() is None: + await con.execute('UPDATE tasks SET environment_resolver_data = (SELECT environment_resolver_data FROM tasks WHERE "id" == ?) WHERE "id" == ?', + (parent_task_id, new_id)) + + # inc children count happens in db trigger + # inherit groups + async with con.execute('SELECT "group" FROM task_groups WHERE "task_id" = ?', (parent_task_id,)) as gcur: + groups = [x['group'] for x in await gcur.fetchall()] + all_groups.update(groups) + if len(groups) > 0: + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, groups) # ui event + await con.executemany('INSERT INTO task_groups ("task_id", "group") VALUES (?, ?)', + zip(itertools.repeat(new_id, len(groups)), groups)) + else: # parent_task_id is None + # in this case we create a default group for the task. + # task should not be left without groups at all - otherwise it will be impossible to find in UI + new_group = '{name}#{id:d}'.format(name=newtask.name(), id=new_id) + all_groups.add(new_group) + await con.execute('INSERT INTO task_groups ("task_id", "group") VALUES (?, ?)', + (new_id, new_group)) + await con.execute('INSERT OR REPLACE INTO task_group_attributes ("group", "ctime") VALUES (?, ?)', + (new_group, current_timestamp)) + if newtask.default_priority() is not None: + await con.execute('UPDATE task_group_attributes SET "priority" = ? WHERE "group" = ?', + (newtask.default_priority(), new_group)) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, (new_group,)) # ui event + # + if newtask.extra_group_names(): + groups = newtask.extra_group_names() + all_groups.update(groups) + await con.executemany('INSERT INTO task_groups ("task_id", "group") VALUES (?, ?)', + zip(itertools.repeat(new_id, len(groups)), groups)) + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_task_groups_changed, groups) # ui event + for group in groups: + async with con.execute('SELECT "group" FROM task_group_attributes WHERE "group" == ?', (group,)) as gcur: + need_create = await gcur.fetchone() is None + if not need_create: + continue + await con.execute('INSERT INTO task_group_attributes ("group", "ctime") VALUES (?, ?)', + (group, current_timestamp)) + # TODO: task_groups.group should be a foreign key to task_group_attributes.group + # but then we need to insert those guys in correct order (first in attributes table, then groups) + # then smth like FOREIGN KEY("group") REFERENCES "task_group_attributes"("group") ON UPDATE CASCADE ON DELETE CASCADE + result.append((SpawnStatus.SUCCEEDED, new_id)) + new_tasks.append(TaskData(new_id, parent_task_id, 0, 0, + TaskState.SPAWNED if newtask.create_as_spawned() else TaskState.WAITING, '', + False, node_id, 'main', newtask.node_output_name(), newtask.name(), 0, 0, None, None, None, None, + all_groups)) + + # callbacks for ui events + con.add_after_commit_callback(self.ui_state_access.scheduler_reports_tasks_added, new_tasks) + return tuple(result) + + return_single = False + if isinstance(newtasks, TaskSpawn): + newtasks = (newtasks,) + return_single = True + if len(newtasks) == 0: + return () + if con is not None: + stuff = await _inner_shit() + else: + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + stuff = await _inner_shit() + await con.commit() + self.wake() + self.poke_task_processor() + return stuff[0] if return_single else stuff + + # + async def node_name_to_id(self, name: str) -> List[int]: + """ + get the list of node ids that have specified name + :param name: + :return: + """ + async with self.data_access.data_connection() as con: + async with con.execute('SELECT "id" FROM "nodes" WHERE "name" = ?', (name,)) as cur: + return list(x[0] for x in await cur.fetchall()) + + # + async def get_invocation_metadata(self, task_id: int) -> Dict[int, List[IncompleteInvocationLogData]]: + """ + get task's log metadata - meaning which nodes it ran on and how + :param task_id: + :return: dict[node_id -> list[IncompleteInvocationLogData]] + """ + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + logs = {} + self.__logger.debug(f'fetching log metadata for {task_id}') + async with con.execute('SELECT "id", node_id, runtime, worker_id, state, return_code from "invocations" WHERE "state" != ? AND "task_id" == ?', + (InvocationState.INVOKING.value, task_id)) as cur: + async for entry in cur: + node_id = entry['node_id'] + logs.setdefault(node_id, []).append(IncompleteInvocationLogData( + entry['id'], + entry['worker_id'], + entry['runtime'], # TODO: this should be set to active run time if invocation is running + InvocationState(entry['state']), + entry['return_code'] + )) + return logs + + async def get_log(self, invocation_id: int) -> Optional[InvocationLogData]: + """ + get logs for given task, node and invocation ids + + returns a dict of node_id + + :param invocation_id: + :return: + """ + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + self.__logger.debug(f"fetching for {invocation_id}") + async with con.execute('SELECT "id", task_id, worker_id, node_id, state, return_code, log_external, runtime, stdout, stderr ' + 'FROM "invocations" WHERE "id" = ?', + (invocation_id,)) as cur: + rawentry = await cur.fetchone() # should be exactly 1 or 0 + if rawentry is None: + return None + + entry: InvocationLogData = InvocationLogData(rawentry['id'], + rawentry['worker_id'], + rawentry['runtime'], + rawentry['task_id'], + rawentry['node_id'], + InvocationState(rawentry['state']), + rawentry['return_code'], + rawentry['stdout'] or '', + rawentry['stderr'] or '') + if entry.invocation_state == InvocationState.IN_PROGRESS: + async with con.execute('SELECT last_address FROM workers WHERE "id" = ?', (entry.worker_id,)) as worcur: + workrow = await worcur.fetchone() + if workrow is None: + self.__logger.error('Worker not found during log fetch! this is not supposed to happen! Database inconsistent?') + else: + try: + with WorkerControlClient.get_worker_control_client(AddressChain(workrow['last_address']), self.message_processor()) as client: # type: WorkerControlClient + stdout, stderr = await client.get_log(invocation_id) + if not self.__use_external_log: + await con.execute('UPDATE "invocations" SET stdout = ?, stderr = ? WHERE "id" = ?', # TODO: is this really needed? if it's never really read + (stdout, stderr, invocation_id)) + await con.commit() + # TODO: maybe add else case? save partial log to file? + except ConnectionError: + self.__logger.warning('could not connect to worker to get freshest logs') + else: + entry.stdout = stdout + entry.stderr = stderr + + elif entry.invocation_state == InvocationState.FINISHED and rawentry['log_external'] == 1: + logbasedir = self.__external_log_location / 'invocations' / f'{invocation_id}' + stdout_path = logbasedir / 'stdout.log' + stderr_path = logbasedir / 'stderr.log' + try: + if stdout_path.exists(): + async with aiofiles.open(stdout_path, 'r') as fstdout: + entry.stdout = await fstdout.read() + except IOError: + self.__logger.exception(f'could not read external stdout log for {invocation_id}') + try: + if stderr_path.exists(): + async with aiofiles.open(stderr_path, 'r') as fstderr: + entry.stderr = await fstderr.read() + except IOError: + self.__logger.exception(f'could not read external stdout log for {invocation_id}') + + return entry + + def server_address(self) -> Tuple[str, int]: + if self.__legacy_command_server_address is None: + raise RuntimeError('cannot get listening address of a non started server') + return self.__legacy_command_server_address + + def server_message_address(self, to: AddressChain) -> AddressChain: + if self.__message_processor is None: + raise RuntimeError('cannot get listening address of a non started server') + + return self.message_processor().listening_address(to) + + def server_message_addresses(self) -> Tuple[AddressChain]: + if self.__message_processor is None: + raise RuntimeError('cannot get listening address of a non started server') + + return self.message_processor().listening_addresses() From 1c33b3edc78cb37ccb4c03f5f774c523f69ee9ea Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:09:25 +0100 Subject: [PATCH 02/10] refactor: break BaseNode<->NodeUI<->Parameter<->ProcessingContext dep cycle --- src/lifeblood/basenode.py | 19 +- .../expression_locals_provider_base.py | 6 + src/lifeblood/node_parameters.py | 959 ++++++++++++++++++ src/lifeblood/node_ui.py | 492 +++++++++ .../node_ui_callback_receiver_base.py | 3 + src/lifeblood/processingcontext.py | 28 +- 6 files changed, 1484 insertions(+), 23 deletions(-) create mode 100644 src/lifeblood/expression_locals_provider_base.py create mode 100644 src/lifeblood/node_parameters.py create mode 100644 src/lifeblood/node_ui.py create mode 100644 src/lifeblood/node_ui_callback_receiver_base.py diff --git a/src/lifeblood/basenode.py b/src/lifeblood/basenode.py index d717b581..8b16650f 100644 --- a/src/lifeblood/basenode.py +++ b/src/lifeblood/basenode.py @@ -1,23 +1,23 @@ import asyncio from copy import deepcopy from typing import Dict, Optional, Any +from logging import Logger from .nodethings import ProcessingResult -from .uidata import NodeUi, ParameterNotFound, Parameter +from .node_ui import NodeUi +from .node_parameters import ParameterNotFound, Parameter from .processingcontext import ProcessingContext from .logging import get_logger from .plugin_info import PluginInfo, empty_plugin_info from .nodegraph_holder_base import NodeGraphHolderBase +from .node_ui_callback_receiver_base import NodeUiCallbackReceiverBase # reexport from .nodethings import ProcessingError -from typing import TYPE_CHECKING, Iterable +from typing import Iterable -if TYPE_CHECKING: - from logging import Logger - -class BaseNode: +class BaseNode(NodeUiCallbackReceiverBase): _plugin_data = None # To be set on module level by loader, set to empty_plugin_info by default @classmethod @@ -37,6 +37,7 @@ def description(cls) -> str: return 'this node type does not have a description' def __init__(self, name: str): + super().__init__() if BaseNode._plugin_data is None: BaseNode._plugin_data = empty_plugin_info self.__parent: NodeGraphHolderBase = None @@ -172,7 +173,7 @@ def _process_task_wrapper(self, task_dict, node_config) -> ProcessingResult: # with self.get_ui().lock_interface_readonly(): # TODO: this is bad, RETHINK! # TODO: , in case threads do l1---r1 - release2 WILL leave lock in locked state forever, as it remembered it at l2 # TODO: l2---r2 - return self.process_task(ProcessingContext(self, task_dict, node_config)) + return self.process_task(ProcessingContext(self.name(), self.label(), self.get_ui(), task_dict, node_config)) def process_task(self, context: ProcessingContext) -> ProcessingResult: """ @@ -185,7 +186,7 @@ def process_task(self, context: ProcessingContext) -> ProcessingResult: def _postprocess_task_wrapper(self, task_dict, node_config) -> ProcessingResult: # with self.get_ui().lock_interface_readonly(): #TODO: read comment for _process_task_wrapper - return self.postprocess_task(ProcessingContext(self, task_dict, node_config)) + return self.postprocess_task(ProcessingContext(self.name(), self.label(), self.get_ui(), task_dict, node_config)) def postprocess_task(self, context: ProcessingContext) -> ProcessingResult: """ @@ -199,7 +200,7 @@ def postprocess_task(self, context: ProcessingContext) -> ProcessingResult: def copy_ui_to(self, to_node: "BaseNode"): newui = deepcopy(self._parameters) # nodeUI redefines deepcopy to detach new copy from node to_node._parameters = newui - newui.attach_to_node(to_node) + newui.set_ui_change_callback_receiver(to_node) def apply_settings(self, settings: Dict[str, Dict[str, Any]]) -> None: with self.get_ui().postpone_ui_callbacks(): diff --git a/src/lifeblood/expression_locals_provider_base.py b/src/lifeblood/expression_locals_provider_base.py new file mode 100644 index 00000000..c99fadca --- /dev/null +++ b/src/lifeblood/expression_locals_provider_base.py @@ -0,0 +1,6 @@ +from typing import Any, Dict + + +class ExpressionLocalsProviderBase: + def locals(self) -> Dict[str, Any]: + raise NotImplementedError() diff --git a/src/lifeblood/node_parameters.py b/src/lifeblood/node_parameters.py new file mode 100644 index 00000000..628da50b --- /dev/null +++ b/src/lifeblood/node_parameters.py @@ -0,0 +1,959 @@ +from dataclasses import dataclass +import os +import pathlib +import math +from copy import deepcopy +from .enums import NodeParameterType +from .expression_locals_provider_base import ExpressionLocalsProviderBase +import re + +from typing import Dict, Any, List, Set, Optional, Tuple, Union, Iterable, FrozenSet, Type + + +class ParameterExpressionError(Exception): + def __init__(self, inner_exception): + self.__inner_exception = inner_exception + + def __str__(self): + return f'ParameterExpressionError: {str(self.__inner_exception)}' + + def inner_expection(self): + return self.__inner_exception + + +class ParameterExpressionCastError(ParameterExpressionError): + """ + represents error with type casting of the expression result + """ + pass + + +class LayoutError(RuntimeError): + pass + + +class LayoutReadonlyError(LayoutError): + pass + + +class ParameterHierarchyItem: + def __init__(self): + self.__parent: Optional["ParameterHierarchyItem"] = None + self.__children: Set["ParameterHierarchyItem"] = set() + + def parent(self) -> Optional["ParameterHierarchyItem"]: + return self.__parent + + def set_parent(self, item: Optional["ParameterHierarchyItem"]): + if self.__parent == item: + return + if self.__parent is not None: + assert self in self.__parent.__children + self.__parent._child_about_to_be_removed(self) + self.__parent.__children.remove(self) + self.__parent = item + if self.__parent is not None: + self.__parent.__children.add(self) + self.__parent._child_added(self) + + def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): + """ + callback for just before a child is removed + :param child: + :return: + """ + pass + + def _child_added(self, child: "ParameterHierarchyItem"): + """ + callback for just after child is added + :param child: + :return: + """ + pass + + def children(self) -> FrozenSet["ParameterHierarchyItem"]: + return frozenset(self.__children) + + def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): + if self.__parent is not None: + self.__parent._children_definition_changed([self]) + + def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): + if self.__parent is not None: + self.__parent._children_appearance_changed([self]) + + def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): + if self.__parent is not None: + self.__parent._children_value_changed([self]) + + def visible(self) -> bool: + return False + + +class ParameterHierarchyLeaf(ParameterHierarchyItem): + def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): + return + + def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): + return + + def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): + return + + def _child_added(self, child: "ParameterHierarchyItem"): + raise RuntimeError('cannot add children to ParameterHierarchyLeaf') + + def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): + raise RuntimeError('cannot remove children from ParameterHierarchyLeaf') + + +def evaluate_expression(expression, context: Optional[ExpressionLocalsProviderBase]): + try: + return eval(expression, + {'os': os, 're': re, 'pathlib': pathlib, 'Path': pathlib.Path, **{k: getattr(math, k) for k in dir(math) if not k.startswith('_')}}, + context.locals() if context is not None else {}) + except Exception as e: + raise ParameterExpressionError(e) from None + + +class Separator(ParameterHierarchyLeaf): + pass + + +class Parameter(ParameterHierarchyLeaf): + __re_expand_pattern = None + __re_escape_backticks_pattern = None + + class DontChange: + pass + + def __init__(self, param_name: str, param_label: Optional[str], param_type: NodeParameterType, param_val: Any, can_have_expression: bool = True, readonly: bool = False, default_value=None): + super(Parameter, self).__init__() + self.__name = param_name + self.__label = param_label + self.__type = param_type + self.__value = None + self.__menu_items: Optional[Dict[str, str]] = None + self.__menu_items_order: List[str] = [] + self.__vis_when = [] + self.__force_hidden = False + self.__is_readonly = False # set it False until the end of constructor + self.__locked = False # same as readonly, but is settable by user + + self.__expression = None + self.__can_have_expressions = can_have_expression + + if Parameter.__re_expand_pattern is None: + Parameter.__re_expand_pattern = re.compile(r'((? str: + return self.__name + + def _set_name(self, name: str): + """ + this should only be called by layout classes + """ + self.__name = name + if self.parent() is not None: + self.parent()._children_definition_changed([self]) + + def label(self) -> Optional[str]: + return self.__label + + def type(self) -> NodeParameterType: + return self.__type + + def unexpanded_value(self, context: Optional[ExpressionLocalsProviderBase] = None): # TODO: why context parameter here? + return self.__value + + def default_value(self): + """ + note that this value will be unexpanded + + :return: + """ + return self.__default_value + + def value(self, context: Optional[ExpressionLocalsProviderBase] = None) -> Any: + """ + returns value of this parameter + :param context: optional dict like locals, for expression evaluations + """ + + if self.__expression is not None: + result = evaluate_expression(self.__expression, context) + # check type and cast + try: + if self.__type == NodeParameterType.INT: + result = int(result) + elif self.__type == NodeParameterType.FLOAT: + result = float(result) + elif self.__type == NodeParameterType.STRING and not isinstance(result, str): + result = str(result) + elif self.__type == NodeParameterType.BOOL: + result = bool(result) + except ValueError: + raise ParameterExpressionCastError(f'could not cast {result} to {self.__type.name}') from None + # check limits + if self.__type in (NodeParameterType.INT, NodeParameterType.FLOAT): + if self.__hard_borders[0] is not None and result < self.__hard_borders[0]: + result = self.__hard_borders[0] + if self.__hard_borders[1] is not None and result > self.__hard_borders[1]: + result = self.__hard_borders[1] + return result + + if self.__type != NodeParameterType.STRING: + return self.__value + + # for string parameters we expand expressions in ``, kinda like bash + parts = self.__re_expand_pattern.split(self.__value) + for i, part in enumerate(parts): + if part.startswith('`') and part.endswith('`'): # expression + parts[i] = str(evaluate_expression(self.__re_escape_backticks_pattern.sub('`', part[1:-1]), context)) + else: + parts[i] = self.__re_escape_backticks_pattern.sub('`', part) + return ''.join(parts) + # return self.__re_expand_pattern.sub(lambda m: str(evaluate_expression(m.group(1), context)), self.__value) + + def set_slider_visualization(self, value_min=DontChange, value_max=DontChange): # type: (Union[int, float], Union[int, float]) -> Parameter + """ + set a visual slider's minimum and maximum + this does nothing to the parameter itself, and it's up to parameter renderer to interpret this data + + :return: self to be chained + """ + if self.__type not in (NodeParameterType.INT, NodeParameterType.FLOAT): + raise ParameterDefinitionError('cannot set limits for parameters of types other than INT and FLOAT') + + if self.__type == NodeParameterType.INT: + value_min = int(value_min) + elif self.__type == NodeParameterType.FLOAT: + value_min = float(value_min) + + if self.__type == NodeParameterType.INT: + value_max = int(value_max) + elif self.__type == NodeParameterType.FLOAT: + value_max = float(value_max) + + self.__display_borders = (value_min, value_max) + return self + + def set_value_limits(self, value_min=DontChange, value_max=DontChange): # type: (Union[int, float, None, Type[DontChange]], Union[int, float, None, Type[DontChange]]) -> Parameter + """ + set minimum and maximum values that parameter will enforce + None means no limit (unset limit) + + :return: self to be chained + """ + if self.__type not in (NodeParameterType.INT, NodeParameterType.FLOAT): + raise ParameterDefinitionError('cannot set limits for parameters of types other than INT and FLOAT') + if value_min == self.DontChange: + value_min = self.__hard_borders[0] + elif value_min is not None: + if self.__type == NodeParameterType.INT: + value_min = int(value_min) + elif self.__type == NodeParameterType.FLOAT: + value_min = float(value_min) + if value_max == self.DontChange: + value_max = self.__hard_borders[1] + elif value_max is not None: + if self.__type == NodeParameterType.INT: + value_max = int(value_max) + elif self.__type == NodeParameterType.FLOAT: + value_max = float(value_max) + assert value_min != self.DontChange + assert value_max != self.DontChange + + self.__hard_borders = (value_min, value_max) + if value_min is not None and self.__value < value_min: + self.__value = value_min + if value_max is not None and self.__value > value_max: + self.__value = value_max + return self + + def set_text_multiline(self, syntax_hint: Optional[str] = None): + if self.__type != NodeParameterType.STRING: + raise ParameterDefinitionError('multiline can be only set for string parameters') + self.__string_multiline = True + self.__string_multiline_syntax_hint = syntax_hint + return self + + def is_text_multiline(self): + return self.__string_multiline + + def syntax_hint(self) -> Optional[str]: + """ + may hint an arbitrary string hint to the renderer + it's up to renderer to decide what to do. + common conception is to use language name lowercase, like: python + None means no hint + """ + return self.__string_multiline_syntax_hint + + def display_value_limits(self) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + returns a tuple of limits for display purposes. + parameter itself ignores this totally. + it's up to parameter renderer to interpret this info + """ + return self.__display_borders + + def value_limits(self) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + returns a tuple of hard limits. + these limits are enforced by the parameter itself + """ + return self.__hard_borders + + def is_readonly(self): + return self.__is_readonly + + def is_locked(self): + return self.__locked + + def set_locked(self, locked: bool): + if locked == self.__locked: + return + self.__locked = locked + if self.parent() is not None: + self.parent()._children_definition_changed([self]) + + def set_value(self, value: Any): + if self.__is_readonly: + raise ParameterReadonly() + if self.__locked: + raise ParameterLocked() + if self.__type == NodeParameterType.FLOAT: + param_value = float(value) + if self.__hard_borders[0] is not None: + param_value = max(param_value, self.__hard_borders[0]) + if self.__hard_borders[1] is not None: + param_value = min(param_value, self.__hard_borders[1]) + elif self.__type == NodeParameterType.INT: + param_value = int(value) + if self.__hard_borders[0] is not None: + param_value = max(param_value, self.__hard_borders[0]) + if self.__hard_borders[1] is not None: + param_value = min(param_value, self.__hard_borders[1]) + elif self.__type == NodeParameterType.BOOL: + param_value = bool(value) + elif self.__type == NodeParameterType.STRING: + param_value = str(value) + else: + raise NotImplementedError() + self.__value = param_value + for other_param in self.__params_referencing_me: + other_param._referencing_param_value_changed(self) + + if self.parent() is not None: + self.parent()._children_value_changed([self]) + + def can_have_expressions(self): + return self.__can_have_expressions + + def has_expression(self): + return self.__expression is not None + + def expression(self): + return self.__expression + + def set_expression(self, expression: Union[str, None]): + """ + sets or removes expression from a parameter + :param expression: either expression code or None means removing expression + :return: + """ + if self.__is_readonly: + raise ParameterReadonly() + if self.__locked: + raise ParameterLocked() + if not self.__can_have_expressions: + raise ParameterCannotHaveExpressions() + if expression != self.__expression: + self.__expression = expression + if self.parent() is not None: + self.parent()._children_definition_changed([self]) + + def remove_expression(self): + self.set_expression(None) + + @classmethod + def python_from_expandable_string(cls, expandable_string, context: Optional[ExpressionLocalsProviderBase] = None) -> str: + """ + given string value that may contain backtick expressions return python equivalent + """ + expression_parts = [] + parts = cls.__re_expand_pattern.split(expandable_string) + for i, part in enumerate(parts): + if part.startswith('`') and part.endswith('`'): # expression + maybe_expr = f'({cls.__re_escape_backticks_pattern.sub("`", part[1:-1])})' + try: + val = evaluate_expression(maybe_expr, context) + if not isinstance(val, str): + maybe_expr = f'str{maybe_expr}' # note, maybe_expr is already enclosed in parentheses + except ParameterExpressionError as e: + # we just catch syntax errors, other runtime errors are allowed as real context is set per task + if isinstance(e.inner_expection(), SyntaxError): + maybe_expr = '""' + expression_parts.append(maybe_expr) + else: + val = cls.__re_escape_backticks_pattern.sub('`', part) + if not val: + continue + expression_parts.append(repr(val)) + + return ' + '.join(expression_parts) + + def _referencing_param_value_changed(self, other_parameter): + """ + when a parameter that we are referencing changes - it will report here + :param other_parameter: + """ + # TODO: this now only works with referencing param in visibility condition + # TODO: butt we want general references, including from parameter expressions + # TODO: OOOORR will i need references for expressions at all? + # TODO: references between node bring SOOOO much pain when serializing them separately + if self.__vis_when: + self.__vis_cache = None + if self.parent() is not None and isinstance(self.parent(), ParametersLayoutBase): + self.parent()._children_appearance_changed([self]) + + def set_hidden(self, hidden): + self.__force_hidden = hidden + + def visible(self) -> bool: + if self.__force_hidden: + return False + if self.__vis_cache is not None: + return self.__vis_cache + if self.__vis_when: + for other_param, op, value in self.__vis_when: + if op == '==' and other_param.value() != value \ + or op == '!=' and other_param.value() == value \ + or op == '>' and other_param.value() <= value \ + or op == '>=' and other_param.value() < value \ + or op == '<' and other_param.value() >= value \ + or op == '<=' and other_param.value() > value \ + or op == 'in' and other_param.value() not in value \ + or op == 'not in' and other_param.value() in value: + self.__vis_cache = False + return False + self.__vis_cache = True + return True + + def _add_referencing_me(self, other_parameter: "Parameter"): + """ + other_parameter MUST belong to the same node to avoid cross-node references + :param other_parameter: + :return: + """ + assert self.has_same_parent(other_parameter), 'references MUST belong to the same node' + self.__params_referencing_me.add(other_parameter) + + def _remove_referencing_me(self, other_parameter: "Parameter"): + assert other_parameter in self.__params_referencing_me + self.__params_referencing_me.remove(other_parameter) + + def references(self) -> Tuple["Parameter", ...]: + """ + returns tuple of parameters referenced by this parameter's definition + static/dynamic references from expressions ARE NOT INCLUDED - they are not parameter's DEFINITION + currently the only thing that can be a reference is parameter from visibility conditions + """ + return tuple(x[0] for x in self.__vis_when) + + def visibility_conditions(self) -> Tuple[Tuple["Parameter", str, Union[bool, int, float, str, tuple]], ...]: + return tuple(self.__vis_when) + + def append_visibility_condition(self, other_param: "Parameter", condition: str, value: Union[bool, int, float, str, tuple]) -> "Parameter": + """ + condition currently can only be a simplest + :param other_param: + :param condition: + :param value: + :return: self to allow easy chaining + """ + allowed_conditions = ('==', '!=', '>=', '<=', '<', '>', 'in', 'not in') + if condition not in allowed_conditions: + raise ParameterDefinitionError(f'condition must be one of: {", ".join(x for x in allowed_conditions)}') + if condition in ('in', 'not in') and not isinstance(value, tuple): + raise ParameterDefinitionError('for in/not in conditions value must be a tuple of possible values') + elif condition not in ('in', 'not in') and isinstance(value, tuple): + raise ParameterDefinitionError('value can be tuple only for in/not in conditions') + + otype = other_param.type() + if otype == NodeParameterType.INT: + if not isinstance(value, tuple): + value = int(value) + elif otype == NodeParameterType.BOOL: + if not isinstance(value, tuple): + value = bool(value) + elif otype == NodeParameterType.FLOAT: + if not isinstance(value, tuple): + value = float(value) + elif otype == NodeParameterType.STRING: + if not isinstance(value, tuple): + value = str(value) + else: # for future + raise ParameterDefinitionError(f'cannot add visibility condition check based on this type of parameters: {otype}') + self.__vis_when.append((other_param, condition, value)) + other_param._add_referencing_me(self) + self.__vis_cache = None + + self.parent()._children_definition_changed([self]) + return self + + def add_menu(self, menu_items_pairs) -> "Parameter": + """ + adds UI menu to parameter param_name + :param menu_items_pairs: dict of label -> value for parameter menu. type of value MUST match type of parameter param_name. type of label MUST be string + :return: self to allow easy chaining + """ + # sanity check and regroup + my_type = self.type() + menu_items = {} + menu_order = [] + for key, value in menu_items_pairs: + menu_items[key] = value + menu_order.append(key) + if not isinstance(key, str): + raise ParameterDefinitionError('menu label type must be string') + if my_type == NodeParameterType.INT and not isinstance(value, int): + raise ParameterDefinitionError(f'wrong menu value for int parameter "{self.name()}"') + elif my_type == NodeParameterType.BOOL and not isinstance(value, bool): + raise ParameterDefinitionError(f'wrong menu value for bool parameter "{self.name()}"') + elif my_type == NodeParameterType.FLOAT and not isinstance(value, float): + raise ParameterDefinitionError(f'wrong menu value for float parameter "{self.name()}"') + elif my_type == NodeParameterType.STRING and not isinstance(value, str): + raise ParameterDefinitionError(f'wrong menu value for string parameter "{self.name()}"') + + self.__menu_items = menu_items + self.__menu_items_order = menu_order + self.parent()._children_definition_changed([self]) + return self + + def has_menu(self): + return self.__menu_items is not None + + def get_menu_items(self): + return self.__menu_items_order, self.__menu_items + + def has_same_parent(self, other_parameter: "Parameter") -> bool: + """ + finds if somewhere down the hierarchy there is a shared parent of self and other_parameter + """ + my_ancestry_line = set() + ancestor = self + while ancestor is not None: + my_ancestry_line.add(ancestor) + ancestor = ancestor.parent() + + ancestor = other_parameter + while ancestor is not None: + if ancestor in my_ancestry_line: + return True + ancestor = ancestor.parent() + return False + + def __setstate__(self, state): + """ + overriden for easier parameter class iterations during active development. + otherwise all node ui data should be recreated from zero in DB every time a change is made + """ + # this init here only to init new shit when unpickling old parameters without resetting DB all the times + self.__init__('', '', NodeParameterType.INT, 0, False) + self.__dict__.update(state) + + +class ParameterError(RuntimeError): + pass + + +class ParameterDefinitionError(ParameterError): + pass + + +class ParameterNotFound(ParameterError): + pass + + +class ParameterNameCollisionError(ParameterError): + pass + + +class ParameterReadonly(ParameterError): + pass + + +class ParameterLocked(ParameterError): + pass + + +class ParameterCannotHaveExpressions(ParameterError): + pass + + +class ParametersLayoutBase(ParameterHierarchyItem): + def __init__(self): + super(ParametersLayoutBase, self).__init__() + self.__parameters: Dict[str, Parameter] = {} # just for quicker access + self.__layouts: Set[ParametersLayoutBase] = set() + self.__block_ui_callbacks = False + + def initializing_interface_lock(self): + return self.block_ui_callbacks() + + def block_ui_callbacks(self): + class _iiLock: + def __init__(self, lockable): + self.__nui = lockable + self.__prev_state = False + + def __enter__(self): + self.__prev_state = self.__nui._ParametersLayoutBase__block_ui_callbacks + self.__nui._ParametersLayoutBase__block_ui_callbacks = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__nui._ParametersLayoutBase__block_ui_callbacks = self.__prev_state + + return _iiLock(self) + + def _is_initialize_lock_set(self): + return self.__block_ui_callbacks + + def add_parameter(self, new_parameter: Parameter): + self.add_generic_leaf(new_parameter) + + def add_generic_leaf(self, item: ParameterHierarchyLeaf): + if not self._is_initialize_lock_set(): + raise LayoutError('initializing interface not inside initializing_interface_lock') + item.set_parent(self) + + def add_layout(self, new_layout: "ParametersLayoutBase"): + if not self._is_initialize_lock_set(): + raise LayoutError('initializing interface not inside initializing_interface_lock') + new_layout.set_parent(self) + + def items(self, recursive=False) -> Iterable["ParameterHierarchyItem"]: + for child in self.children(): + yield child + if not recursive: + continue + elif isinstance(child, ParametersLayoutBase): + for child_param in child.parameters(recursive=recursive): + yield child_param + + def parameters(self, recursive=False) -> Iterable[Parameter]: + for item in self.items(recursive=recursive): + if isinstance(item, Parameter): + yield item + + def parameter(self, name: str) -> Parameter: + if name in self.__parameters: + return self.__parameters[name] + for layout in self.__layouts: + try: + return layout.parameter(name) + except ParameterNotFound: + continue + raise ParameterNotFound(f'parameter "{name}" not found in layout hierarchy') + + def visible(self) -> bool: + return len(self.children()) != 0 and any(x.visible() for x in self.items()) + + def _child_added(self, child: "ParameterHierarchyItem"): + super(ParametersLayoutBase, self)._child_added(child) + if isinstance(child, Parameter): + # check global parameter name uniqueness + rootparent = self + while isinstance(rootparent.parent(), ParametersLayoutBase): + rootparent = rootparent.parent() + if child.name() in (x.name() for x in rootparent.parameters(recursive=True) if x != child): + raise ParameterNameCollisionError('cannot add parameters with the same name to the same layout hierarchy') + self.__parameters[child.name()] = child + elif isinstance(child, ParametersLayoutBase): + self.__layouts.add(child) + # check global parameter name uniqueness + rootparent = self + while isinstance(rootparent.parent(), ParametersLayoutBase): + rootparent = rootparent.parent() + new_params = list(child.parameters(recursive=True)) + existing_params = set(x.name() for x in rootparent.parameters(recursive=True) if x not in new_params) + for new_param in new_params: + if new_param.name() in existing_params: + raise ParameterNameCollisionError('cannot add parameters with the same name to the same layout hierarchy') + + def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): + if isinstance(child, Parameter): + del self.__parameters[child.name()] + elif isinstance(child, ParametersLayoutBase): + self.__layouts.remove(child) + super(ParametersLayoutBase, self)._child_about_to_be_removed(child) + + def _children_definition_changed(self, changed_children: Iterable["ParameterHierarchyItem"]): + """ + :param changed_children: + :return: + """ + super(ParametersLayoutBase, self)._children_definition_changed(changed_children) + # check self.__parameters consistency + reversed_parameters: Dict[Parameter, str] = {v: k for k, v in self.__parameters.items()} + for child in changed_children: + if not isinstance(child, Parameter): + continue + if child in reversed_parameters: + del self.__parameters[reversed_parameters[child]] + self.__parameters[child.name()] = child + + def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): + """ + :param children: + :return: + """ + super(ParametersLayoutBase, self)._children_value_changed(children) + + def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): + super(ParametersLayoutBase, self)._children_appearance_changed(children) + + def relative_size_for_child(self, child: ParameterHierarchyItem) -> Tuple[float, float]: + """ + get relative size of a child in this layout + the exact interpretation of size is up to subclass to decide + :param child: + :return: + """ + raise NotImplementedError() + + +class OrderedParametersLayout(ParametersLayoutBase): + def __init__(self): + super(OrderedParametersLayout, self).__init__() + self.__parameter_order: List[ParameterHierarchyItem] = [] + + def _child_added(self, child: "ParameterHierarchyItem"): + super(OrderedParametersLayout, self)._child_added(child) + self.__parameter_order.append(child) + + def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): + self.__parameter_order.remove(child) + super(OrderedParametersLayout, self)._child_about_to_be_removed(child) + + def items(self, recursive=False): + """ + unlike base method, we need to return parameters in order + :param recursive: + :return: + """ + for child in self.__parameter_order: + yield child + if not recursive: + continue + elif isinstance(child, ParametersLayoutBase): + for child_param in child.items(recursive=recursive): + yield child_param + + def relative_size_for_child(self, child: ParameterHierarchyItem) -> Tuple[float, float]: + """ + get relative size of a child in this layout + the exact interpretation of size is up to subclass to decide + :param child: + :return: + """ + assert child in self.children() + return 1.0, 1.0 + + +class VerticalParametersLayout(OrderedParametersLayout): + """ + simple vertical parameter layout. + """ + pass + + +class CollapsableVerticalGroup(VerticalParametersLayout): + """ + a vertical parameter layout to be drawn as collapsable block + """ + + def __init__(self, group_name, group_label): + super(CollapsableVerticalGroup, self).__init__() + + # for now it's here just to ensure name uniqueness. in future - maybe store collapsed state + self.__unused_param = Parameter(group_name, group_name, NodeParameterType.BOOL, True) + + self.__group_name = group_name + self.__group_label = group_label + + def is_collapsed(self): + return True + + def name(self): + return self.__group_name + + def label(self): + return self.__group_label + + +class OneLineParametersLayout(OrderedParametersLayout): + """ + horizontal parameter layout. + unlike vertical, this one has to keep track of portions of line it's parameters are taking + parameters of this group should be rendered in one line + """ + + def __init__(self): + super(OneLineParametersLayout, self).__init__() + self.__hsizes = {} + + def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): + super(ParametersLayoutBase, self)._children_appearance_changed(children) + self.__hsizes = {} + + def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): + super(OneLineParametersLayout, self)._children_definition_changed(children) + self.__hsizes = {} + + def relative_size_for_child(self, child: ParameterHierarchyItem) -> Tuple[float, float]: + assert child in self.children() + if child not in self.__hsizes: + self._update_hsizes() + assert child in self.__hsizes + return self.__hsizes[child], 1.0 + + def _update_hsizes(self): + self.__hsizes = {} + totalitems = 0 + for item in self.items(): + if item.visible(): + totalitems += 1 + if totalitems == 0: + uniform_size = 1.0 + else: + uniform_size = 1.0 / float(totalitems) + for item in self.items(): + self.__hsizes[item] = uniform_size + + +class MultiGroupLayout(OrderedParametersLayout): + """ + this group can dynamically spawn more parameters according to it's template + spawning more parameters does NOT count as definition change + """ + + def __init__(self, name, label=None): + super(MultiGroupLayout, self).__init__() + self.__template: Union[ParametersLayoutBase, Parameter, None] = None + if label is None: + label = 'count' + self.__count_param = Parameter(name, label, NodeParameterType.INT, 0, can_have_expression=False) + self.__count_param.set_parent(self) + self.__last_count = 0 + self.__nested_indices = [] + + def nested_indices(self): + """ + if a multiparam is inside other multiparams - those multiparams should add their indices + to this one, so that this multiparam will be able to uniquely and predictable name it's parameters + """ + return tuple(self.__nested_indices) + + def __append_nested_index(self, index: int): + """ + this should be called only when a multiparam is instanced by another multiparam + """ + self.__nested_indices.append(index) + + def set_spawning_template(self, layout: ParametersLayoutBase): + self.__template = deepcopy(layout) + + def add_layout(self, new_layout: "ParametersLayoutBase"): + """ + this function is unavailable cuz of the nature of this layout + """ + raise LayoutError('NO') + + def add_parameter(self, new_parameter: Parameter): + """ + this function is unavailable cuz of the nature of this layout + """ + raise LayoutError('NO') + + def add_template_instance(self): + self.__count_param.set_value(self.__count_param.value() + 1) + + def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): + + for child in children: + if child == self.__count_param: + break + else: + super(MultiGroupLayout, self)._children_value_changed(children) + return + if self.__count_param.value() < 0: + self.__count_param.set_value(0) + super(MultiGroupLayout, self)._children_value_changed(children) + return + + new_count = self.__count_param.value() + if self.__last_count < new_count: + if self.__template is None: + raise LayoutError('template is not set') + for _ in range(new_count - self.__last_count): + # note: the check below is good, but it's not needed currently, cuz visibility condition on append checks common parent + # and nodes not from template do not share parents with template, so that prevents external references + for param in self.__template.parameters(recursive=True): + # sanity check - for now we only support references within the same template block only + for ref_param in param.references(): + if not ref_param.has_same_parent(param): + raise ParameterDefinitionError('Parameters within MultiGroupLayout\'s template currently cannot reference outer parameters') + ## + new_layout = deepcopy(self.__template) + i = len(self.children()) - 1 + for param in new_layout.parameters(recursive=True): + param._set_name(param.name() + '_' + '.'.join(str(x) for x in (*self.nested_indices(), i))) + parent = param.parent() + if isinstance(parent, MultiGroupLayout): + for idx in self.nested_indices(): + parent.__append_nested_index(idx) + parent.__append_nested_index(i) + new_layout.set_parent(self) + elif self.__last_count > self.__count_param.value(): + for _ in range(self.__last_count - new_count): + instances = list(self.items(recursive=False)) + assert len(instances) > 1 + instances[-1].set_parent(None) + self.__last_count = new_count + super(MultiGroupLayout, self)._children_value_changed(children) + + def _child_added(self, child: "ParameterHierarchyItem"): + super(MultiGroupLayout, self)._child_added(child) + + def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): + super(MultiGroupLayout, self)._child_about_to_be_removed(child) + + +@dataclass +class ParameterFullValue: + unexpanded_value: Union[int, float, str, bool] + expression: Optional[str] + diff --git a/src/lifeblood/node_ui.py b/src/lifeblood/node_ui.py new file mode 100644 index 00000000..732eca95 --- /dev/null +++ b/src/lifeblood/node_ui.py @@ -0,0 +1,492 @@ +import asyncio +import pickle +from copy import deepcopy +from .enums import NodeParameterType +from .node_visualization_classes import NodeColorScheme +from .node_ui_callback_receiver_base import NodeUiCallbackReceiverBase +from .node_parameters import CollapsableVerticalGroup, LayoutError, LayoutReadonlyError, MultiGroupLayout, OneLineParametersLayout, Parameter, ParameterError, ParameterFullValue, ParameterHierarchyItem, ParameterLocked, ParameterReadonly, ParametersLayoutBase, Separator, VerticalParametersLayout +from .logging import get_logger + +from typing import Dict, Any, Optional, Tuple, Iterable, Callable + + +class NodeUiError(RuntimeError): + pass + + +class NodeUiDefinitionError(RuntimeError): + pass + + +class _SpecialOutputCountChangingLayout(VerticalParametersLayout): + def __init__(self, nodeui: "NodeUi", parameter_name, parameter_label): + super(_SpecialOutputCountChangingLayout, self).__init__() + self.__my_nodeui = nodeui + newparam = Parameter(parameter_name, parameter_label, NodeParameterType.INT, 2, can_have_expression=False) + newparam.set_value_limits(2) + with self.initializing_interface_lock(): + self.add_parameter(newparam) + + def add_layout(self, new_layout: "ParametersLayoutBase"): + """ + this function is unavailable cuz of the nature of this layout + """ + raise LayoutError('NO') + + def add_parameter(self, new_parameter: Parameter): + """ + this function is unavailable cuz of the nature of this layout + """ + if len(list(self.parameters())) > 0: + raise LayoutError('NO') + super(_SpecialOutputCountChangingLayout, self).add_parameter(new_parameter) + + def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): + # we expect this special layout to have only one single specific child + child = None + for child in children: + break + if child is None: + return + assert isinstance(child, Parameter) + new_num_outputs = child.value() + num_outputs = len(self.__my_nodeui.outputs_names()) + if num_outputs == new_num_outputs: + return + + if num_outputs < new_num_outputs: + for i in range(num_outputs, new_num_outputs): + self.__my_nodeui._add_output_unsafe(f'output{i}') + else: # num_outputs > new_num_outputs + for _ in range(new_num_outputs, num_outputs): + self.__my_nodeui._remove_last_output_unsafe() + self.__my_nodeui._outputs_definition_changed() + + +class NodeUi(ParameterHierarchyItem): + def __init__(self, change_callback_receiver: NodeUiCallbackReceiverBase): + super(NodeUi, self).__init__() + self.__logger = get_logger('scheduler.nodeUI') + self.__parameter_layout = VerticalParametersLayout() + self.__parameter_layout.set_parent(self) + self.__change_callback_receiver: Optional[NodeUiCallbackReceiverBase] = change_callback_receiver + self.__block_ui_callbacks = False + self.__lock_ui_readonly = False + self.__postpone_ui_callbacks = False + self.__postponed_callbacks = None + self.__inputs_names = ('main',) + self.__outputs_names = ('main',) + + self.__groups_stack = [] + + self.__have_output_parameter_set: bool = False + + # default colorscheme + self.__color_scheme = NodeColorScheme() + self.__color_scheme.set_main_color(0.1882, 0.2510, 0.1882) # dark-greenish + + def set_ui_change_callback_receiver(self, callback_receiver: NodeUiCallbackReceiverBase): + self.__change_callback_receiver = callback_receiver + + def color_scheme(self): + return self.__color_scheme + + def main_parameter_layout(self): + return self.__parameter_layout + + def parent(self) -> Optional["ParameterHierarchyItem"]: + return None + + def set_parent(self, item: Optional["ParameterHierarchyItem"]): + if item is not None: + raise RuntimeError('NodeUi class is supposed to be tree root') + + def initializing_interface_lock(self): + return self.block_ui_callbacks() + + def block_ui_callbacks(self): + class _iiLock: + def __init__(self, lockable): + self.__nui = lockable + self.__prev_state = None + + def __enter__(self): + self.__prev_state = self.__nui._NodeUi__block_ui_callbacks + self.__nui._NodeUi__block_ui_callbacks = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__nui._NodeUi__block_ui_callbacks = self.__prev_state + + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + return _iiLock(self) + + def lock_interface_readonly(self): + raise NotImplementedError("read trello task, read TODO. this do NOT work multitheaded, leads to permalocks, needs rethinking") + class _roLock: + def __init__(self, lockable): + self.__nui = lockable + self.__prev_state = None + + def __enter__(self): + self.__prev_state = self.__nui._NodeUi__lock_ui_readonly + self.__nui._NodeUi__lock_ui_readonly = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__nui._NodeUi__lock_ui_readonly = self.__prev_state + + return _roLock(self) + + def postpone_ui_callbacks(self): + """ + use this in with-statement + for mass change of parameters it may be more efficient to perform changes in batches + """ + class _iiPostpone: + def __init__(self, nodeui): + self.__nui = nodeui + self.__val = None + + def __enter__(self): + if not self.__nui._NodeUi__postpone_ui_callbacks: + assert self.__nui._NodeUi__postponed_callbacks is None + self.__val = self.__nui._NodeUi__postpone_ui_callbacks + self.__nui._NodeUi__postpone_ui_callbacks = True + # otherwise: already blocked - we are in nested block, ignore + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.__val is None: + return + assert not self.__val # nested block should do nothing + self.__nui._NodeUi__postpone_ui_callbacks = self.__val + if self.__nui._NodeUi__postponed_callbacks is not None: + self.__nui._NodeUi__ui_callback(self.__nui._NodeUi__postponed_callbacks) + self.__nui._NodeUi__postponed_callbacks = None + + return _iiPostpone(self) + + class _slwrapper: + def __init__(self, ui: "NodeUi", layout_creator, layout_creator_kwargs=None): + self.__ui = ui + self.__layout_creator = layout_creator + self.__layout_creator_kwargs = layout_creator_kwargs or {} + + def __enter__(self): + new_layout = self.__layout_creator(**self.__layout_creator_kwargs) + self.__ui._NodeUi__groups_stack.append(new_layout) + with self.__ui._NodeUi__parameter_layout.initializing_interface_lock(): + self.__ui._NodeUi__parameter_layout.add_layout(new_layout) + + def __exit__(self, exc_type, exc_val, exc_tb): + layout = self.__ui._NodeUi__groups_stack.pop() + self.__ui._add_layout(layout) + + def parameters_on_same_line_block(self): + """ + use it in with statement + :return: + """ + return self.parameter_layout_block(OneLineParametersLayout) + + def parameter_layout_block(self, parameter_layout_producer: Callable[[], ParametersLayoutBase]): + """ + arbitrary simple parameter override block + use it in with statement + :return: + """ + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + return NodeUi._slwrapper(self, parameter_layout_producer) + + def add_parameter_to_control_output_count(self, parameter_name: str, parameter_label: str): + """ + a very special function for a very special case when you want the number of outputs to be controlled + by a parameter + + from now on output names will be: 'main', 'output1', 'output2', ... + + :return: + """ + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + if self.__have_output_parameter_set: + raise NodeUiDefinitionError('there can only be one parameter to control output count') + self.__have_output_parameter_set = True + self.__outputs_names = ('main', 'output1') + + with self.parameter_layout_block(lambda: _SpecialOutputCountChangingLayout(self, parameter_name, parameter_label)): + # no need to do anything, with block will add that layout to stack, and parameter is created in that layout's constructor + layout = self.current_layout() + # this layout should always have exactly one parameter + assert len(list(layout.parameters())) == 1, f'oh no, {len(list(layout.parameters()))}' + return layout.parameter(parameter_name) + + def multigroup_parameter_block(self, name: str, label: Optional[str] = None): + """ + use it in with statement + creates a block like multiparameter block in houdini + any parameters added will be actually added to template to be instanced later as needed + :return: + """ + class _slwrapper_multi: + def __init__(self, ui: "NodeUi", name: str, label: Optional[str] = None): + self.__ui = ui + self.__new_layout = None + self.__name = name + self.__label = label + + def __enter__(self): + self.__new_layout = VerticalParametersLayout() + self.__ui._NodeUi__groups_stack.append(self.__new_layout) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.__ui._NodeUi__groups_stack.pop() == self.__new_layout + with self.__ui._NodeUi__parameter_layout.initializing_interface_lock(): + multi_layout = MultiGroupLayout(self.__name, self.__label) + with multi_layout.initializing_interface_lock(): + multi_layout.set_spawning_template(self.__new_layout) + self.__ui._add_layout(multi_layout) + + def multigroup(self) -> VerticalParametersLayout: + return self.__new_layout + + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + return _slwrapper_multi(self, name, label) + + def current_layout(self): + """ + get current layout to which add_parameter would add parameter + this can be main nodeUI's layout, but can be something else, if we are in some with block, + like for ex: collapsable_group_block or parameters_on_same_line_block + + :return: + """ + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + layout = self.__parameter_layout + if len(self.__groups_stack) != 0: + layout = self.__groups_stack[-1] + return layout + + def collapsable_group_block(self, group_name: str, group_label: str = ''): + """ + use it in with statement + creates a visually distinct group of parameters that renderer should draw as a collapsable block + + :return: + """ + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + return NodeUi._slwrapper(self, CollapsableVerticalGroup, {'group_name': group_name, 'group_label': group_label}) + + def _add_layout(self, new_layout): + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + layout = self.__parameter_layout + if len(self.__groups_stack) != 0: + layout = self.__groups_stack[-1] + with layout.initializing_interface_lock(): + layout.add_layout(new_layout) + + def add_parameter(self, param_name: str, param_label: Optional[str], param_type: NodeParameterType, param_val: Any, can_have_expressions: bool = True, readonly: bool = False) -> Parameter: + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + layout = self.__parameter_layout + if len(self.__groups_stack) != 0: + layout = self.__groups_stack[-1] + with layout.initializing_interface_lock(): + newparam = Parameter(param_name, param_label, param_type, param_val, can_have_expressions, readonly) + layout.add_parameter(newparam) + return newparam + + def add_separator(self): + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + layout = self.__parameter_layout + if len(self.__groups_stack) != 0: + layout = self.__groups_stack[-1] + with layout.initializing_interface_lock(): + newsep = Separator() + layout.add_generic_leaf(newsep) + return newsep + + def add_input(self, input_name): + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + if input_name not in self.__inputs_names: + self.__inputs_names += (input_name,) + + def _add_output_unsafe(self, output_name): + if output_name not in self.__outputs_names: + self.__outputs_names += (output_name,) + + def add_output(self, output_name): + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + if self.__have_output_parameter_set: + raise NodeUiDefinitionError('cannot add outputs when output count is controlled by a parameter') + return self._add_output_unsafe(output_name) + + def _remove_last_output_unsafe(self): + if len(self.__outputs_names) < 2: + return + self.__outputs_names = self.__outputs_names[:-1] + + def remove_last_output(self): + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if not self.__block_ui_callbacks: + raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') + if self.__have_output_parameter_set: + raise NodeUiDefinitionError('cannot add outputs when output count is controlled by a parameter') + return self._remove_last_output_unsafe() + + def add_output_for_spawned_tasks(self): + return self.add_output('spawned') + + def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): + self.__ui_callback(definition_changed=True) + + def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): + self.__ui_callback(definition_changed=False) + + def _outputs_definition_changed(self): #TODO: not entirely sure how safe this is right now + self.__ui_callback(definition_changed=True) + + def __ui_callback(self, definition_changed=False): + if self.__lock_ui_readonly: + raise LayoutReadonlyError() + if self.__postpone_ui_callbacks: + # so we save definition_changed to __postponed_callbacks + self.__postponed_callbacks = self.__postponed_callbacks or definition_changed + return + + if self.__change_callback_receiver is not None and not self.__block_ui_callbacks: + self.__change_callback_receiver._ui_changed(definition_changed) + + def inputs_names(self) -> Tuple[str]: + return self.__inputs_names + + def outputs_names(self) -> Tuple[str]: + return self.__outputs_names + + def parameter(self, param_name: str) -> Parameter: + return self.__parameter_layout.parameter(param_name) + + def parameters(self) -> Iterable[Parameter]: + return self.__parameter_layout.parameters(recursive=True) + + def items(self, recursive=False) -> Iterable[ParameterHierarchyItem]: + return self.__parameter_layout.items(recursive=recursive) + + def set_parameters_batch(self, parameters: Dict[str, ParameterFullValue]): + """ + If signal blocking is needed - caller can do it + + for now it's implemented the stupid way + """ + names_to_set = list(parameters.keys()) + names_to_set.append(None) + something_set_this_iteration = False + parameters_were_postponed = False + for param_name in names_to_set: + if param_name is None: + if parameters_were_postponed: + if not something_set_this_iteration: + self.__logger.warning(f'failed to set all parameters!') + break + names_to_set.append(None) + something_set_this_iteration = False + continue + assert isinstance(param_name, str) + param = self.parameter(param_name) + if param is None: + parameters_were_postponed = True + continue + param_value = parameters[param_name] + try: + param.set_value(param_value.unexpanded_value) + except (ParameterReadonly, ParameterLocked): + # if value is already correct - just skip + if param.unexpanded_value() != param_value.unexpanded_value: + self.__logger.error(f'unable to set value for "{param_name}"') + # shall we just ignore the error? + except ParameterError as e: + self.__logger.error(f'failed to set value for "{param_name}" because {repr(e)}') + if param.can_have_expressions(): + try: + param.set_expression(param_value.expression) + except (ParameterReadonly, ParameterLocked): + # if value is already correct - just skip + if param.expression() != param_value.expression: + self.__logger.error(f'unable to set expression for "{param_name}"') + # shall we just ignore the error? + except ParameterError as e: + self.__logger.error(f'failed to set expression for "{param_name}" because {repr(e)}') + elif param_value.expression is not None: + self.__logger.error(f'parameter "{param_name}" cannot have expressions, yet expression is stored for it') + + something_set_this_iteration = True + + def __deepcopy__(self, memo): + cls = self.__class__ + crap = cls.__new__(cls) + newdict = self.__dict__.copy() + newdict['_NodeUi__change_callback_receiver'] = None + newdict['_NodeUi__lock_ui_readonly'] = False + assert id(self) not in memo + memo[id(self)] = crap # to avoid recursion, though manual tells us to treat memo as opaque object + for k, v in newdict.items(): + crap.__dict__[k] = deepcopy(v, memo) + return crap + + def __setstate__(self, state): + ensure_attribs = { # this exists only for the ease of upgrading NodeUi classes during development + '_NodeUi__lock_ui_readonly': False, + '_NodeUi__postpone_ui_callbacks': False + } + self.__dict__.update(state) + for attrname, default_value in ensure_attribs.items(): + if not hasattr(self, attrname): + setattr(self, attrname, default_value) + + def serialize(self) -> bytes: + """ + note - this serialization disconnects the node to which this UI is connected + :return: + """ + obj = deepcopy(self) + assert obj.__change_callback_receiver is None + return pickle.dumps(obj) + + async def serialize_async(self) -> bytes: + return await asyncio.get_event_loop().run_in_executor(None, self.serialize) + + def __repr__(self): + return 'NodeUi: ' + ', '.join(('%s: %s' % (x.name() if isinstance(x, Parameter) else '-layout-', x) for x in self.__parameter_layout.items())) + + @classmethod + def deserialize(cls, data: bytes) -> "NodeUi": + return pickle.loads(data) + + @classmethod + async def deserialize_async(cls, data: bytes) -> "NodeUi": + return await asyncio.get_event_loop().run_in_executor(None, cls.deserialize, data) diff --git a/src/lifeblood/node_ui_callback_receiver_base.py b/src/lifeblood/node_ui_callback_receiver_base.py new file mode 100644 index 00000000..7b876621 --- /dev/null +++ b/src/lifeblood/node_ui_callback_receiver_base.py @@ -0,0 +1,3 @@ +class NodeUiCallbackReceiverBase: + def _ui_changed(self, definition_changed: bool = False): + raise NotImplementedError() diff --git a/src/lifeblood/processingcontext.py b/src/lifeblood/processingcontext.py index 313c9bf3..99f36cd0 100644 --- a/src/lifeblood/processingcontext.py +++ b/src/lifeblood/processingcontext.py @@ -2,17 +2,17 @@ from .attribute_serialization import deserialize_attributes_core from .environment_resolver import EnvironmentResolverArguments +from .node_ui import NodeUi +from .node_parameters import Parameter +from .expression_locals_provider_base import ExpressionLocalsProviderBase -from typing import Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, Union -if TYPE_CHECKING: - from .basenode import BaseNode - from .uidata import Parameter - -class ProcessingContext: +class ProcessingContext(ExpressionLocalsProviderBase): class TaskWrapper: def __init__(self, task_dict: dict): + super().__init__() self.__attributes = deserialize_attributes_core(task_dict.get('attributes', '{}')) self.__stuff = task_dict @@ -28,9 +28,9 @@ def get(self, item, default): return self.__attributes.get(item, default) class NodeWrapper: - def __init__(self, node: "BaseNode", context: "ProcessingContext"): - self.__parameters: Dict[str, "Parameter"] = {x.name(): x for x in node.get_ui().parameters()} - self.__attrs = {'name': node.name(), 'label': node.label()} + def __init__(self, node_name: str, node_label: str, node_ui: NodeUi, context: "ProcessingContext"): + self.__parameters: Dict[str, Parameter] = {x.name(): x for x in node_ui.parameters()} + self.__attrs = {'name': node_name, 'label': node_label} self.__context = context def __getitem__(self, item): @@ -51,7 +51,7 @@ def get(self, key, default=None): def __getitem__(self, item): return self.get(item) - def __init__(self, node: "BaseNode", task_dict: dict, node_config: Dict[str, Union[str, int, float, list, dict]]): + def __init__(self, node_name: str, node_label: str, node_ui: NodeUi, task_dict: dict, node_config: Dict[str, Union[str, int, float, list, dict]]): """ All information node can access during processing. This is read-only. @@ -63,15 +63,15 @@ def __init__(self, node: "BaseNode", task_dict: dict, node_config: Dict[str, Uni self.__task_attributes = deserialize_attributes_core(task_dict.get('attributes', '{}')) self.__task_dict = task_dict self.__task_wrapper = ProcessingContext.TaskWrapper(task_dict) - self.__node_wrapper = ProcessingContext.NodeWrapper(node, self) + self.__node_wrapper = ProcessingContext.NodeWrapper(node_name, node_label, node_ui, self) self.__env_args = EnvironmentResolverArguments.deserialize(task_dict.get('environment_resolver_data')) if task_dict.get('environment_resolver_data') is not None else None self.__conf_wrapper = ProcessingContext.ConfigWrapper(node_config) - self.__node = node + self.__node_ui = node_ui def param_value(self, param_name: str): - return self.__node.get_ui().parameter(param_name).value(self) + return self.__node_ui.parameter(param_name).value(self) - def locals(self): + def locals(self) -> Dict[str, Any]: """ locals to be available during expression evaluation node - represents current node From b30871930913057ed99e796a3856ed430d1a2b37 Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:11:33 +0100 Subject: [PATCH 03/10] refactor: break worker<->protocol/message processor dep cycle --- src/lifeblood/worker.py | 905 +----------------- src/lifeblood/worker_core.py | 901 +++++++++++++++++ src/lifeblood/worker_invocation_protocol.py | 11 +- ...ocessor.py => worker_message_processor.py} | 100 +- .../worker_message_processor_client.py | 82 ++ 5 files changed, 1022 insertions(+), 977 deletions(-) create mode 100644 src/lifeblood/worker_core.py rename src/lifeblood/{worker_messsage_processor.py => worker_message_processor.py} (70%) create mode 100644 src/lifeblood/worker_message_processor_client.py diff --git a/src/lifeblood/worker.py b/src/lifeblood/worker.py index bdeb6341..7362de87 100644 --- a/src/lifeblood/worker.py +++ b/src/lifeblood/worker.py @@ -1,48 +1,14 @@ -import random -import sys -import os -import copy -import errno -import shutil -import threading -import asyncio -import aiofiles -import psutil -import datetime -import time -import tempfile -from . import logging -from .nethelpers import get_addr_to, get_localhost, get_hostname -from .hardware_resources import HardwareResources -from .worker_metadata import WorkerMetadata -from .exceptions import WorkerNotAvailable, AlreadyRunning, \ - InvocationMessageWrongInvocationId, InvocationMessageAddresseeTimeout, InvocationCancelled -from .worker_messsage_processor import WorkerMessageProcessor -from .scheduler_message_processor import SchedulerWorkerControlClient +from .config import Config +from .enums import WorkerType, ProcessPriorityAdjustment +from .net_messages.address import AddressChain +from .worker_core import WorkerCore +from .worker_message_processor import WorkerMessageProcessor from .worker_invocation_protocol import WorkerInvocationProtocolHandlerV10, WorkerInvocationServerProtocol -from .worker_pool_message_processor import WorkerPoolControlClient -from .invocationjob import Invocation, InvocationEnvironment -from .config import get_config, Config -from .misc import get_unique_machine_id -from . import environment_resolver -from .enums import WorkerType, WorkerState, ProcessPriorityAdjustment -from .paths import log_path -from .process_utils import kill_process_tree -from .misc import event_set_context -from .net_messages.address import AddressChain, DirectAddress -from .net_messages.exceptions import MessageTransferError -from .defaults import worker_start_port as default_worker_start_port -from .worker_runtime_pythonpath import lifeblood_connection -import inspect +from typing import Optional -from typing import Dict, Optional, Tuple - -is_posix = not sys.platform.startswith('win') - - -class Worker: +class Worker(WorkerCore): def __init__(self, scheduler_addr: AddressChain, *, child_priority_adjustment: ProcessPriorityAdjustment = ProcessPriorityAdjustment.NO_CHANGE, worker_type: WorkerType = WorkerType.STANDARD, @@ -51,847 +17,18 @@ def __init__(self, scheduler_addr: AddressChain, *, scheduler_ping_interval: float = 10, scheduler_ping_miss_threshold: int = 6, worker_id: Optional[int] = None, - pool_address: Optional[AddressChain] = None): - """ - - :param scheduler_addr: - :param worker_type: - :param singleshot: - """ - self.__config = config or get_config('worker') - self.__logger = logging.get_logger('worker') - self.log_root_path: str = '' - for self.log_root_path in (os.path.expandvars(self.__config.get_option_noasync('worker.logpath', log_path('invocations', 'worker', ensure_path_exists=False))), - os.path.join(tempfile.gettempdir(), 'lifeblood', 'worker_logs')): - logs_ok = True - try: - if not os.path.exists(self.log_root_path): - os.makedirs(self.log_root_path, exist_ok=True) - except PermissionError: - logs_ok = False - except OSError as e: - if e.errno == errno.EACCES: - logs_ok = False - logs_ok = logs_ok and os.access(self.log_root_path, os.W_OK) - if logs_ok: - break - self.__logger.warning(f'could not use location {self.log_root_path} for logs, trying another...') - else: - raise RuntimeError('could not initialize logs directory') - self.__logger.info(f'using {self.log_root_path} for invocation logs') - - self.__status = {} - self.__scheduler_db_uid: int = 0 # unsigned 64bit int - self.__running_process: Optional[asyncio.subprocess.Process] = None - self.__running_process_start_time: float = 0 - self.__running_task: Optional[Invocation] = None - self.__running_task_progress: Optional[float] = None - self.__running_awaiter = None - self.__previous_notrunning_awaiter = None # here we will temporarily save running_awaiter before it is set to None again when task canceled or finished, to avoid task being GCd while in work - self.__message_processor: Optional[WorkerMessageProcessor] = None - self.__local_invocation_server: Optional[asyncio.Server] = None - self.__local_invocation_server_address_string: str = '' - - self.__local_shared_dir = self.__config.get_option_noasync("local_shared_dir_path", os.path.join(tempfile.gettempdir(), 'lifeblood_worker', 'shared')) - - # resources - config_resources = self.__config.get_option_noasync('resources') - config_devices = self.__config.get_option_noasync('devices') - if config_resources is not None: # schema check - if not isinstance(config_resources, dict): - raise RuntimeError('resources config section must be a mapping') - else: - config_resources = {} - if config_devices is not None: # schema check - if not isinstance(config_devices, dict): - raise RuntimeError('devices config section must be a mapping') - for key, val in config_devices.items(): - if not isinstance(key, str): - raise RuntimeError('devices config (keys) must be strings') - if not isinstance(val, dict): - raise RuntimeError('devices config .<> (values) config section must be a mapping') - for key_name, val_devname_data in val.items(): - if not isinstance(key_name, str): - raise RuntimeError('devices config . (keys) must be strings') - if not isinstance(val_devname_data, dict): - raise RuntimeError('devices config ..<> (values) section must be a mapping') - if 'resources' in val_devname_data: - for key_resname, val_resval in val_devname_data['resources'].items(): - if not isinstance(key_resname, str): - raise RuntimeError('devices config ..resources. (keys) must be strings') - if not isinstance(val_resval, (int, float, str)): - raise RuntimeError('devices config ..resources..<> (values) must be ints, floats or special strings like "32G"') - if 'tags' in val_devname_data: - for key_tagname, val_tagval in val_devname_data['tags'].items(): - if not isinstance(key_tagname, str): - raise RuntimeError('devices config ..tags. (keys) must be strings') - if not isinstance(val_tagval, (int, float, str)): - raise RuntimeError('devices config ..tags..<> (values) must be ints, floats or strings') - else: - config_devices = {} - self.__my_resources = HardwareResources( - hwid=get_unique_machine_id() if self.__config.get_option_noasync('worker.override_hwid') is None else self.__config.get_option_noasync('worker.override_hwid'), - resources={ - 'cpu_count': psutil.cpu_count(), - 'cpu_mem': psutil.virtual_memory().total, - **config_resources - }, - devices=[(dev_type, dev_name, dev_stuff.get('resources', {})) for dev_type, dev_dev in config_devices.items() for dev_name, dev_stuff in dev_dev.items()], + pool_address: Optional[AddressChain] = None, + ): + super().__init__( + scheduler_addr=scheduler_addr, + child_priority_adjustment=child_priority_adjustment, + worker_type=worker_type, + config=config, + singleshot=singleshot, + scheduler_ping_interval=scheduler_ping_interval, + scheduler_ping_miss_threshold=scheduler_ping_miss_threshold, + worker_id=worker_id, + pool_address=pool_address, + message_processor_factory=WorkerMessageProcessor, + worker_invocation_protocol_factory=lambda worker: WorkerInvocationServerProtocol(worker, [WorkerInvocationProtocolHandlerV10(worker)]), ) - self.__my_device_tags = { - dev_type: { - dev_name: {tag_name: tag_val for tag_name, tag_val in dev_stuff.get('tags', {}).items()} for dev_name, dev_stuff in dev_dev.items() - } for dev_type, dev_dev in config_devices.items() - } - - self.__task_changing_state_lock = asyncio.Lock() - self.__task_switching_event = asyncio.Event() # this will signal invocation message waiters to cancel what they are doing - self.__stop_lock = threading.Lock() - self.__start_lock = asyncio.Lock() # cant use threading lock in async methods - it can yeild out, and deadlock on itself - self.__where_to_report: Optional[AddressChain] = None - self.__ping_interval = scheduler_ping_interval - self.__ping_missed_threshold = scheduler_ping_miss_threshold - self.__ping_missed = 0 - self.__scheduler_addr = scheduler_addr - self.__scheduler_pinger = None - self.__components_stop_event = asyncio.Event() - self.__extra_files_base_dir = None - self.__my_addr_for_scheduler: Optional[AddressChain] = None - self.__message_address: Optional[DirectAddress] = None - self.__worker_id = worker_id - self.__pool_address: Optional[AddressChain] = pool_address - if self.__worker_id is None and self.__pool_address is not None \ - or self.__worker_id is not None and self.__pool_address is None: - raise RuntimeError('pool_address must be given together with worker_id') - - self.__worker_task_comm_queues: Dict[str, asyncio.Queue] = {} - - self.__worker_type: WorkerType = worker_type - self.__singleshot: bool = singleshot or worker_type == WorkerType.SCHEDULER_HELPER - - self.__child_priority_adjustment = child_priority_adjustment - # this below is a placeholder solution. the easiest way to implement priority lowering without testing on different platrofms - if self.__child_priority_adjustment == ProcessPriorityAdjustment.LOWER: - if sys.platform.startswith('win'): - assert hasattr(psutil, 'BELOW_NORMAL_PRIORITY_CLASS') - psutil.Process().nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) - else: - psutil.Process().nice(10) - - # deploy a copy of runtime module somewhere in temp - rtmodule_code = inspect.getsource(lifeblood_connection) - - filepath = os.path.join(tempfile.gettempdir(), 'lifeblood', 'lifeblood_runtime', 'lifeblood_connection.py') - os.makedirs(os.path.dirname(filepath), exist_ok=True) - existing_code = None - if os.path.exists(filepath): - with open(filepath, 'r') as f: - existing_code = f.read() - - if existing_code != rtmodule_code: - with open(filepath, 'w') as f: - f.write(rtmodule_code) - self.__rt_module_dir = os.path.dirname(filepath) - - self.__stopping_waiters = [] - self.__finished = asyncio.Event() - self.__started = False - self.__started_event = asyncio.Event() - self.__stopped = False - - def message_processor(self) -> WorkerMessageProcessor: - return self.__message_processor - - def scheduler_message_address(self) -> AddressChain: - return self.__scheduler_addr - - async def start(self): - if self.__started: - return - if self.__stopped: - raise RuntimeError('already stopped, cannot start again') - - async with self.__start_lock: - abort_start = False - - # start local server for invocation api connections - loop = asyncio.get_event_loop() - localhost = get_localhost() - localport_start = 10101 - localport_end = 11111 - localport = None - for _ in range(localport_end - localport_start): # big but finite - localport = random.randint(localport_start, localport_end) - try: - self.__local_invocation_server = await loop.create_server( - lambda: WorkerInvocationServerProtocol(self, [WorkerInvocationProtocolHandlerV10(self)]), - localhost, - localport - ) - break - except OSError as e: - if e.errno != errno.EADDRINUSE: - raise - continue - else: - raise RuntimeError('could not find an opened port!') - self.__local_invocation_server_address_string = f'{localhost}:{localport}' - - # start message processor - my_ip = get_addr_to(self.__scheduler_addr.split_address()[0]) - my_port = default_worker_start_port() - for i in range(1024): # big but finite - try: - self.__message_processor = WorkerMessageProcessor(self, (my_ip, my_port)) - await self.__message_processor.start() - break - except OSError as e: - if e.errno != errno.EADDRINUSE: - raise - my_port += 1 - continue - else: - raise RuntimeError('could not find an opened port!') - - self.__message_address = DirectAddress.from_host_port(my_ip, my_port) - - # now report our address to the scheduler - metadata = WorkerMetadata(get_hostname()) - try: - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient - # re-normalize addresses - self.__scheduler_addr, self.__my_addr_for_scheduler = await client.get_normalized_addresses() - self.__scheduler_db_uid = await client.say_hello(self.__my_addr_for_scheduler, self.__worker_type, self.__my_resources, metadata) - except MessageTransferError as e: - self.__logger.error('error connecting to scheduler during start') - abort_start = True - # - # and report to the pool - try: - if self.__worker_id is not None: - assert self.__pool_address is not None - with WorkerPoolControlClient.get_worker_pool_control_client(self.__pool_address, self.__message_processor) as wpclient: # type: WorkerPoolControlClient - await wpclient.report_state(self.__worker_id, WorkerState.IDLE) - except ConnectionError as e: - self.__logger.error('error connecting to worker pool during start') - abort_start = True - - self.__scheduler_pinger = asyncio.create_task(self.scheduler_pinger()) - self.__started = True - self.__started_event.set() - if abort_start: - self.__logger.error('error during stating worker, aborting!') - self.stop() - else: - self.__logger.info('worker started') - - def is_started(self): - return self.__started - - def wait_till_starts(self): # we can await this function cuz it returns a future... - return self.__started_event.wait() - - def stop(self): - async def _send_byebye(): - try: - self.__logger.debug('saying bye to scheduler') - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient - await client.say_bye(self.__my_addr_for_scheduler) - except MessageTransferError: # if scheduler or route is down - self.__logger.info('couldn\'t say bye to scheduler as it seem to be down') - except Exception: - self.__logger.exception('couldn\'t say bye to scheduler for unknown reason') - - if not self.__started or self.__stopped: - return - with self.__stop_lock: # NOTE: there is literally no threading in worker, so this is excessive - self.__logger.info('STOPPING WORKER') - self.__components_stop_event.set() - - async def _finalizer(): - await self.__scheduler_pinger # to ensure pinger stops and won't try to contact scheduler any more - await self.cancel_task() # then we cancel task, here we still can report it to the scheduler. - # no new tasks will be picked up cuz __stopped is already set - self.__local_invocation_server.close() - await _send_byebye() # saying bye, don't bother us. (some delayed comms may still come through to the __server - await self.__local_invocation_server.wait_closed() - self.__message_processor.stop() - await self.__message_processor.wait_till_stops() - self.__logger.info('message processor stopped') - - self.__stopping_waiters.append(asyncio.create_task(_finalizer())) - self.__finished.set() - self.__stopped = True - - async def wait_till_stops(self): - # if self.__scheduler_pinger is not None: - # #try: - # await self.__scheduler_pinger - # #except asyncio.CancelledError: - # # self.__logger.debug('wait_to_finished: scheduler_pinger was cancelled') - # # #raise - # self.__scheduler_pinger = None - # await self.__server.wait_closed() - await self.__finished.wait() - self.__logger.info('server closed') - await self.__scheduler_pinger - self.__logger.info('pinger closed') - for waiter in self.__stopping_waiters: - await waiter - - def get_log_filepath(self, level, invocation_id: int = None): # TODO: think of a better, more generator-style way of returning logs - if self.__running_task is None and invocation_id is None: - return os.path.join(self.log_root_path, f'db_{self.__scheduler_db_uid:016x}', 'common', level) - else: - return os.path.join(self.log_root_path, f'db_{self.__scheduler_db_uid:016x}', 'invocations', str(invocation_id or self.__running_task.invocation_id()), level) - - async def delete_logs(self, invocation_id: int): - self.__logger.debug(f'removing logs for {invocation_id}') - path = os.path.join(self.log_root_path, f'db_{self.__scheduler_db_uid:016x}', 'invocations', str(invocation_id or self.__running_task.invocation_id())) - await asyncio.get_event_loop().run_in_executor(None, shutil.rmtree, path) # assume that deletion MAY take time, so allow util tasks to be processed while we wait - - async def run_task(self, task: Invocation, report_to: AddressChain): - if self.__stopped: - raise WorkerNotAvailable() - self.__logger.debug(f'locks are {self.__task_changing_state_lock.locked()}') - async with self.__task_changing_state_lock: - self.__logger.debug('run_task: task_change_state locks acquired') - # we must ensure picking up and finishing tasks is in critical section - assert len(task.job_definition().args()) > 0 - if self.__running_process is not None: - raise AlreadyRunning('Task already in progress') - - # prepare logging - self.__logger.info(f'running task {task}') - - # save external files - self.__extra_files_base_dir = None - extra_files_map: Dict[str, str] = {} - if len(task.job_definition().extra_files()) > 0: - self.__extra_files_base_dir = tempfile.mkdtemp(prefix='lifeblood_efs_') # TODO: add base temp dir to config - self.__logger.debug(f'creating extra file temporary dir at {self.__extra_files_base_dir}') - for exfilepath, exfiledata in task.job_definition().extra_files().items(): - self.__logger.info(f'saving extra job file {exfilepath}') - exfilepath_parts = exfilepath.split('/') - tmpfilepath = os.path.join(self.__extra_files_base_dir, *exfilepath_parts) - os.makedirs(os.path.dirname(tmpfilepath), exist_ok=True) - with open(tmpfilepath, 'w' if isinstance(exfiledata, str) else 'wb') as f: - f.write(exfiledata) - extra_files_map[exfilepath] = tmpfilepath - - # check args for extra file references - if len(task.job_definition().extra_files()) > 0: - args = [] - for arg in task.job_definition().args(): - if isinstance(arg, str) and arg.startswith(':/') and arg[2:] in task.job_definition().extra_files(): - args.append(extra_files_map[arg[2:]]) - else: - args.append(arg) - else: - args = task.job_definition().args() - - try: - if task.job_definition().environment_resolver_arguments() is None: - resolver = environment_resolver.get_resolver(self.__config.get_option_noasync('default_env_wrapper.name', 'TrivialEnvironmentResolver')) - resolver_arguments = self.__config.get_option_noasync('default_env_wrapper.arguments', {}) - else: - env_res_args = task.job_definition().environment_resolver_arguments() - resolver = env_res_args.get_resolver() - resolver_arguments = env_res_args.arguments() - except environment_resolver.ResolutionImpossibleError as e: - self.__logger.error(f'cannot run the task: Unable to resolve environment: {str(e)}') - raise - - # TODO: resolver args get_environment() acually does resolution so should be renamed to like resolve_environment() - # Environment's resolve() actually just expands and merges everything, so naming it "resolve" is misleading next to EnvironmentResolver - - env = copy.deepcopy(task.job_definition().env() or InvocationEnvironment()) - - env.prepend('PYTHONPATH', self.__rt_module_dir) - env['LIFEBLOOD_RUNTIME_IID'] = task.invocation_id() - env['LIFEBLOOD_RUNTIME_TID'] = task.task_id() - env['LIFEBLOOD_RUNTIME_SCHEDULER_ADDR'] = self.__local_invocation_server_address_string - - env['LBDEV_TYPES'] = ','.join({dev_type for dev_type, _, _ in self.__my_resources.devices()}) - for dev_type, dev_name_list in task.resources_to_use().devices.items(): - for i, dev_name in enumerate(dev_name_list): - env[f'LBDEV_TYPE{i}'] = dev_type - env[f'LBDEV_NAME{i}'] = dev_name - env[f'LBDEV_TAGS{i}'] = ','.join(f'{tag_name}={tag_val}' for tag_name, tag_val in self.__my_device_tags.get(dev_type, {}).get(dev_name, {}).items()) - - # we do NOT set all attribs to env - just a frame list can easily hit proc env size limit - for aname, aval in task.job_definition().attributes().items(): - if aname.startswith('_'): # skip attributes starting with _ - continue - # TODO: THINK OF A BETTER LOGIC ! - if isinstance(aval, (str, int, float)): - env[f'LBATTR_{aname}'] = str(aval) - - if self.__extra_files_base_dir is not None: - env['LB_EF_ROOT'] = self.__extra_files_base_dir - try: - #with open(self.get_log_filepath('output', task.invocation_id()), 'a') as stdout: - # with open(self.get_log_filepath('error', task.invocation_id()), 'a') as stderr: - # TODO: proper child process priority adjustment should be done, for now it's implemented in constructor. - self.__running_process_start_time = time.time() - - self.__running_process: asyncio.subprocess.Process = await resolver.create_process( - resolver_arguments, - args, - extra_env=env, - resources_to_use=task.resources_to_use(), - ) - except Exception as e: - self.__logger.exception('task creation failed with error: %s' % (repr(e),)) - raise - - self.__running_task = task - self.__running_awaiter = asyncio.create_task(self._awaiter()) - self.__running_task_progress = 0 - if self.__worker_id is not None: # TODO: gracefully handle connection fails here \/ - assert self.__pool_address is not None - self.__where_to_report = AddressChain.join_address((self.__pool_address, report_to)) - with WorkerPoolControlClient.get_worker_pool_control_client(self.__pool_address, self.__message_processor) as wpclient: # type: WorkerPoolControlClient - await wpclient.report_state(self.__worker_id, WorkerState.BUSY) - else: - self.__where_to_report = report_to - - # TODO: we must keep track of _awaiter, that it's not dead. - # Either make a global watchdog task - # Or wrap the whole _awaiter in try and catch errors within itself - - # callback awaiter - async def _awaiter(self): - stdout_path = self.get_log_filepath('output', self.__running_task.invocation_id()) - stderr_path = self.get_log_filepath('error', self.__running_task.invocation_id()) - os.makedirs(os.path.dirname(stdout_path), exist_ok=True) - os.makedirs(os.path.dirname(stderr_path), exist_ok=True) - async with aiofiles.open(stdout_path, 'wb') as stdout: - async with aiofiles.open(stderr_path, 'wb') as stderr: - async def _flush(): - await asyncio.sleep(1) # ensure to flush every 1 second - await stdout.flush() - await stderr.flush() - - await stdout.write(datetime.datetime.now().strftime('[SYS][%d.%m.%y %H:%M:%S] task initialized\n').encode('UTF-8')) - - progress_reporting_task = None - minimum_progress_reporting_interval = await self.__config.get_option('minimum_progress_reporting_interval', 1.0) - last_progress_reported_timestamp = time.monotonic() - minimum_progress_reporting_interval - last_progress_reported = None - last_progress_attempted_to_report = None - - tasks_to_wait = {} - try: - rout_task = asyncio.create_task(self.__running_process.stdout.readline()) - rerr_task = asyncio.create_task(self.__running_process.stderr.readline()) - done_task = asyncio.create_task(self.__running_process.wait()) - flush_task = asyncio.create_task(_flush()) - tasks_to_wait = {rout_task, rerr_task, done_task, flush_task} - while len(tasks_to_wait) != 0: - done, tasks_to_wait = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED) - if rout_task in done: - buff_line = rout_task.result() - progress = self.__running_task.job_definition().match_stdout_progress(buff_line) - if progress is not None: - self.__running_task_progress = progress - if buff_line != b'': # this can only happen at eof - await stdout.write(datetime.datetime.now().strftime('[OUT][%H:%M:%S] ').encode('UTF-8') + buff_line) - rout_task = asyncio.create_task(self.__running_process.stdout.readline()) - tasks_to_wait.add(rout_task) - if rerr_task in done: - buff_line = rerr_task.result() - progress = self.__running_task.job_definition().match_stderr_progress(buff_line) - if progress is not None: - self.__running_task_progress = progress - if buff_line != b'': # this can only happen at eof - message = datetime.datetime.now().strftime('[ERR][%H:%M:%S] ').encode('UTF-8') + buff_line - await asyncio.gather( - stderr.write(message), - stdout.write(message) - ) - rerr_task = asyncio.create_task(self.__running_process.stderr.readline()) - tasks_to_wait.add(rerr_task) - - # check if previous progress reporting task finished - if progress_reporting_task is not None and progress_reporting_task.done(): - try: - await progress_reporting_task - except MessageTransferError as e: - self.__logger.warning('failed report invocation progress cuz of: %s', e) - except Exception as e: - self.__logger.warning('failed report invocation progress, unexpected error: %s', e) - else: - last_progress_reported = last_progress_attempted_to_report - progress_reporting_task = None - last_progress_reported_timestamp = time.monotonic() - - # report progress if can - if last_progress_reported != self.__running_task_progress \ - and progress_reporting_task is None \ - and time.monotonic() - last_progress_reported_timestamp > minimum_progress_reporting_interval: - progress_reporting_task = asyncio.create_task(self.__helper_report_progress(self.running_invocation().invocation_id(), self.__running_task_progress)) - last_progress_attempted_to_report = self.__running_task_progress - - if flush_task in done and not done_task.done(): - flush_task = asyncio.create_task(_flush()) - tasks_to_wait.add(flush_task) - await stdout.write(datetime.datetime.now().strftime('[SYS][%d.%m.%y %H:%M:%S] task finished\n').encode('UTF-8')) - except asyncio.CancelledError: - self.__logger.debug('task awaiter was cancelled') - for task in tasks_to_wait: - task.cancel() - raise - finally: - # safer to wait for existing progress reporting task than cancel it - # as cancelling may disrupt network protocol and cause timeout waiting on scheduler side - if progress_reporting_task is not None: - try: - await progress_reporting_task - except MessageTransferError as e: - self.__logger.warning('failed report invocation progress cuz of: %s', e) - except Exception as e: - self.__logger.warning('failed report invocation progress, unexpected error: %s', e) - progress_reporting_task = None - # report to the pool - if self.__worker_id is not None: - try: - assert self.__pool_address is not None - with WorkerPoolControlClient.get_worker_pool_control_client(self.__pool_address, self.__message_processor) as wpclient: # type: WorkerPoolControlClient - await wpclient.report_state(self.__worker_id, WorkerState.IDLE) - except (Exception, asyncio.CancelledError): - self.__logger.error('failed to report task cancellation to worker pool. stopping worker') - self.stop() - - await self.__running_process.wait() - await self.task_finished() - - async def __helper_report_progress(self, invocation_id: int, progress: float): - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__where_to_report, self.__message_processor) as client: # type: SchedulerWorkerControlClient - await client.report_invocation_progress(invocation_id, progress) - - def is_task_running(self) -> bool: - return self.__running_task is not None - - def running_invocation(self) -> Optional[Invocation]: - return self.__running_task - - async def deliver_invocation_message(self, destination_invocation_id: int, destination_addressee: str, source_invocation_id: Optional[int], message_body: bytes, addressee_timeout: float = 90.0): - """ - deliver message to task - - the idea is to deliver ONLY when message is waited for. - so queues are added/removed by receiver, not by this deliver method - current impl is NOT thread safe, it relies on async to separate important regions - """ - while True: - # while we wait - invocation MAY change. - running_invocation = self.running_invocation() - if running_invocation is None or destination_invocation_id != running_invocation.invocation_id(): - raise InvocationMessageWrongInvocationId() - - while destination_addressee not in self.__worker_task_comm_queues: - wait_start_timestamp = time.time() - await asyncio.sleep(0.05) # we MOST LIKELY are already waiting for this, so timeout occurs - addressee_timeout -= time.time() - wait_start_timestamp - if addressee_timeout <= 0: - raise InvocationMessageAddresseeTimeout() - # important to keep checking if invocation was changed, - # and important to have no awaits (no interruptions) between check and enqueueing - running_invocation = self.running_invocation() - if running_invocation is None or destination_invocation_id != running_invocation.invocation_id(): - raise InvocationMessageWrongInvocationId() - - queue = self.__worker_task_comm_queues[destination_addressee] - - if not queue.empty(): - # need to return control to loop in case 2 deliver_invocation_message calls happen to happen at the same time, - # and one is stuck in the loop of upper while being satisfied, but queue not empty already - await asyncio.sleep(0.01) - continue - queue.put_nowait((source_invocation_id, message_body)) - queue.put_nowait(()) - break - - async def worker_task_addressee_wait(self, addressee: str, timeout: float = 30) -> Tuple[int, bytes]: - """ - wait for a data message to addressee to be delivered - - :returns: sender invocation id, message body - """ - if self.__task_switching_event.is_set(): - self.__logger.warning('cannot wait for invocation message when task is being cancelled') - raise InvocationCancelled() - - # get ref to queues, so if it's replaced under us we stay consistent - queues = self.__worker_task_comm_queues - # TODO: (j) need tests for multiple waits on SAME addressee at the same time - if addressee not in queues: - queues[addressee] = asyncio.Queue() - - cancel_event_waiter = asyncio.create_task(self.__task_switching_event.wait()) - queue_getter = asyncio.create_task(queues[addressee].get()) - value = None - try: - done, pend = await asyncio.wait([queue_getter, cancel_event_waiter], timeout=timeout, return_when=asyncio.FIRST_COMPLETED) - if queue_getter in done: - value = queue_getter.result() - else: - queue_getter.cancel() - if cancel_event_waiter in done: - # note - at this point both tasks are done or cancelled - raise InvocationCancelled() - else: - cancel_event_waiter.cancel() - # check for timeout - if len(done) == 0: - raise asyncio.TimeoutError() - - assert value is not None, 'internal logic error, value cannot be None here' - - # value = await asyncio.wait_for(queues[addressee].get(), timeout=timeout) - assert queues[addressee].get_nowait() == () - # this way above we ensure one single deliver_task deliver to one single addressee_wait - finally: - if queues[addressee].empty(): # TODO: see TODO (j) above - queues.pop(addressee) - - return value - - def is_stopping(self) -> bool: - """ - returns True is stop was called on worker, - so worker is closed or in the process of closing - """ - return self.__stopped - - async def cancel_task(self): - async with self.__task_changing_state_lock, event_set_context(self.__task_switching_event): - self.__logger.debug('cancel_task: task_change_state locks acquired') - if self.__running_process is None: - return - self.__logger.info('cancelling running task') - self.__running_awaiter.cancel() - cancelling_awaiter = self.__running_awaiter - self.__running_awaiter = None - - await kill_process_tree(self.__running_process) - self.__running_task.finish(None, time.time() - self.__running_process_start_time) - - self.__running_process._transport.close() # sometimes not closed straight away transport ON EXIT may cause exceptions in __del__ that event loop is closed - - # report to scheduler that cancel was a success - self.__logger.info(f'reporting cancel back to {self.__where_to_report}') - - proc_stdout_filepath = self.get_log_filepath('output', self.__running_task.invocation_id()) - proc_stderr_filepath = self.get_log_filepath('error', self.__running_task.invocation_id()) - - # we want to append worker's message that job was killed - try: - message = datetime.datetime.now().strftime('\n[WORKER][%d.%m.%y %H:%M:%S] ').encode('UTF-8') + b'killed by worker.\n' - async with aiofiles.open(proc_stdout_filepath, 'ab') as stdout, \ - aiofiles.open(proc_stderr_filepath, 'ab') as stderr: - await asyncio.gather( - stderr.write(message), - stdout.write(message) - ) - except Exception as e: - self.__logger.warning("failed to append worker message to the logs") - - try: - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__where_to_report, self.__message_processor) as client: # type: SchedulerWorkerControlClient - await client.report_task_canceled(self.__running_task, - proc_stdout_filepath, - proc_stderr_filepath) - except Exception as e: - self.__logger.exception(f'could not report cuz of {e}') - except: - self.__logger.exception('could not report cuz i have no idea') - # end reporting - - try: - await self.delete_logs(self.__running_task.invocation_id()) - except OSError: - self.__logger.exception("failed to delete logs, ignoring") - - self.__running_task = None - self.__worker_task_comm_queues = {} - self.__running_process = None - self.__where_to_report = None - self.__running_task_progress = None - await self._cleanup_extra_files() - - await asyncio.wait((cancelling_awaiter,)) # ensure everything is done before we proceed - - # stop ourselves if we are a small task helper - if self.__singleshot: - self.stop() - - def task_status(self) -> Optional[float]: - return self.__running_task_progress - - async def task_finished(self): - """ - is called when current process finishes - :return: - """ - async with self.__task_changing_state_lock, event_set_context(self.__task_switching_event): - self.__logger.debug('task_finished: task_change_state locks acquired') - if self.__running_process is None: - self.__logger.warning('task_finished called, but there is no running task. This can only normally happen if a task_cancel happened the same moment as finish.') - return - self.__logger.info('task finished') - process_exit_code = await self.__running_process.wait() - self.__running_task.finish(process_exit_code, time.time() - self.__running_process_start_time) - - # report to scheduler - self.__logger.info(f'reporting done back to {self.__where_to_report}') - try: - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__where_to_report, self.__message_processor) as client: # type: SchedulerWorkerControlClient - await client.report_task_done(self.__running_task, - self.get_log_filepath('output', self.__running_task.invocation_id()), - self.get_log_filepath('error', self.__running_task.invocation_id())) - except Exception as e: - self.__logger.exception(f'could not report cuz of {e}') - except: - self.__logger.exception('could not report cuz i have no idea') - # end reporting - self.__logger.debug(f'done reporting done back to {self.__where_to_report}') - - try: - await self.delete_logs(self.__running_task.invocation_id()) - except OSError: - self.__logger.exception("failed to delete logs, ignoring") - - self.__where_to_report = None - self.__running_task = None - self.__worker_task_comm_queues = {} - self.__running_process = None - self.__previous_notrunning_awaiter = self.__running_awaiter # this is JUST so task is not GCd - self.__running_awaiter = None # TODO: lol, this function can be called from awaiter, and if we hand below - awaiter can be gcd, and it's all fucked - self.__running_task_progress = None - await self._cleanup_extra_files() - - # stop ourselves if we are a small task helper - if self.__singleshot: - self.stop() - - async def _cleanup_extra_files(self): - """ - cleanup extra files transfered with the task - :return: - """ - if self.__extra_files_base_dir is None: - return - try: - shutil.rmtree(self.__extra_files_base_dir) - except: - self.__logger.exception('could not cleanup extra files') - - # - # simply ping scheduler once in a while - async def scheduler_pinger(self): - """ - ping scheduler once in a while. if it misses too many pings - close worker and wait for new broadcasts - :return: - """ - - async def _reintroduce_ourself(): - for attempt in range(5): - self.__logger.debug(f'trying to reintroduce myself, attempt: {attempt + 1}') - metadata = WorkerMetadata(get_hostname()) - try: - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient - assert self.__my_addr_for_scheduler is not None - addr = self.__my_addr_for_scheduler - self.__logger.debug('saying bye') - await client.say_bye(addr) - self.__logger.debug('cancelling task') - await self.cancel_task() - self.__logger.debug('saying hello') - self.__scheduler_db_uid = await client.say_hello(addr, self.__worker_type, self.__my_resources, metadata) - self.__logger.debug('reintroduce done') - break - except Exception: - self.__logger.exception('failed to reintroduce myself. sleeping a bit and retrying') - await asyncio.sleep(10) - else: # failed to reintroduce. consider that something is wrong with the network, stop - self.__logger.error('failed to reintroduce myself. assuming network problems, exiting') - self.stop() - - exit_wait = asyncio.create_task(self.__components_stop_event.wait()) - while True: - done, pend = await asyncio.wait((exit_wait, ), timeout=self.__ping_interval, return_when=asyncio.FIRST_COMPLETED) - if exit_wait in done: - await exit_wait - break - #await asyncio.sleep(self.__ping_interval) - if self.__ping_missed_threshold == 0: - continue - # Here we are locking to prevent unexpected task state changes while checking for state inconsistencies - async with self.__task_changing_state_lock: - self.__logger.debug('pinger: task_change_state locks acquired') - try: - self.__logger.debug('pinging scheduler') - with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient - result = await client.ping(self.__my_addr_for_scheduler) - self.__logger.debug(f'scheduler pinged: sees me as {result}') - except MessageTransferError as mte: - self.__logger.error('ping message delivery failed') - result = None - except Exception as e: - self.__logger.exception('unexpected exception happened') - result = None - task_running = self.is_task_running() - - if result is None: # this means EOF - self.__ping_missed += 1 - self.__logger.info(f'server ping missed. total misses: {self.__ping_missed}') - if self.__ping_missed >= self.__ping_missed_threshold: - # assume scheruler down, drop everything and look for another scheruler - self.stop() - return - - if result in (WorkerState.OFF, WorkerState.UNKNOWN): - # something is wrong, lets try to reintroduce ourselves. - # Note that we can be sure that there cannot be race conditions here: - # pinger starts working always AFTER hello, OR it saz hello itself. - # and scheduler will immediately switch worker state on hello, so ping coming after confirmed hello will ALWAYS get newer state - self.__logger.warning(f'scheduler replied it thinks i\'m {result.name}. canceling tasks if any and reintroducing myself') - await _reintroduce_ourself() - elif result == WorkerState.BUSY and not task_running: - # Note: the order is: - # - sched sets worker to INVOKING - # - shced sends "task" - # - worker receives task, sets is_task_running - # - worker answers to sched - # - sched sets worker to BUSY - # and when finished: - # - worker reports done | - # - sched sets worker to IDLE | under __task_changing_state_lock - # - worker unsets is_task_running | - # so there is no way it can be not task_running AND sched state busy. - # if it is - it must be an error - self.__logger.warning(f'scheduler replied it thinks i\'m BUSY, but i\'m free, so something is inconsistent. resolving by reintroducing myself') - await _reintroduce_ourself() - elif result == WorkerState.IDLE and task_running: - # Note from scheme above - this is not possible, - # the only period where scheduler can think IDLE while is_task_running set is in __task_changing_state_lock-ed area - # but we aquired sched state and our is_task_running above inside that __task_changing_state_lock - self.__logger.warning(f'scheduler replied it thinks i\'m IDLE, but i\'m doing a task, so something is inconsistent. resolving by reintroducing myself') - await _reintroduce_ourself() - elif result == WorkerState.ERROR: - # currently the only way it can be error is because of shitty network - # ideally here we would check ourselves - # but there's nothing to check right now - self.__logger.warning('scheduler replied it thinks i\'m ERROR, but i\'m doing fine. probably something is wrong with the network. waiting for scheduler to resolve the problem') - # no we don't reintroduce - error state on scheduler side just means he won't give us tasks for now - # and since error is most probably due to network - it will either resolve itself, or there is no point reintroducing if connection cannot be established anyway - elif result is not None: - self.__ping_missed = 0 - - def worker_message_address(self) -> DirectAddress: - if self.__message_address is None: - raise RuntimeError('cannot get listening address of a non started worker') - - return self.__message_address diff --git a/src/lifeblood/worker_core.py b/src/lifeblood/worker_core.py new file mode 100644 index 00000000..d2bc4165 --- /dev/null +++ b/src/lifeblood/worker_core.py @@ -0,0 +1,901 @@ +import random +import sys +import os +import copy +import errno +import shutil +import threading +import asyncio +import aiofiles +import psutil +import datetime +import time +import tempfile +from . import logging +from .nethelpers import get_addr_to, get_localhost, get_hostname +from .hardware_resources import HardwareResources +from .worker_metadata import WorkerMetadata +from .exceptions import WorkerNotAvailable, AlreadyRunning, \ + InvocationMessageWrongInvocationId, InvocationMessageAddresseeTimeout, InvocationCancelled +from .scheduler_message_processor_client import SchedulerWorkerControlClient +from .worker_pool_message_processor_client import WorkerPoolControlClient +from .invocationjob import Invocation, InvocationEnvironment +from .config import get_config, Config +from .misc import get_unique_machine_id +from . import environment_resolver +from .enums import WorkerType, WorkerState, ProcessPriorityAdjustment +from .paths import log_path +from .process_utils import kill_process_tree +from .misc import event_set_context +from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor +from .net_messages.address import AddressChain, DirectAddress +from .net_messages.exceptions import MessageTransferError +from .defaults import worker_start_port as default_worker_start_port + +from .worker_runtime_pythonpath import lifeblood_connection +import inspect + +from typing import Callable, Dict, Optional, Tuple + + +is_posix = not sys.platform.startswith('win') + + +class WorkerCore: + def __init__(self, scheduler_addr: AddressChain, *, + child_priority_adjustment: ProcessPriorityAdjustment = ProcessPriorityAdjustment.NO_CHANGE, + worker_type: WorkerType = WorkerType.STANDARD, + config: Optional[Config] = None, # TODO: this should be replaced with config provider with a fixed interface + singleshot: bool = False, + scheduler_ping_interval: float = 10, + scheduler_ping_miss_threshold: int = 6, + worker_id: Optional[int] = None, + pool_address: Optional[AddressChain] = None, + message_processor_factory: Callable[["WorkerCore", Tuple[str, int]], TcpCommandMessageProcessor], + worker_invocation_protocol_factory: Callable[["WorkerCore"], asyncio.StreamReaderProtocol], + ): + """ + + :param scheduler_addr: + :param worker_type: + :param singleshot: + """ + self.__config = config or get_config('worker') + self.__logger = logging.get_logger('worker') + self.log_root_path: str = '' + for self.log_root_path in (os.path.expandvars(self.__config.get_option_noasync('worker.logpath', log_path('invocations', 'worker', ensure_path_exists=False))), + os.path.join(tempfile.gettempdir(), 'lifeblood', 'worker_logs')): + logs_ok = True + try: + if not os.path.exists(self.log_root_path): + os.makedirs(self.log_root_path, exist_ok=True) + except PermissionError: + logs_ok = False + except OSError as e: + if e.errno == errno.EACCES: + logs_ok = False + logs_ok = logs_ok and os.access(self.log_root_path, os.W_OK) + if logs_ok: + break + self.__logger.warning(f'could not use location {self.log_root_path} for logs, trying another...') + else: + raise RuntimeError('could not initialize logs directory') + self.__logger.info(f'using {self.log_root_path} for invocation logs') + + self.__status = {} + self.__scheduler_db_uid: int = 0 # unsigned 64bit int + self.__running_process: Optional[asyncio.subprocess.Process] = None + self.__running_process_start_time: float = 0 + self.__running_task: Optional[Invocation] = None + self.__running_task_progress: Optional[float] = None + self.__running_awaiter = None + self.__previous_notrunning_awaiter = None # here we will temporarily save running_awaiter before it is set to None again when task canceled or finished, to avoid task being GCd while in work + self.__message_processor: Optional[TcpCommandMessageProcessor] = None + self.__message_processor_factory = message_processor_factory + self.__worker_invocation_protocol_factory = worker_invocation_protocol_factory + self.__local_invocation_server: Optional[asyncio.Server] = None + self.__local_invocation_server_address_string: str = '' + + self.__local_shared_dir = self.__config.get_option_noasync("local_shared_dir_path", os.path.join(tempfile.gettempdir(), 'lifeblood_worker', 'shared')) + + # resources + config_resources = self.__config.get_option_noasync('resources') + config_devices = self.__config.get_option_noasync('devices') + if config_resources is not None: # schema check + if not isinstance(config_resources, dict): + raise RuntimeError('resources config section must be a mapping') + else: + config_resources = {} + if config_devices is not None: # schema check + if not isinstance(config_devices, dict): + raise RuntimeError('devices config section must be a mapping') + for key, val in config_devices.items(): + if not isinstance(key, str): + raise RuntimeError('devices config (keys) must be strings') + if not isinstance(val, dict): + raise RuntimeError('devices config .<> (values) config section must be a mapping') + for key_name, val_devname_data in val.items(): + if not isinstance(key_name, str): + raise RuntimeError('devices config . (keys) must be strings') + if not isinstance(val_devname_data, dict): + raise RuntimeError('devices config ..<> (values) section must be a mapping') + if 'resources' in val_devname_data: + for key_resname, val_resval in val_devname_data['resources'].items(): + if not isinstance(key_resname, str): + raise RuntimeError('devices config ..resources. (keys) must be strings') + if not isinstance(val_resval, (int, float, str)): + raise RuntimeError('devices config ..resources..<> (values) must be ints, floats or special strings like "32G"') + if 'tags' in val_devname_data: + for key_tagname, val_tagval in val_devname_data['tags'].items(): + if not isinstance(key_tagname, str): + raise RuntimeError('devices config ..tags. (keys) must be strings') + if not isinstance(val_tagval, (int, float, str)): + raise RuntimeError('devices config ..tags..<> (values) must be ints, floats or strings') + else: + config_devices = {} + self.__my_resources = HardwareResources( + hwid=get_unique_machine_id() if self.__config.get_option_noasync('worker.override_hwid') is None else self.__config.get_option_noasync('worker.override_hwid'), + resources={ + 'cpu_count': psutil.cpu_count(), + 'cpu_mem': psutil.virtual_memory().total, + **config_resources + }, + devices=[(dev_type, dev_name, dev_stuff.get('resources', {})) for dev_type, dev_dev in config_devices.items() for dev_name, dev_stuff in dev_dev.items()], + ) + self.__my_device_tags = { + dev_type: { + dev_name: {tag_name: tag_val for tag_name, tag_val in dev_stuff.get('tags', {}).items()} for dev_name, dev_stuff in dev_dev.items() + } for dev_type, dev_dev in config_devices.items() + } + + self.__task_changing_state_lock = asyncio.Lock() + self.__task_switching_event = asyncio.Event() # this will signal invocation message waiters to cancel what they are doing + self.__stop_lock = threading.Lock() + self.__start_lock = asyncio.Lock() # cant use threading lock in async methods - it can yeild out, and deadlock on itself + self.__where_to_report: Optional[AddressChain] = None + self.__ping_interval = scheduler_ping_interval + self.__ping_missed_threshold = scheduler_ping_miss_threshold + self.__ping_missed = 0 + self.__scheduler_addr = scheduler_addr + self.__scheduler_pinger = None + self.__components_stop_event = asyncio.Event() + self.__extra_files_base_dir = None + self.__my_addr_for_scheduler: Optional[AddressChain] = None + self.__message_address: Optional[DirectAddress] = None + self.__worker_id = worker_id + self.__pool_address: Optional[AddressChain] = pool_address + if self.__worker_id is None and self.__pool_address is not None \ + or self.__worker_id is not None and self.__pool_address is None: + raise RuntimeError('pool_address must be given together with worker_id') + + self.__worker_task_comm_queues: Dict[str, asyncio.Queue] = {} + + self.__worker_type: WorkerType = worker_type + self.__singleshot: bool = singleshot or worker_type == WorkerType.SCHEDULER_HELPER + + self.__child_priority_adjustment = child_priority_adjustment + # this below is a placeholder solution. the easiest way to implement priority lowering without testing on different platrofms + if self.__child_priority_adjustment == ProcessPriorityAdjustment.LOWER: + if sys.platform.startswith('win'): + assert hasattr(psutil, 'BELOW_NORMAL_PRIORITY_CLASS') + psutil.Process().nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) + else: + psutil.Process().nice(10) + + # deploy a copy of runtime module somewhere in temp + rtmodule_code = inspect.getsource(lifeblood_connection) + + filepath = os.path.join(tempfile.gettempdir(), 'lifeblood', 'lifeblood_runtime', 'lifeblood_connection.py') + os.makedirs(os.path.dirname(filepath), exist_ok=True) + existing_code = None + if os.path.exists(filepath): + with open(filepath, 'r') as f: + existing_code = f.read() + + if existing_code != rtmodule_code: + with open(filepath, 'w') as f: + f.write(rtmodule_code) + self.__rt_module_dir = os.path.dirname(filepath) + + self.__stopping_waiters = [] + self.__finished = asyncio.Event() + self.__started = False + self.__started_event = asyncio.Event() + self.__stopped = False + + def message_processor(self) -> TcpCommandMessageProcessor: + return self.__message_processor + + def scheduler_message_address(self) -> AddressChain: + return self.__scheduler_addr + + async def start(self): + if self.__started: + return + if self.__stopped: + raise RuntimeError('already stopped, cannot start again') + + async with self.__start_lock: + abort_start = False + + # start local server for invocation api connections + loop = asyncio.get_event_loop() + localhost = get_localhost() + localport_start = 10101 + localport_end = 11111 + localport = None + for _ in range(localport_end - localport_start): # big but finite + localport = random.randint(localport_start, localport_end) + try: + self.__local_invocation_server = await loop.create_server( + lambda: self.__worker_invocation_protocol_factory(self), + localhost, + localport + ) + break + except OSError as e: + if e.errno != errno.EADDRINUSE: + raise + continue + else: + raise RuntimeError('could not find an opened port!') + self.__local_invocation_server_address_string = f'{localhost}:{localport}' + + # start message processor + my_ip = get_addr_to(self.__scheduler_addr.split_address()[0]) + my_port = default_worker_start_port() + for i in range(1024): # big but finite + try: + self.__message_processor = self.__message_processor_factory(self, (my_ip, my_port)) + await self.__message_processor.start() + break + except OSError as e: + if e.errno != errno.EADDRINUSE: + raise + my_port += 1 + continue + else: + raise RuntimeError('could not find an opened port!') + + self.__message_address = DirectAddress.from_host_port(my_ip, my_port) + + # now report our address to the scheduler + metadata = WorkerMetadata(get_hostname()) + try: + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient + # re-normalize addresses + self.__scheduler_addr, self.__my_addr_for_scheduler = await client.get_normalized_addresses() + self.__scheduler_db_uid = await client.say_hello(self.__my_addr_for_scheduler, self.__worker_type, self.__my_resources, metadata) + except MessageTransferError as e: + self.__logger.error('error connecting to scheduler during start') + abort_start = True + # + # and report to the pool + try: + if self.__worker_id is not None: + assert self.__pool_address is not None + with WorkerPoolControlClient.get_worker_pool_control_client(self.__pool_address, self.__message_processor) as wpclient: # type: WorkerPoolControlClient + await wpclient.report_state(self.__worker_id, WorkerState.IDLE) + except ConnectionError as e: + self.__logger.error('error connecting to worker pool during start') + abort_start = True + + self.__scheduler_pinger = asyncio.create_task(self.scheduler_pinger()) + self.__started = True + self.__started_event.set() + if abort_start: + self.__logger.error('error during stating worker, aborting!') + self.stop() + else: + self.__logger.info('worker started') + + def is_started(self): + return self.__started + + def wait_till_starts(self): # we can await this function cuz it returns a future... + return self.__started_event.wait() + + def stop(self): + async def _send_byebye(): + try: + self.__logger.debug('saying bye to scheduler') + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient + await client.say_bye(self.__my_addr_for_scheduler) + except MessageTransferError: # if scheduler or route is down + self.__logger.info('couldn\'t say bye to scheduler as it seem to be down') + except Exception: + self.__logger.exception('couldn\'t say bye to scheduler for unknown reason') + + if not self.__started or self.__stopped: + return + with self.__stop_lock: # NOTE: there is literally no threading in worker, so this is excessive + self.__logger.info('STOPPING WORKER') + self.__components_stop_event.set() + + async def _finalizer(): + await self.__scheduler_pinger # to ensure pinger stops and won't try to contact scheduler any more + await self.cancel_task() # then we cancel task, here we still can report it to the scheduler. + # no new tasks will be picked up cuz __stopped is already set + self.__local_invocation_server.close() + await _send_byebye() # saying bye, don't bother us. (some delayed comms may still come through to the __server + await self.__local_invocation_server.wait_closed() + self.__message_processor.stop() + await self.__message_processor.wait_till_stops() + self.__logger.info('message processor stopped') + + self.__stopping_waiters.append(asyncio.create_task(_finalizer())) + self.__finished.set() + self.__stopped = True + + async def wait_till_stops(self): + # if self.__scheduler_pinger is not None: + # #try: + # await self.__scheduler_pinger + # #except asyncio.CancelledError: + # # self.__logger.debug('wait_to_finished: scheduler_pinger was cancelled') + # # #raise + # self.__scheduler_pinger = None + # await self.__server.wait_closed() + await self.__finished.wait() + self.__logger.info('server closed') + await self.__scheduler_pinger + self.__logger.info('pinger closed') + for waiter in self.__stopping_waiters: + await waiter + + def get_log_filepath(self, level, invocation_id: int = None): # TODO: think of a better, more generator-style way of returning logs + if self.__running_task is None and invocation_id is None: + return os.path.join(self.log_root_path, f'db_{self.__scheduler_db_uid:016x}', 'common', level) + else: + return os.path.join(self.log_root_path, f'db_{self.__scheduler_db_uid:016x}', 'invocations', str(invocation_id or self.__running_task.invocation_id()), level) + + async def delete_logs(self, invocation_id: int): + self.__logger.debug(f'removing logs for {invocation_id}') + path = os.path.join(self.log_root_path, f'db_{self.__scheduler_db_uid:016x}', 'invocations', str(invocation_id or self.__running_task.invocation_id())) + await asyncio.get_event_loop().run_in_executor(None, shutil.rmtree, path) # assume that deletion MAY take time, so allow util tasks to be processed while we wait + + async def run_task(self, task: Invocation, report_to: AddressChain): + if self.__stopped: + raise WorkerNotAvailable() + self.__logger.debug(f'locks are {self.__task_changing_state_lock.locked()}') + async with self.__task_changing_state_lock: + self.__logger.debug('run_task: task_change_state locks acquired') + # we must ensure picking up and finishing tasks is in critical section + assert len(task.job_definition().args()) > 0 + if self.__running_process is not None: + raise AlreadyRunning('Task already in progress') + + # prepare logging + self.__logger.info(f'running task {task}') + + # save external files + self.__extra_files_base_dir = None + extra_files_map: Dict[str, str] = {} + if len(task.job_definition().extra_files()) > 0: + self.__extra_files_base_dir = tempfile.mkdtemp(prefix='lifeblood_efs_') # TODO: add base temp dir to config + self.__logger.debug(f'creating extra file temporary dir at {self.__extra_files_base_dir}') + for exfilepath, exfiledata in task.job_definition().extra_files().items(): + self.__logger.info(f'saving extra job file {exfilepath}') + exfilepath_parts = exfilepath.split('/') + tmpfilepath = os.path.join(self.__extra_files_base_dir, *exfilepath_parts) + os.makedirs(os.path.dirname(tmpfilepath), exist_ok=True) + with open(tmpfilepath, 'w' if isinstance(exfiledata, str) else 'wb') as f: + f.write(exfiledata) + extra_files_map[exfilepath] = tmpfilepath + + # check args for extra file references + if len(task.job_definition().extra_files()) > 0: + args = [] + for arg in task.job_definition().args(): + if isinstance(arg, str) and arg.startswith(':/') and arg[2:] in task.job_definition().extra_files(): + args.append(extra_files_map[arg[2:]]) + else: + args.append(arg) + else: + args = task.job_definition().args() + + try: + if task.job_definition().environment_resolver_arguments() is None: + resolver = environment_resolver.get_resolver(self.__config.get_option_noasync('default_env_wrapper.name', 'TrivialEnvironmentResolver')) + resolver_arguments = self.__config.get_option_noasync('default_env_wrapper.arguments', {}) + else: + env_res_args = task.job_definition().environment_resolver_arguments() + resolver = env_res_args.get_resolver() + resolver_arguments = env_res_args.arguments() + except environment_resolver.ResolutionImpossibleError as e: + self.__logger.error(f'cannot run the task: Unable to resolve environment: {str(e)}') + raise + + # TODO: resolver args get_environment() acually does resolution so should be renamed to like resolve_environment() + # Environment's resolve() actually just expands and merges everything, so naming it "resolve" is misleading next to EnvironmentResolver + + env = copy.deepcopy(task.job_definition().env() or InvocationEnvironment()) + + env.prepend('PYTHONPATH', self.__rt_module_dir) + env['LIFEBLOOD_RUNTIME_IID'] = task.invocation_id() + env['LIFEBLOOD_RUNTIME_TID'] = task.task_id() + env['LIFEBLOOD_RUNTIME_SCHEDULER_ADDR'] = self.__local_invocation_server_address_string + + env['LBDEV_TYPES'] = ','.join({dev_type for dev_type, _, _ in self.__my_resources.devices()}) + for dev_type, dev_name_list in task.resources_to_use().devices.items(): + for i, dev_name in enumerate(dev_name_list): + env[f'LBDEV_TYPE{i}'] = dev_type + env[f'LBDEV_NAME{i}'] = dev_name + env[f'LBDEV_TAGS{i}'] = ','.join(f'{tag_name}={tag_val}' for tag_name, tag_val in self.__my_device_tags.get(dev_type, {}).get(dev_name, {}).items()) + + # we do NOT set all attribs to env - just a frame list can easily hit proc env size limit + for aname, aval in task.job_definition().attributes().items(): + if aname.startswith('_'): # skip attributes starting with _ + continue + # TODO: THINK OF A BETTER LOGIC ! + if isinstance(aval, (str, int, float)): + env[f'LBATTR_{aname}'] = str(aval) + + if self.__extra_files_base_dir is not None: + env['LB_EF_ROOT'] = self.__extra_files_base_dir + try: + #with open(self.get_log_filepath('output', task.invocation_id()), 'a') as stdout: + # with open(self.get_log_filepath('error', task.invocation_id()), 'a') as stderr: + # TODO: proper child process priority adjustment should be done, for now it's implemented in constructor. + self.__running_process_start_time = time.time() + + self.__running_process: asyncio.subprocess.Process = await resolver.create_process( + resolver_arguments, + args, + extra_env=env, + resources_to_use=task.resources_to_use(), + ) + except Exception as e: + self.__logger.exception('task creation failed with error: %s' % (repr(e),)) + raise + + self.__running_task = task + self.__running_awaiter = asyncio.create_task(self._awaiter()) + self.__running_task_progress = 0 + if self.__worker_id is not None: # TODO: gracefully handle connection fails here \/ + assert self.__pool_address is not None + self.__where_to_report = AddressChain.join_address((self.__pool_address, report_to)) + with WorkerPoolControlClient.get_worker_pool_control_client(self.__pool_address, self.__message_processor) as wpclient: # type: WorkerPoolControlClient + await wpclient.report_state(self.__worker_id, WorkerState.BUSY) + else: + self.__where_to_report = report_to + + # TODO: we must keep track of _awaiter, that it's not dead. + # Either make a global watchdog task + # Or wrap the whole _awaiter in try and catch errors within itself + + # callback awaiter + async def _awaiter(self): + stdout_path = self.get_log_filepath('output', self.__running_task.invocation_id()) + stderr_path = self.get_log_filepath('error', self.__running_task.invocation_id()) + os.makedirs(os.path.dirname(stdout_path), exist_ok=True) + os.makedirs(os.path.dirname(stderr_path), exist_ok=True) + async with aiofiles.open(stdout_path, 'wb') as stdout: + async with aiofiles.open(stderr_path, 'wb') as stderr: + async def _flush(): + await asyncio.sleep(1) # ensure to flush every 1 second + await stdout.flush() + await stderr.flush() + + await stdout.write(datetime.datetime.now().strftime('[SYS][%d.%m.%y %H:%M:%S] task initialized\n').encode('UTF-8')) + + progress_reporting_task = None + minimum_progress_reporting_interval = await self.__config.get_option('minimum_progress_reporting_interval', 1.0) + last_progress_reported_timestamp = time.monotonic() - minimum_progress_reporting_interval + last_progress_reported = None + last_progress_attempted_to_report = None + + tasks_to_wait = {} + try: + rout_task = asyncio.create_task(self.__running_process.stdout.readline()) + rerr_task = asyncio.create_task(self.__running_process.stderr.readline()) + done_task = asyncio.create_task(self.__running_process.wait()) + flush_task = asyncio.create_task(_flush()) + tasks_to_wait = {rout_task, rerr_task, done_task, flush_task} + while len(tasks_to_wait) != 0: + done, tasks_to_wait = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED) + if rout_task in done: + buff_line = rout_task.result() + progress = self.__running_task.job_definition().match_stdout_progress(buff_line) + if progress is not None: + self.__running_task_progress = progress + if buff_line != b'': # this can only happen at eof + await stdout.write(datetime.datetime.now().strftime('[OUT][%H:%M:%S] ').encode('UTF-8') + buff_line) + rout_task = asyncio.create_task(self.__running_process.stdout.readline()) + tasks_to_wait.add(rout_task) + if rerr_task in done: + buff_line = rerr_task.result() + progress = self.__running_task.job_definition().match_stderr_progress(buff_line) + if progress is not None: + self.__running_task_progress = progress + if buff_line != b'': # this can only happen at eof + message = datetime.datetime.now().strftime('[ERR][%H:%M:%S] ').encode('UTF-8') + buff_line + await asyncio.gather( + stderr.write(message), + stdout.write(message) + ) + rerr_task = asyncio.create_task(self.__running_process.stderr.readline()) + tasks_to_wait.add(rerr_task) + + # check if previous progress reporting task finished + if progress_reporting_task is not None and progress_reporting_task.done(): + try: + await progress_reporting_task + except MessageTransferError as e: + self.__logger.warning('failed report invocation progress cuz of: %s', e) + except Exception as e: + self.__logger.warning('failed report invocation progress, unexpected error: %s', e) + else: + last_progress_reported = last_progress_attempted_to_report + progress_reporting_task = None + last_progress_reported_timestamp = time.monotonic() + + # report progress if can + if last_progress_reported != self.__running_task_progress \ + and progress_reporting_task is None \ + and time.monotonic() - last_progress_reported_timestamp > minimum_progress_reporting_interval: + progress_reporting_task = asyncio.create_task(self.__helper_report_progress(self.running_invocation().invocation_id(), self.__running_task_progress)) + last_progress_attempted_to_report = self.__running_task_progress + + if flush_task in done and not done_task.done(): + flush_task = asyncio.create_task(_flush()) + tasks_to_wait.add(flush_task) + await stdout.write(datetime.datetime.now().strftime('[SYS][%d.%m.%y %H:%M:%S] task finished\n').encode('UTF-8')) + except asyncio.CancelledError: + self.__logger.debug('task awaiter was cancelled') + for task in tasks_to_wait: + task.cancel() + raise + finally: + # safer to wait for existing progress reporting task than cancel it + # as cancelling may disrupt network protocol and cause timeout waiting on scheduler side + if progress_reporting_task is not None: + try: + await progress_reporting_task + except MessageTransferError as e: + self.__logger.warning('failed report invocation progress cuz of: %s', e) + except Exception as e: + self.__logger.warning('failed report invocation progress, unexpected error: %s', e) + progress_reporting_task = None + # report to the pool + if self.__worker_id is not None: + try: + assert self.__pool_address is not None + with WorkerPoolControlClient.get_worker_pool_control_client(self.__pool_address, self.__message_processor) as wpclient: # type: WorkerPoolControlClient + await wpclient.report_state(self.__worker_id, WorkerState.IDLE) + except (Exception, asyncio.CancelledError): + self.__logger.error('failed to report task cancellation to worker pool. stopping worker') + self.stop() + + await self.__running_process.wait() + await self.task_finished() + + async def __helper_report_progress(self, invocation_id: int, progress: float): + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__where_to_report, self.__message_processor) as client: # type: SchedulerWorkerControlClient + await client.report_invocation_progress(invocation_id, progress) + + def is_task_running(self) -> bool: + return self.__running_task is not None + + def running_invocation(self) -> Optional[Invocation]: + return self.__running_task + + async def deliver_invocation_message(self, destination_invocation_id: int, destination_addressee: str, source_invocation_id: Optional[int], message_body: bytes, addressee_timeout: float = 90.0): + """ + deliver message to task + + the idea is to deliver ONLY when message is waited for. + so queues are added/removed by receiver, not by this deliver method + current impl is NOT thread safe, it relies on async to separate important regions + """ + while True: + # while we wait - invocation MAY change. + running_invocation = self.running_invocation() + if running_invocation is None or destination_invocation_id != running_invocation.invocation_id(): + raise InvocationMessageWrongInvocationId() + + while destination_addressee not in self.__worker_task_comm_queues: + wait_start_timestamp = time.time() + await asyncio.sleep(0.05) # we MOST LIKELY are already waiting for this, so timeout occurs + addressee_timeout -= time.time() - wait_start_timestamp + if addressee_timeout <= 0: + raise InvocationMessageAddresseeTimeout() + # important to keep checking if invocation was changed, + # and important to have no awaits (no interruptions) between check and enqueueing + running_invocation = self.running_invocation() + if running_invocation is None or destination_invocation_id != running_invocation.invocation_id(): + raise InvocationMessageWrongInvocationId() + + queue = self.__worker_task_comm_queues[destination_addressee] + + if not queue.empty(): + # need to return control to loop in case 2 deliver_invocation_message calls happen to happen at the same time, + # and one is stuck in the loop of upper while being satisfied, but queue not empty already + await asyncio.sleep(0.01) + continue + queue.put_nowait((source_invocation_id, message_body)) + queue.put_nowait(()) + break + + async def worker_task_addressee_wait(self, addressee: str, timeout: float = 30) -> Tuple[int, bytes]: + """ + wait for a data message to addressee to be delivered + + :returns: sender invocation id, message body + """ + if self.__task_switching_event.is_set(): + self.__logger.warning('cannot wait for invocation message when task is being cancelled') + raise InvocationCancelled() + + # get ref to queues, so if it's replaced under us we stay consistent + queues = self.__worker_task_comm_queues + # TODO: (j) need tests for multiple waits on SAME addressee at the same time + if addressee not in queues: + queues[addressee] = asyncio.Queue() + + cancel_event_waiter = asyncio.create_task(self.__task_switching_event.wait()) + queue_getter = asyncio.create_task(queues[addressee].get()) + value = None + try: + done, pend = await asyncio.wait([queue_getter, cancel_event_waiter], timeout=timeout, return_when=asyncio.FIRST_COMPLETED) + if queue_getter in done: + value = queue_getter.result() + else: + queue_getter.cancel() + if cancel_event_waiter in done: + # note - at this point both tasks are done or cancelled + raise InvocationCancelled() + else: + cancel_event_waiter.cancel() + # check for timeout + if len(done) == 0: + raise asyncio.TimeoutError() + + assert value is not None, 'internal logic error, value cannot be None here' + + # value = await asyncio.wait_for(queues[addressee].get(), timeout=timeout) + assert queues[addressee].get_nowait() == () + # this way above we ensure one single deliver_task deliver to one single addressee_wait + finally: + if queues[addressee].empty(): # TODO: see TODO (j) above + queues.pop(addressee) + + return value + + def is_stopping(self) -> bool: + """ + returns True is stop was called on worker, + so worker is closed or in the process of closing + """ + return self.__stopped + + async def cancel_task(self): + async with self.__task_changing_state_lock, event_set_context(self.__task_switching_event): + self.__logger.debug('cancel_task: task_change_state locks acquired') + if self.__running_process is None: + return + self.__logger.info('cancelling running task') + self.__running_awaiter.cancel() + cancelling_awaiter = self.__running_awaiter + self.__running_awaiter = None + + await kill_process_tree(self.__running_process) + self.__running_task.finish(None, time.time() - self.__running_process_start_time) + + self.__running_process._transport.close() # sometimes not closed straight away transport ON EXIT may cause exceptions in __del__ that event loop is closed + + # report to scheduler that cancel was a success + self.__logger.info(f'reporting cancel back to {self.__where_to_report}') + + proc_stdout_filepath = self.get_log_filepath('output', self.__running_task.invocation_id()) + proc_stderr_filepath = self.get_log_filepath('error', self.__running_task.invocation_id()) + + # we want to append worker's message that job was killed + try: + message = datetime.datetime.now().strftime('\n[WORKER][%d.%m.%y %H:%M:%S] ').encode('UTF-8') + b'killed by worker.\n' + async with aiofiles.open(proc_stdout_filepath, 'ab') as stdout, \ + aiofiles.open(proc_stderr_filepath, 'ab') as stderr: + await asyncio.gather( + stderr.write(message), + stdout.write(message) + ) + except Exception as e: + self.__logger.warning("failed to append worker message to the logs") + + try: + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__where_to_report, self.__message_processor) as client: # type: SchedulerWorkerControlClient + await client.report_task_canceled(self.__running_task, + proc_stdout_filepath, + proc_stderr_filepath) + except Exception as e: + self.__logger.exception(f'could not report cuz of {e}') + except: + self.__logger.exception('could not report cuz i have no idea') + # end reporting + + try: + await self.delete_logs(self.__running_task.invocation_id()) + except OSError: + self.__logger.exception("failed to delete logs, ignoring") + + self.__running_task = None + self.__worker_task_comm_queues = {} + self.__running_process = None + self.__where_to_report = None + self.__running_task_progress = None + await self._cleanup_extra_files() + + await asyncio.wait((cancelling_awaiter,)) # ensure everything is done before we proceed + + # stop ourselves if we are a small task helper + if self.__singleshot: + self.stop() + + def task_status(self) -> Optional[float]: + return self.__running_task_progress + + async def task_finished(self): + """ + is called when current process finishes + :return: + """ + async with self.__task_changing_state_lock, event_set_context(self.__task_switching_event): + self.__logger.debug('task_finished: task_change_state locks acquired') + if self.__running_process is None: + self.__logger.warning('task_finished called, but there is no running task. This can only normally happen if a task_cancel happened the same moment as finish.') + return + self.__logger.info('task finished') + process_exit_code = await self.__running_process.wait() + self.__running_task.finish(process_exit_code, time.time() - self.__running_process_start_time) + + # report to scheduler + self.__logger.info(f'reporting done back to {self.__where_to_report}') + try: + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__where_to_report, self.__message_processor) as client: # type: SchedulerWorkerControlClient + await client.report_task_done(self.__running_task, + self.get_log_filepath('output', self.__running_task.invocation_id()), + self.get_log_filepath('error', self.__running_task.invocation_id())) + except Exception as e: + self.__logger.exception(f'could not report cuz of {e}') + except: + self.__logger.exception('could not report cuz i have no idea') + # end reporting + self.__logger.debug(f'done reporting done back to {self.__where_to_report}') + + try: + await self.delete_logs(self.__running_task.invocation_id()) + except OSError: + self.__logger.exception("failed to delete logs, ignoring") + + self.__where_to_report = None + self.__running_task = None + self.__worker_task_comm_queues = {} + self.__running_process = None + self.__previous_notrunning_awaiter = self.__running_awaiter # this is JUST so task is not GCd + self.__running_awaiter = None # TODO: lol, this function can be called from awaiter, and if we hand below - awaiter can be gcd, and it's all fucked + self.__running_task_progress = None + await self._cleanup_extra_files() + + # stop ourselves if we are a small task helper + if self.__singleshot: + self.stop() + + async def _cleanup_extra_files(self): + """ + cleanup extra files transfered with the task + :return: + """ + if self.__extra_files_base_dir is None: + return + try: + shutil.rmtree(self.__extra_files_base_dir) + except: + self.__logger.exception('could not cleanup extra files') + + # + # simply ping scheduler once in a while + async def scheduler_pinger(self): + """ + ping scheduler once in a while. if it misses too many pings - close worker and wait for new broadcasts + :return: + """ + + async def _reintroduce_ourself(): + for attempt in range(5): + self.__logger.debug(f'trying to reintroduce myself, attempt: {attempt + 1}') + metadata = WorkerMetadata(get_hostname()) + try: + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient + assert self.__my_addr_for_scheduler is not None + addr = self.__my_addr_for_scheduler + self.__logger.debug('saying bye') + await client.say_bye(addr) + self.__logger.debug('cancelling task') + await self.cancel_task() + self.__logger.debug('saying hello') + self.__scheduler_db_uid = await client.say_hello(addr, self.__worker_type, self.__my_resources, metadata) + self.__logger.debug('reintroduce done') + break + except Exception: + self.__logger.exception('failed to reintroduce myself. sleeping a bit and retrying') + await asyncio.sleep(10) + else: # failed to reintroduce. consider that something is wrong with the network, stop + self.__logger.error('failed to reintroduce myself. assuming network problems, exiting') + self.stop() + + exit_wait = asyncio.create_task(self.__components_stop_event.wait()) + while True: + done, pend = await asyncio.wait((exit_wait, ), timeout=self.__ping_interval, return_when=asyncio.FIRST_COMPLETED) + if exit_wait in done: + await exit_wait + break + #await asyncio.sleep(self.__ping_interval) + if self.__ping_missed_threshold == 0: + continue + # Here we are locking to prevent unexpected task state changes while checking for state inconsistencies + async with self.__task_changing_state_lock: + self.__logger.debug('pinger: task_change_state locks acquired') + try: + self.__logger.debug('pinging scheduler') + with SchedulerWorkerControlClient.get_scheduler_control_client(self.__scheduler_addr, self.__message_processor) as client: # type: SchedulerWorkerControlClient + result = await client.ping(self.__my_addr_for_scheduler) + self.__logger.debug(f'scheduler pinged: sees me as {result}') + except MessageTransferError as mte: + self.__logger.error('ping message delivery failed') + result = None + except Exception as e: + self.__logger.exception('unexpected exception happened') + result = None + task_running = self.is_task_running() + + if result is None: # this means EOF + self.__ping_missed += 1 + self.__logger.info(f'server ping missed. total misses: {self.__ping_missed}') + if self.__ping_missed >= self.__ping_missed_threshold: + # assume scheruler down, drop everything and look for another scheruler + self.stop() + return + + if result in (WorkerState.OFF, WorkerState.UNKNOWN): + # something is wrong, lets try to reintroduce ourselves. + # Note that we can be sure that there cannot be race conditions here: + # pinger starts working always AFTER hello, OR it saz hello itself. + # and scheduler will immediately switch worker state on hello, so ping coming after confirmed hello will ALWAYS get newer state + self.__logger.warning(f'scheduler replied it thinks i\'m {result.name}. canceling tasks if any and reintroducing myself') + await _reintroduce_ourself() + elif result == WorkerState.BUSY and not task_running: + # Note: the order is: + # - sched sets worker to INVOKING + # - shced sends "task" + # - worker receives task, sets is_task_running + # - worker answers to sched + # - sched sets worker to BUSY + # and when finished: + # - worker reports done | + # - sched sets worker to IDLE | under __task_changing_state_lock + # - worker unsets is_task_running | + # so there is no way it can be not task_running AND sched state busy. + # if it is - it must be an error + self.__logger.warning(f'scheduler replied it thinks i\'m BUSY, but i\'m free, so something is inconsistent. resolving by reintroducing myself') + await _reintroduce_ourself() + elif result == WorkerState.IDLE and task_running: + # Note from scheme above - this is not possible, + # the only period where scheduler can think IDLE while is_task_running set is in __task_changing_state_lock-ed area + # but we aquired sched state and our is_task_running above inside that __task_changing_state_lock + self.__logger.warning(f'scheduler replied it thinks i\'m IDLE, but i\'m doing a task, so something is inconsistent. resolving by reintroducing myself') + await _reintroduce_ourself() + elif result == WorkerState.ERROR: + # currently the only way it can be error is because of shitty network + # ideally here we would check ourselves + # but there's nothing to check right now + self.__logger.warning('scheduler replied it thinks i\'m ERROR, but i\'m doing fine. probably something is wrong with the network. waiting for scheduler to resolve the problem') + # no we don't reintroduce - error state on scheduler side just means he won't give us tasks for now + # and since error is most probably due to network - it will either resolve itself, or there is no point reintroducing if connection cannot be established anyway + elif result is not None: + self.__ping_missed = 0 + + def worker_message_address(self) -> DirectAddress: + if self.__message_address is None: + raise RuntimeError('cannot get listening address of a non started worker') + + return self.__message_address diff --git a/src/lifeblood/worker_invocation_protocol.py b/src/lifeblood/worker_invocation_protocol.py index 892473f8..59877420 100644 --- a/src/lifeblood/worker_invocation_protocol.py +++ b/src/lifeblood/worker_invocation_protocol.py @@ -4,15 +4,14 @@ from .attribute_serialization import deserialize_attributes from .enums import InvocationMessageResult from .exceptions import CouldNotNegotiateProtocolVersion, InvocationCancelled -from .scheduler_message_processor import SchedulerExtraControlClient, SchedulerInvocationMessageClient +from .scheduler_message_processor_client import SchedulerExtraControlClient, SchedulerInvocationMessageClient from .taskspawn import TaskSpawn from .net_messages.address_routing import RoutingImpossible from . import logging +from .worker_core import WorkerCore +from typing import Dict, Set, Sequence, Tuple -from typing import Dict, Set, Sequence, Tuple, TYPE_CHECKING -if TYPE_CHECKING: - from .worker import Worker async def read_string(reader) -> str: @@ -37,7 +36,7 @@ async def handle(self, reader: asyncio.StreamReader, writer: asyncio.StreamWrite class WorkerInvocationProtocolHandlerV10(ProtocolHandler): - def __init__(self, worker: "Worker"): + def __init__(self, worker: WorkerCore): super().__init__() self.__worker = worker self.__logger = logging.get_logger(f'worker.invoc_protocol_v{".".join(str(i) for i in self.protocol_version())}') @@ -124,7 +123,7 @@ async def comm_receive_invocation_message(self, reader: asyncio.StreamReader, wr class WorkerInvocationServerProtocol(asyncio.StreamReaderProtocol): - def __init__(self, worker: "Worker", protocol_handlers: Sequence[ProtocolHandler], limit: int = 2 ** 16): + def __init__(self, worker: WorkerCore, protocol_handlers: Sequence[ProtocolHandler], limit: int = 2 ** 16): self.__logger = logging.get_logger('worker.invoc_protocol') self.__timeout = 300.0 self.__reader = asyncio.StreamReader(limit=limit) diff --git a/src/lifeblood/worker_messsage_processor.py b/src/lifeblood/worker_message_processor.py similarity index 70% rename from src/lifeblood/worker_messsage_processor.py rename to src/lifeblood/worker_message_processor.py index 15d6bf85..7e6b110d 100644 --- a/src/lifeblood/worker_messsage_processor.py +++ b/src/lifeblood/worker_message_processor.py @@ -1,9 +1,8 @@ import os import asyncio import aiofiles -from contextlib import contextmanager from .exceptions import NotEnoughResources, ProcessInitializationError, WorkerNotAvailable, \ - InvocationMessageWrongInvocationId, InvocationMessageAddresseeTimeout, InvocationMessageError + InvocationMessageWrongInvocationId, InvocationMessageAddresseeTimeout from .environment_resolver import ResolutionImpossibleError from . import logging from . import invocationjob @@ -11,18 +10,16 @@ from .enums import WorkerPingReply, TaskScheduleStatus, InvocationMessageResult from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor from .net_messages.impl.clients import CommandJsonMessageClient -from .net_messages.address import AddressChain +from .net_messages.address import AddressChain, DirectAddress from .net_messages.messages import Message from .net_messages.impl.message_haldlers import CommandMessageHandlerBase +from .worker_core import WorkerCore - -from typing import Iterable, Optional, Tuple, TYPE_CHECKING, Union -if TYPE_CHECKING: - from .worker import Worker +from typing import Iterable, Tuple, Union class WorkerCommandHandler(CommandMessageHandlerBase): - def __init__(self, worker: "Worker"): + def __init__(self, worker: WorkerCore): super().__init__() self.__logger = logging.get_logger('worker.message_handler') self.__worker = worker @@ -203,86 +200,15 @@ async def _command_invocation_message(self, args: dict, client: CommandJsonMessa class WorkerMessageProcessor(TcpCommandMessageProcessor): - def __init__(self, worker: "Worker", listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]]], *, backlog=4096, connection_pool_cache_time=300): + def __init__( + self, + worker: WorkerCore, + listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]], DirectAddress, Iterable[DirectAddress]], + *, + backlog=4096, + connection_pool_cache_time=300 + ): super().__init__(listening_address_or_addresses, backlog=backlog, connection_pool_cache_time=connection_pool_cache_time, message_handlers=(WorkerCommandHandler(worker),)) - - -# -# Client -# - - -class WorkerControlClient: - def __init__(self, client: CommandJsonMessageClient): - self.__client = client - - @classmethod - @contextmanager - def get_worker_control_client(cls, worker_address: AddressChain, processor: TcpCommandMessageProcessor) -> "WorkerControlClient": - with processor.message_client(worker_address) as message_client: - yield WorkerControlClient(message_client) - - async def ping(self) -> Tuple[WorkerPingReply, float]: - await self.__client.send_command('ping', {}) - - reply_message = await self.__client.receive_message() - data_json = await reply_message.message_body_as_json() - return WorkerPingReply(data_json['ps']), float(data_json['pv']) - - async def give_task(self, task: invocationjob.Invocation, reply_address: Optional[AddressChain] = None) -> Tuple[TaskScheduleStatus, str, str]: - """ - if reply_address is not given - message source address will be used - """ - await self.__client.send_command('task', { - 'task': await asyncio.get_event_loop().run_in_executor(None, task.serialize_to_data), - 'reply_to': str(reply_address) if reply_address else None - }) - - reply = await (await self.__client.receive_message()).message_body_as_json() - return TaskScheduleStatus(reply['status']), reply.get('error_class', ''), reply.get('message', '') - - async def quit_worker(self): - await self.__client.send_command('quit', {}) - - await self.__client.receive_message() - - async def cancel_task(self) -> None: - await self.__client.send_command('drop', {}) - - await self.__client.receive_message() - - async def status(self): - raise NotImplementedError() - - async def get_log(self, invocation_id) -> Tuple[str, str]: - await self.__client.send_command('log', { - 'invoc_id': invocation_id - }) - - reply = await (await self.__client.receive_message()).message_body_as_json() - return str(reply['stdout']), str(reply['stderr']) - - async def send_invocation_message(self, - destination_invocation_id: int, - destination_addressee: str, - source_invocation_id: Optional[int], - message_body: bytes, - addressee_timeout: float, - overall_timeout: float) -> InvocationMessageResult: - """ - Note that this command, unlike others, does not raise, - instead it wraps errors into InvocationMessageResult - """ - await self.__client.send_command('invocation_message', { - 'dst_invoc_id': destination_invocation_id, - 'src_invoc_id': source_invocation_id, - 'addressee': destination_addressee, - 'addressee_timeout': addressee_timeout, - 'message_data_raw': message_body.decode('latin1'), - }) - - reply = await (await self.__client.receive_message(timeout=overall_timeout)).message_body_as_json() - return InvocationMessageResult(reply['result']) diff --git a/src/lifeblood/worker_message_processor_client.py b/src/lifeblood/worker_message_processor_client.py new file mode 100644 index 00000000..c1cf6b87 --- /dev/null +++ b/src/lifeblood/worker_message_processor_client.py @@ -0,0 +1,82 @@ +import asyncio +from contextlib import contextmanager +from . import invocationjob +from .enums import WorkerPingReply, TaskScheduleStatus, InvocationMessageResult +from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor +from .net_messages.impl.clients import CommandJsonMessageClient +from .net_messages.address import AddressChain + +from typing import Optional, Tuple + + +class WorkerControlClient: + def __init__(self, client: CommandJsonMessageClient): + self.__client = client + + @classmethod + @contextmanager + def get_worker_control_client(cls, worker_address: AddressChain, processor: TcpCommandMessageProcessor) -> "WorkerControlClient": + with processor.message_client(worker_address) as message_client: + yield WorkerControlClient(message_client) + + async def ping(self) -> Tuple[WorkerPingReply, float]: + await self.__client.send_command('ping', {}) + + reply_message = await self.__client.receive_message() + data_json = await reply_message.message_body_as_json() + return WorkerPingReply(data_json['ps']), float(data_json['pv']) + + async def give_task(self, task: invocationjob.Invocation, reply_address: Optional[AddressChain] = None) -> Tuple[TaskScheduleStatus, str, str]: + """ + if reply_address is not given - message source address will be used + """ + await self.__client.send_command('task', { + 'task': await asyncio.get_event_loop().run_in_executor(None, task.serialize_to_data), + 'reply_to': str(reply_address) if reply_address else None + }) + + reply = await (await self.__client.receive_message()).message_body_as_json() + return TaskScheduleStatus(reply['status']), reply.get('error_class', ''), reply.get('message', '') + + async def quit_worker(self): + await self.__client.send_command('quit', {}) + + await self.__client.receive_message() + + async def cancel_task(self) -> None: + await self.__client.send_command('drop', {}) + + await self.__client.receive_message() + + async def status(self): + raise NotImplementedError() + + async def get_log(self, invocation_id) -> Tuple[str, str]: + await self.__client.send_command('log', { + 'invoc_id': invocation_id + }) + + reply = await (await self.__client.receive_message()).message_body_as_json() + return str(reply['stdout']), str(reply['stderr']) + + async def send_invocation_message(self, + destination_invocation_id: int, + destination_addressee: str, + source_invocation_id: Optional[int], + message_body: bytes, + addressee_timeout: float, + overall_timeout: float) -> InvocationMessageResult: + """ + Note that this command, unlike others, does not raise, + instead it wraps errors into InvocationMessageResult + """ + await self.__client.send_command('invocation_message', { + 'dst_invoc_id': destination_invocation_id, + 'src_invoc_id': source_invocation_id, + 'addressee': destination_addressee, + 'addressee_timeout': addressee_timeout, + 'message_data_raw': message_body.decode('latin1'), + }) + + reply = await (await self.__client.receive_message(timeout=overall_timeout)).message_body_as_json() + return InvocationMessageResult(reply['result']) From c43e5b1c60c14d33b80408ad187984f48556cd61 Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:12:40 +0100 Subject: [PATCH 04/10] refactor: break worker_pool<->message processor dep cycle --- src/lifeblood/main_workerpool.py | 4 +- src/lifeblood/simple_worker_pool.py | 150 ++---------------- src/lifeblood/simple_worker_pool_main.py | 145 +++++++++++++++++ .../worker_pool_message_processor.py | 44 ++--- .../worker_pool_message_processor_client.py | 24 +++ src/lifeblood/worker_pool_protocol.py | 7 +- 6 files changed, 195 insertions(+), 179 deletions(-) create mode 100644 src/lifeblood/simple_worker_pool_main.py create mode 100644 src/lifeblood/worker_pool_message_processor_client.py diff --git a/src/lifeblood/main_workerpool.py b/src/lifeblood/main_workerpool.py index a37feef4..557b19d8 100644 --- a/src/lifeblood/main_workerpool.py +++ b/src/lifeblood/main_workerpool.py @@ -1,7 +1,7 @@ import sys import argparse -from . import simple_worker_pool +from . import simple_worker_pool_main def main(argv): @@ -14,7 +14,7 @@ def main(argv): opts = parser.parse_args(argv[:1]) remaining_args = argv[1:] - known_types = {'simple': simple_worker_pool} + known_types = {'simple': simple_worker_pool_main} if opts.list: print('known pool types:\n' + '\n'.join(f'\t{x}' for x in known_types)) return diff --git a/src/lifeblood/simple_worker_pool.py b/src/lifeblood/simple_worker_pool.py index 52916774..bc9940cd 100644 --- a/src/lifeblood/simple_worker_pool.py +++ b/src/lifeblood/simple_worker_pool.py @@ -1,40 +1,24 @@ import sys import errno -import argparse import asyncio -import signal import shutil import tempfile import time import itertools from pathlib import Path from types import MappingProxyType -import json from .config import get_config, Config -from .broadcasting import await_broadcast from .defaults import message_proxy_port from .pulse_checker import PulseChecker from .process_utils import create_worker_process, send_stop_signal_to_worker from .logging import get_logger -from .worker_pool_message_processor import WorkerPoolMessageProcessor from .nethelpers import get_addr_to, get_localhost from .enums import WorkerState, WorkerType, ProcessPriorityAdjustment - +from .net_messages.message_processor import MessageProcessorBase from .net_messages.address import AddressChain, DirectAddress -from typing import Tuple, Dict, List, Optional - - -async def create_worker_pool(worker_type: WorkerType = WorkerType.STANDARD, *, - minimal_total_to_ensure=0, minimal_idle_to_ensure=0, maximum_total=256, - idle_timeout=10, worker_suspicious_lifetime=4, housekeeping_interval: float = 10, - priority=ProcessPriorityAdjustment.NO_CHANGE, scheduler_address: AddressChain): - swp = WorkerPool(worker_type, - minimal_total_to_ensure=minimal_total_to_ensure, minimal_idle_to_ensure=minimal_idle_to_ensure, maximum_total=maximum_total, - idle_timeout=idle_timeout, worker_suspicious_lifetime=worker_suspicious_lifetime, housekeeping_interval=housekeeping_interval, priority=priority, scheduler_address=scheduler_address) - await swp.start() - return swp +from typing import Callable, Tuple, Dict, List, Optional class ProcData: @@ -49,7 +33,7 @@ def __init__(self, process: asyncio.subprocess.Process, id: int): self.sent_term_signal = False -class WorkerPool: # TODO: split base class, make this just one of implementations +class SimpleWorkerPool: # TODO: split base class, make this just one of implementations def __init__(self, worker_type: WorkerType = WorkerType.STANDARD, *, minimal_total_to_ensure=0, minimal_idle_to_ensure=0, maximum_total=256, idle_timeout=10, worker_suspicious_lifetime=4, housekeeping_interval: float = 10, @@ -57,19 +41,21 @@ def __init__(self, worker_type: WorkerType = WorkerType.STANDARD, *, scheduler_address: AddressChain, message_proxy_address: Optional[Tuple[Optional[str], Optional[int]]] = None, config: Optional[Config] = None, + message_processor_factory: Callable[["SimpleWorkerPool", List[Tuple[str, int]]], MessageProcessorBase], ): """ manages a pool of workers. :param worker_type: workers are created of given type - :param minimal_total_to_ensure: at minimum this amount of workers will be always upheld - :param minimal_idle_to_ensure: at minimum this amount of IDLE or OFF(as we assume they are OFF only while they are booting up) workers will be always upheld + :param minimal_total_to_ensure: at minimum this amount of workers will always be upheld + :param minimal_idle_to_ensure: at minimum this amount of IDLE or OFF(as we assume they are OFF only while they are booting up) workers will always be upheld :param scheduler_address: force created workers to use this scheduler address """ # local helper workers' pool self.__worker_pool: Dict[asyncio.Future, ProcData] = {} self.__workers_to_merge: List[ProcData] = [] self.__pool_task = None - self.__message_proxy: Optional[WorkerPoolMessageProcessor] = None + self.__message_proxy: Optional[MessageProcessorBase] = None + self.__message_processor_factory = message_processor_factory self.__stop_event = asyncio.Event() self.__server_closer_waiter = None self.__poke_event = asyncio.Event() @@ -121,7 +107,7 @@ async def start(self): else: proxy_addresses = (proxy_addr,) for i in range(1024): # somewhat big, but not too big - self.__message_proxy = WorkerPoolMessageProcessor(self, [(addr, proxy_port) for addr in proxy_addresses]) # TODO: config for other arguments + self.__message_proxy = self.__message_processor_factory(self, [(addr, proxy_port) for addr in proxy_addresses]) # TODO: config for other arguments try: await self.__message_proxy.start() break @@ -390,121 +376,3 @@ async def _worker_state_change(self, worker_id: int, state: WorkerState): self.__id_to_procdata[worker_id].state = state self.__id_to_procdata[worker_id].state_entering_time = time.time() self.__poke_event.set() - - -async def async_main(argv): - logger = get_logger('simple_worker_pool') - parser = argparse.ArgumentParser('lifeblood pool simple') - parser.add_argument('--min-idle', '-m', - dest='minimal_idle_to_ensure', - default=1, type=int, - help='worker pool will ensure at least this amount of workers is up idle (default=1)') - parser.add_argument('--min-total', - dest='minimal_total_to_ensure', - default=0, type=int, - help='worker pool will ensure at least this amount of workers is up total (default=0)') - parser.add_argument('--max', '-M', - dest='maximum_total', - default=256, type=int, - help='no more than this amount of workers will be run locally at the same time (default=256)') - parser.add_argument('--priority', choices=tuple(x.name for x in ProcessPriorityAdjustment), default=ProcessPriorityAdjustment.LOWER.name, help='pass to spawned workers: adjust child process priority') - - opts = parser.parse_args(argv) - opts.priority = [x for x in ProcessPriorityAdjustment if x.name == opts.priority][0] # there MUST be exactly 1 match - - graceful_closer_no_reentry = False - - def graceful_closer(*args): - nonlocal graceful_closer_no_reentry - if graceful_closer_no_reentry: - print('DOUBLE SIGNAL CAUGHT: ALREADY EXITING') - return - graceful_closer_no_reentry = True - logger.info('SIGINT/SIGTERM caught') - nonlocal noloop - noloop = True - stop_event.set() - if pool: - pool.stop() - - noasync_do_close = False - - def noasync_windows_graceful_closer_event(*args): - nonlocal noasync_do_close - noasync_do_close = True - - async def windows_graceful_closer(): - while not noasync_do_close: - await asyncio.sleep(1) - graceful_closer() - - logger.debug(f'starting {__name__} with: ' + ', '.join(f'{key}={val}' for key, val in opts.__dict__.items())) - pool = None - noloop = False # TODO: add arg - - # override event handlers - win_signal_waiting_task = None - try: - asyncio.get_event_loop().add_signal_handler(signal.SIGINT, graceful_closer) - asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, graceful_closer) - except NotImplementedError: # solution for windows - signal.signal(signal.SIGINT, noasync_windows_graceful_closer_event) - signal.signal(signal.SIGBREAK, noasync_windows_graceful_closer_event) - win_signal_waiting_task = asyncio.create_task(windows_graceful_closer()) - # - - stop_event = asyncio.Event() - stop_task = asyncio.create_task(stop_event.wait()) - config = get_config('worker') - - start_attempt_cooldown = 0 - while True: - if await config.get_option('worker.listen_to_broadcast', True): - logger.info('listening for scheduler broadcasts...') - broadcast_task = asyncio.create_task(await_broadcast('lifeblood_scheduler')) - done, _ = await asyncio.wait((broadcast_task, stop_task), return_when=asyncio.FIRST_COMPLETED) - if stop_task in done: - broadcast_task.cancel() - logger.info('broadcast listening cancelled') - break - assert broadcast_task.done() - message = await broadcast_task - scheduler_info = json.loads(message) - logger.debug('received', scheduler_info) - if 'message_address' not in scheduler_info: - logger.debug('broadcast does not have "message_address" key, ignoring') - continue - addr = AddressChain(scheduler_info['message_address']) - else: - if stop_event.is_set(): - break - logger.info('boradcast listening disabled') - start_attempt_cooldown = 10 - if not config.has_option_noasync('worker.scheduler_address'): - raise RuntimeError('worker.scheduler_address config option must be provided') - addr = AddressChain(await config.get_option('worker.scheduler_address', None)) - logger.debug(f'using {addr}') - - try: - pool = await create_worker_pool(WorkerType.STANDARD, scheduler_address=addr, **opts.__dict__) - except Exception: - logger.exception('could not start the pool') - await asyncio.sleep(start_attempt_cooldown) - else: - await pool.wait_till_stops() - logger.info('pool quited') - if noloop: - break - - if win_signal_waiting_task is not None: - if not win_signal_waiting_task.done(): - win_signal_waiting_task.cancel() - logger.info('pool loop stopped') - - -def main(argv): - try: - asyncio.run(async_main(argv)) - except KeyboardInterrupt: - get_logger('simple_worker_pool').warning('SIGINT caught where it wasn\'t supposed to be caught') - diff --git a/src/lifeblood/simple_worker_pool_main.py b/src/lifeblood/simple_worker_pool_main.py new file mode 100644 index 00000000..b09fdf83 --- /dev/null +++ b/src/lifeblood/simple_worker_pool_main.py @@ -0,0 +1,145 @@ +import argparse +import asyncio +import signal +import json +from .config import get_config +from .broadcasting import await_broadcast + +from .logging import get_logger + +from .simple_worker_pool import SimpleWorkerPool +from .worker_pool_message_processor import WorkerPoolMessageProcessor +from .enums import WorkerType, ProcessPriorityAdjustment + +from .net_messages.address import AddressChain + + +async def create_worker_pool(worker_type: WorkerType = WorkerType.STANDARD, *, + minimal_total_to_ensure=0, minimal_idle_to_ensure=0, maximum_total=256, + idle_timeout=10, worker_suspicious_lifetime=4, housekeeping_interval: float = 10, + priority=ProcessPriorityAdjustment.NO_CHANGE, scheduler_address: AddressChain): + swp = SimpleWorkerPool(worker_type, + minimal_total_to_ensure=minimal_total_to_ensure, minimal_idle_to_ensure=minimal_idle_to_ensure, maximum_total=maximum_total, + idle_timeout=idle_timeout, worker_suspicious_lifetime=worker_suspicious_lifetime, housekeeping_interval=housekeeping_interval, priority=priority, scheduler_address=scheduler_address, + message_processor_factory=WorkerPoolMessageProcessor, + ) + return swp + + +async def async_main(argv): + logger = get_logger('simple_worker_pool') + parser = argparse.ArgumentParser('lifeblood pool simple') + parser.add_argument('--min-idle', '-m', + dest='minimal_idle_to_ensure', + default=1, type=int, + help='worker pool will ensure at least this amount of workers is up idle (default=1)') + parser.add_argument('--min-total', + dest='minimal_total_to_ensure', + default=0, type=int, + help='worker pool will ensure at least this amount of workers is up total (default=0)') + parser.add_argument('--max', '-M', + dest='maximum_total', + default=256, type=int, + help='no more than this amount of workers will be run locally at the same time (default=256)') + parser.add_argument('--priority', choices=tuple(x.name for x in ProcessPriorityAdjustment), default=ProcessPriorityAdjustment.LOWER.name, help='pass to spawned workers: adjust child process priority') + + opts = parser.parse_args(argv) + opts.priority = [x for x in ProcessPriorityAdjustment if x.name == opts.priority][0] # there MUST be exactly 1 match + + graceful_closer_no_reentry = False + + def graceful_closer(*args): + nonlocal graceful_closer_no_reentry + if graceful_closer_no_reentry: + print('DOUBLE SIGNAL CAUGHT: ALREADY EXITING') + return + graceful_closer_no_reentry = True + logger.info('SIGINT/SIGTERM caught') + nonlocal noloop + noloop = True + stop_event.set() + if pool: + pool.stop() + + noasync_do_close = False + + def noasync_windows_graceful_closer_event(*args): + nonlocal noasync_do_close + noasync_do_close = True + + async def windows_graceful_closer(): + while not noasync_do_close: + await asyncio.sleep(1) + graceful_closer() + + logger.debug(f'starting {__name__} with: ' + ', '.join(f'{key}={val}' for key, val in opts.__dict__.items())) + pool = None + noloop = False # TODO: add arg + + # override event handlers + win_signal_waiting_task = None + try: + asyncio.get_event_loop().add_signal_handler(signal.SIGINT, graceful_closer) + asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, graceful_closer) + except NotImplementedError: # solution for windows + signal.signal(signal.SIGINT, noasync_windows_graceful_closer_event) + signal.signal(signal.SIGBREAK, noasync_windows_graceful_closer_event) + win_signal_waiting_task = asyncio.create_task(windows_graceful_closer()) + # + + stop_event = asyncio.Event() + stop_task = asyncio.create_task(stop_event.wait()) + config = get_config('worker') + + start_attempt_cooldown = 0 + while True: + if await config.get_option('worker.listen_to_broadcast', True): + logger.info('listening for scheduler broadcasts...') + broadcast_task = asyncio.create_task(await_broadcast('lifeblood_scheduler')) + done, _ = await asyncio.wait((broadcast_task, stop_task), return_when=asyncio.FIRST_COMPLETED) + if stop_task in done: + broadcast_task.cancel() + logger.info('broadcast listening cancelled') + break + assert broadcast_task.done() + message = await broadcast_task + scheduler_info = json.loads(message) + logger.debug('received', scheduler_info) + if 'message_address' not in scheduler_info: + logger.debug('broadcast does not have "message_address" key, ignoring') + continue + addr = AddressChain(scheduler_info['message_address']) + else: + if stop_event.is_set(): + break + logger.info('boradcast listening disabled') + start_attempt_cooldown = 10 + if not config.has_option_noasync('worker.scheduler_address'): + raise RuntimeError('worker.scheduler_address config option must be provided') + addr = AddressChain(await config.get_option('worker.scheduler_address', None)) + logger.debug(f'using {addr}') + + try: + pool = await create_worker_pool(WorkerType.STANDARD, scheduler_address=addr, **opts.__dict__) + await pool.start() + except Exception: + logger.exception('could not start the pool') + await asyncio.sleep(start_attempt_cooldown) + else: + await pool.wait_till_stops() + logger.info('pool quited') + if noloop: + break + + if win_signal_waiting_task is not None: + if not win_signal_waiting_task.done(): + win_signal_waiting_task.cancel() + logger.info('pool loop stopped') + + +def main(argv): + try: + asyncio.run(async_main(argv)) + except KeyboardInterrupt: + get_logger('simple_worker_pool').warning('SIGINT caught where it wasn\'t supposed to be caught') + diff --git a/src/lifeblood/worker_pool_message_processor.py b/src/lifeblood/worker_pool_message_processor.py index 08070dd0..d9de4931 100644 --- a/src/lifeblood/worker_pool_message_processor.py +++ b/src/lifeblood/worker_pool_message_processor.py @@ -1,19 +1,16 @@ -from contextlib import contextmanager from .enums import WorkerState from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor from .net_messages.impl.clients import CommandJsonMessageClient -from .net_messages.address import AddressChain +from .net_messages.address import DirectAddress from .net_messages.messages import Message from .net_messages.impl.message_haldlers import CommandMessageHandlerBase +from .simple_worker_pool import SimpleWorkerPool - -from typing import Iterable, Tuple, TYPE_CHECKING, Union -if TYPE_CHECKING: - from .simple_worker_pool import WorkerPool +from typing import Iterable, Tuple, Union class WorkerPoolMessageHandler(CommandMessageHandlerBase): - def __init__(self, worker_pool: "WorkerPool"): + def __init__(self, worker_pool: SimpleWorkerPool): super().__init__() self.__worker_pool = worker_pool @@ -42,32 +39,15 @@ async def _command_state_report(self, args: dict, client: CommandJsonMessageClie class WorkerPoolMessageProcessor(TcpCommandMessageProcessor): - def __init__(self, worker_pool: "WorkerPool", listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]]], *, backlog=4096, connection_pool_cache_time=300): + def __init__( + self, + worker_pool: SimpleWorkerPool, + listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]], DirectAddress, Iterable[DirectAddress]], + *, + backlog=4096, + connection_pool_cache_time=300 + ): super().__init__(listening_address_or_addresses, backlog=backlog, connection_pool_cache_time=connection_pool_cache_time, message_handlers=(WorkerPoolMessageHandler(worker_pool),)) - - -# -# Client -# - - -class WorkerPoolControlClient: - def __init__(self, client: CommandJsonMessageClient): - self.__client = client - - @classmethod - @contextmanager - def get_worker_pool_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "WorkerPoolControlClient": - with processor.message_client(scheduler_address) as message_client: - yield WorkerPoolControlClient(message_client) - - async def report_state(self, worker_id: int, state: WorkerState): - await self.__client.send_command('worker.state_report', { - 'worker_id': worker_id, - 'state': state.value - }) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' diff --git a/src/lifeblood/worker_pool_message_processor_client.py b/src/lifeblood/worker_pool_message_processor_client.py new file mode 100644 index 00000000..be0a7719 --- /dev/null +++ b/src/lifeblood/worker_pool_message_processor_client.py @@ -0,0 +1,24 @@ +from contextlib import contextmanager +from .enums import WorkerState +from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor +from .net_messages.impl.clients import CommandJsonMessageClient +from .net_messages.address import AddressChain + + +class WorkerPoolControlClient: + def __init__(self, client: CommandJsonMessageClient): + self.__client = client + + @classmethod + @contextmanager + def get_worker_pool_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "WorkerPoolControlClient": + with processor.message_client(scheduler_address) as message_client: + yield WorkerPoolControlClient(message_client) + + async def report_state(self, worker_id: int, state: WorkerState): + await self.__client.send_command('worker.state_report', { + 'worker_id': worker_id, + 'state': state.value + }) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' diff --git a/src/lifeblood/worker_pool_protocol.py b/src/lifeblood/worker_pool_protocol.py index 00ebb1e5..5182d0ff 100644 --- a/src/lifeblood/worker_pool_protocol.py +++ b/src/lifeblood/worker_pool_protocol.py @@ -2,14 +2,13 @@ import struct from .logging import get_logger from .enums import WorkerState +from .simple_worker_pool import SimpleWorkerPool -from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from .simple_worker_pool import WorkerPool +from typing import Optional class WorkerPoolProtocol(asyncio.StreamReaderProtocol): - def __init__(self, worker_pool: "WorkerPool", limit=2 ** 16, logger=None): + def __init__(self, worker_pool: SimpleWorkerPool, limit=2 ** 16, logger=None): self.__logger = logger or get_logger(self.__class__.__name__.lower()) self.__timeout = 60 self.__worker_pool = worker_pool From 697dd86cf95e1675fca34bd7a5853c44bf833ae9 Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:24:27 +0100 Subject: [PATCH 05/10] refactor scheduler, break deps with protocols and message processors --- src/lifeblood/scheduler/pinger.py | 6 +- .../scheduler/scheduler_component_base.py | 6 +- src/lifeblood/scheduler/scheduler_core.py | 4 +- src/lifeblood/scheduler/state_object.py | 0 src/lifeblood/scheduler/task_processor.py | 6 +- src/lifeblood/scheduler/ui_state_accessor.py | 4 +- src/lifeblood/scheduler_message_processor.py | 211 ++---------------- .../scheduler_message_processor_client.py | 194 ++++++++++++++++ src/lifeblood/scheduler_task_protocol.py | 7 +- src/lifeblood/scheduler_ui_protocol.py | 18 +- 10 files changed, 234 insertions(+), 222 deletions(-) delete mode 100644 src/lifeblood/scheduler/state_object.py create mode 100644 src/lifeblood/scheduler_message_processor_client.py diff --git a/src/lifeblood/scheduler/pinger.py b/src/lifeblood/scheduler/pinger.py index 0016b187..0af25f19 100644 --- a/src/lifeblood/scheduler/pinger.py +++ b/src/lifeblood/scheduler/pinger.py @@ -2,7 +2,7 @@ import asyncio import time from .. import logging -from ..worker_messsage_processor import WorkerControlClient +from ..worker_message_processor_client import WorkerControlClient from ..enums import WorkerState, InvocationState, WorkerPingState, WorkerPingReply from .scheduler_component_base import SchedulerComponentBase from ..net_messages.address import AddressChain @@ -11,13 +11,13 @@ from typing import Any, Optional, TYPE_CHECKING if TYPE_CHECKING: # TODO: maybe separate a subset of scheduler's methods to smth like SchedulerData class, or idunno, for now no obvious way to separate, so having a reference back - from .scheduler import Scheduler + from .scheduler_core import SchedulerCore class Pinger(SchedulerComponentBase): def __init__( self, - scheduler: "Scheduler", + scheduler: "SchedulerCore", ): super().__init__(scheduler) self.__pinger_logger = logging.get_logger('scheduler.worker_pinger') diff --git a/src/lifeblood/scheduler/scheduler_component_base.py b/src/lifeblood/scheduler/scheduler_component_base.py index f870e4ef..730bc287 100644 --- a/src/lifeblood/scheduler/scheduler_component_base.py +++ b/src/lifeblood/scheduler/scheduler_component_base.py @@ -5,11 +5,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: # TODO: maybe separate a subset of scheduler's methods to smth like SchedulerData class, or idunno, for now no obvious way to separate, so having a reference back - from .scheduler import Scheduler + from .scheduler_core import SchedulerCore class SchedulerComponentBase(ComponentBase): - def __init__(self, scheduler: "Scheduler"): + def __init__(self, scheduler: "SchedulerCore"): super().__init__() self.__stop_event = asyncio.Event() self.__main_task = None @@ -18,7 +18,7 @@ def __init__(self, scheduler: "Scheduler"): self.__mode = SchedulerMode.STANDARD @property - def scheduler(self) -> "Scheduler": + def scheduler(self) -> "SchedulerCore": return self.__scheduler @property diff --git a/src/lifeblood/scheduler/scheduler_core.py b/src/lifeblood/scheduler/scheduler_core.py index 7d9387cd..f4648bca 100644 --- a/src/lifeblood/scheduler/scheduler_core.py +++ b/src/lifeblood/scheduler/scheduler_core.py @@ -18,7 +18,7 @@ from ..invocationjob import Invocation, InvocationJob, Requirements from ..environment_resolver import EnvironmentResolverArguments from ..broadcasting import create_broadcaster -from ..simple_worker_pool import WorkerPool +from ..simple_worker_pool import SimpleWorkerPool from ..nethelpers import get_broadcast_addr_for, all_interfaces from ..worker_metadata import WorkerMetadata from ..taskspawn import TaskSpawn @@ -472,7 +472,7 @@ async def start(self): self.__message_processor = self.__message_processor_factory(self, self.__message_processor_addresses) await self.__message_processor.start() worker_pool_message_proxy_address = (self.__message_processor_addresses[0].split(':', 1)[0], None) # use same ip as scheduler's message processor, but default port - self.__worker_pool = WorkerPool(WorkerType.SCHEDULER_HELPER, + self.__worker_pool = SimpleWorkerPool(WorkerType.SCHEDULER_HELPER, minimal_idle_to_ensure=self.__worker_pool_helpers_minimal_idle_to_ensure, scheduler_address=self.server_message_address(DirectAddress(worker_pool_message_proxy_address[0])), message_proxy_address=worker_pool_message_proxy_address, diff --git a/src/lifeblood/scheduler/state_object.py b/src/lifeblood/scheduler/state_object.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/lifeblood/scheduler/task_processor.py b/src/lifeblood/scheduler/task_processor.py index 6657107f..b7b4b074 100644 --- a/src/lifeblood/scheduler/task_processor.py +++ b/src/lifeblood/scheduler/task_processor.py @@ -10,7 +10,7 @@ from ..basenode_serialization import FailedToDeserialize from ..enums import WorkerState, InvocationState, TaskState, TaskGroupArchivedState, TaskScheduleStatus from ..misc import atimeit -from ..worker_messsage_processor import WorkerControlClient +from ..worker_message_processor_client import WorkerControlClient from ..invocationjob import InvocationJob, InvocationRequirements, Invocation from ..environment_resolver import EnvironmentResolverArguments from ..nodethings import ProcessingResult @@ -25,7 +25,7 @@ from typing import List, Optional, TYPE_CHECKING if TYPE_CHECKING: # TODO: maybe separate a subset of scheduler's methods to smth like SchedulerData class, or idunno, for now no obvious way to separate, so having a reference back - from .scheduler import Scheduler + from .scheduler_core import SchedulerCore # import tracemalloc @@ -34,7 +34,7 @@ class TaskProcessor(SchedulerComponentBase): def __init__( self, - scheduler: "Scheduler", + scheduler: "SchedulerCore", ): super().__init__(scheduler) self.__logger = logging.get_logger('scheduler.task_processor') diff --git a/src/lifeblood/scheduler/ui_state_accessor.py b/src/lifeblood/scheduler/ui_state_accessor.py index 27239ca5..b9224a46 100644 --- a/src/lifeblood/scheduler/ui_state_accessor.py +++ b/src/lifeblood/scheduler/ui_state_accessor.py @@ -19,7 +19,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Set, Union if TYPE_CHECKING: # TODO: maybe separate a subset of scheduler's methods to smth like SchedulerData class, or idunno, for now no obvious way to separate, so having a reference back - from .scheduler import Scheduler + from .scheduler_core import SchedulerCore class QueueEventType(Enum): @@ -38,7 +38,7 @@ def is_expired(self) -> bool: class UIStateAccessor(SchedulerComponentBase): - def __init__(self, scheduler: "Scheduler"): + def __init__(self, scheduler: "SchedulerCore"): super().__init__(scheduler) self.__logger = get_logger('scheduler.ui_state_accessor') self.__data_access = scheduler.data_access diff --git a/src/lifeblood/scheduler_message_processor.py b/src/lifeblood/scheduler_message_processor.py index 1202ac0a..c4ba3de6 100644 --- a/src/lifeblood/scheduler_message_processor.py +++ b/src/lifeblood/scheduler_message_processor.py @@ -1,29 +1,25 @@ import asyncio -import aiofiles -from contextlib import contextmanager from . import logging from . import invocationjob from .taskspawn import TaskSpawn from .enums import WorkerState, WorkerType, SpawnStatus, InvocationState, InvocationMessageResult -from .worker_messsage_processor import WorkerControlClient +from .worker_message_processor_client import WorkerControlClient from .hardware_resources import HardwareResources from .worker_metadata import WorkerMetadata from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor from .net_messages.impl.clients import CommandJsonMessageClient -from .net_messages.address import AddressChain +from .net_messages.address import DirectAddress from .net_messages.messages import Message -from .net_messages.exceptions import MessageTransferTimeoutError, MessageReceiveTimeoutError, MessageTransferError +from .net_messages.exceptions import MessageTransferTimeoutError, MessageTransferError from .net_messages.impl.message_haldlers import CommandMessageHandlerBase +from .scheduler.scheduler_core import SchedulerCore - -from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Set, Tuple, TYPE_CHECKING, Union -if TYPE_CHECKING: - from .scheduler import Scheduler +from typing import Awaitable, Callable, Dict, Iterable, Optional, Tuple, Union class SchedulerCommandHandler(CommandMessageHandlerBase): - def __init__(self, scheduler: "Scheduler"): + def __init__(self, scheduler: SchedulerCore): super().__init__() self.__scheduler = scheduler @@ -222,7 +218,7 @@ async def _command_forward_invocation_message(self, args: dict, client: CommandJ class SchedulerExtraCommandHandler(CommandMessageHandlerBase): - def __init__(self, scheduler: "Scheduler"): + def __init__(self, scheduler: SchedulerCore): super().__init__() self.__scheduler = scheduler @@ -287,194 +283,17 @@ async def comm_update_task_attributes(self, args: dict, client: CommandJsonMessa class SchedulerMessageProcessor(TcpCommandMessageProcessor): - def __init__(self, scheduler: "Scheduler", listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]]], *, backlog=4096, connection_pool_cache_time=300): + def __init__( + self, + scheduler: SchedulerCore, + listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]], DirectAddress, Iterable[DirectAddress]], + *, + backlog=4096, + connection_pool_cache_time=300 + ): super().__init__(listening_address_or_addresses, backlog=backlog, connection_pool_cache_time=connection_pool_cache_time, message_handlers=(SchedulerCommandHandler(scheduler), SchedulerExtraCommandHandler(scheduler))) self.__logger = logging.get_logger('scheduler.message_processor') - - -# -# Client -# - - -class SchedulerBaseClient: - def __init__(self, client: CommandJsonMessageClient): - self.__client = client - - @classmethod - @contextmanager - def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerBaseClient": - with processor.message_client(scheduler_address) as message_client: - yield SchedulerBaseClient(message_client) - - async def pulse(self): - await self.__client.send_command('pulse', {}) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' - - async def get_normalized_addresses(self) -> Tuple[AddressChain, AddressChain]: - """ - TODO: this should be available to ALL clients/processors - normalized address chain is the one that has all the intermediate addresses - so that reversed address can be used as is to send messages back - example of non-normalized address would be - 192.168.0.11:1234|10.0.0.22:2345 - this assumes that target at 192.168.0.11:1234 can send messages to different subnets - while such address is correct, it cannot be used reversed as return address without additional actions - a normalized version of this address is something like - 192.168.0.11:1234|10.0.0.11:1234|10.0.0.22:2345 - - :returns: tuple of normalized addresses of destination message processor, and a reversed address of this client's processor as seen by destination processor - """ - - await self.__client.send_command('what_is_my_address', {}) - reply = await self.__client.receive_message() - reply_body = await reply.message_body_as_json() - assert reply_body.get('ok', False), 'something is not ok' - return reply.message_source(), reply_body['my_address'] - - -class SchedulerWorkerControlClient(SchedulerBaseClient): - def __init__(self, client: CommandJsonMessageClient): - super().__init__(client) - self.__client = client - - @classmethod - @contextmanager - def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerWorkerControlClient": - with processor.message_client(scheduler_address) as message_client: - yield SchedulerWorkerControlClient(message_client) - - async def ping(self, addr: AddressChain) -> WorkerState: - await self.__client.send_command('worker.ping', { - 'worker_addr': str(addr) - }) - reply = await self.__client.receive_message() - return WorkerState((await reply.message_body_as_json())['state']) - - async def report_task_done(self, task: invocationjob.Invocation, stdout_file: str, stderr_file: str): - async with aiofiles.open(stdout_file, 'r', errors='replace') as f: - stdout = await f.read() - async with aiofiles.open(stderr_file, 'r', errors='replace') as f: - stderr = await f.read() - await self.__client.send_command('worker.done', { - 'task': await asyncio.get_event_loop().run_in_executor(None, task.serialize_to_data), - 'stdout': stdout, - 'stderr': stderr - }) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' - - async def report_task_canceled(self, task: invocationjob.Invocation, stdout_file: str, stderr_file: str): - async with aiofiles.open(stdout_file, 'r') as f: - stdout = await f.read() - async with aiofiles.open(stderr_file, 'r') as f: - stderr = await f.read() - await self.__client.send_command('worker.dropped', { - 'task': await asyncio.get_event_loop().run_in_executor(None, task.serialize_to_data), - 'stdout': stdout, - 'stderr': stderr - }) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' - - async def report_invocation_progress(self, invocation_id: int, progress: float): - await self.__client.send_command('worker.progress_report', { - 'invocation_id': invocation_id, - 'progress': progress, - }) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' - - async def say_hello(self, address_to_advertise: AddressChain, worker_type: WorkerType, worker_resources: HardwareResources, worker_metadata: WorkerMetadata) -> int: - await self.__client.send_command('worker.hello', { - 'worker_addr': str(address_to_advertise), - 'worker_type': worker_type.value, - 'worker_res': worker_resources.serialize().decode('latin1'), - 'meta_hostname': worker_metadata.hostname, - }) - reply = await self.__client.receive_message() - return (await reply.message_body_as_json())['db_uid'] - - async def say_bye(self, address_of_worker: str): - await self.__client.send_command('worker.bye', { - 'worker_addr': str(address_of_worker) - }) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' - - -class SchedulerExtraControlClient(SchedulerBaseClient): - def __init__(self, client: CommandJsonMessageClient): - super().__init__(client) - self.__client = client - - @classmethod - @contextmanager - def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerExtraControlClient": - with processor.message_client(scheduler_address) as message_client: - yield SchedulerExtraControlClient(message_client) - - async def spawn(self, task_spawn: TaskSpawn) -> Tuple[SpawnStatus, Optional[int]]: - await self.__client.send_command('spawn', { - 'task': task_spawn.serialize().decode('latin1') - }) - reply = await self.__client.receive_message() - ret_data = await reply.message_body_as_json() - return SpawnStatus(ret_data['status']), ret_data['task_id'] - - async def node_name_to_id(self, name: str) -> List[int]: - await self.__client.send_command('nodenametoid', { - 'name': name - }) - reply = await self.__client.receive_message() - ret_data = await reply.message_body_as_json() - return list(ret_data['node_ids']) - - async def update_task_attributes(self, task_id: int, attribs_to_update: dict, attribs_to_delete: Set[str]): - await self.__client.send_command('tupdateattribs', { - 'task_id': task_id, - 'attribs_to_update': attribs_to_update, - 'attribs_to_delete': list(attribs_to_delete), - }) - reply = await self.__client.receive_message() - assert (await reply.message_body_as_json()).get('ok') - - -class SchedulerInvocationMessageClient: - def __init__(self, client: CommandJsonMessageClient): - self.__client = client - - @classmethod - @contextmanager - def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerInvocationMessageClient": - with processor.message_client(scheduler_address) as message_client: - yield SchedulerInvocationMessageClient(message_client) - - async def send_invocation_message(self, - destination_invocation_id: int, - destination_addressee: str, - source_invocation_id: Optional[int], - message_body: bytes, - *, - addressee_timeout: float = 90, - overall_timeout: float = 300) -> InvocationMessageResult: - if overall_timeout < addressee_timeout: - overall_timeout = addressee_timeout - - await self.__client.send_command('forward_invocation_message', { - 'dst_invoc_id': destination_invocation_id, - 'src_invoc_id': source_invocation_id, - 'addressee': destination_addressee, - 'addressee_timeout': addressee_timeout, - 'overall_timeout': overall_timeout, - 'message_data_raw': message_body.decode('latin1'), - }) - try: - return InvocationMessageResult((await (await self.__client.receive_message(timeout=overall_timeout)).message_body_as_json())['result']) - except MessageReceiveTimeoutError: - return InvocationMessageResult.ERROR_DELIVERY_TIMEOUT diff --git a/src/lifeblood/scheduler_message_processor_client.py b/src/lifeblood/scheduler_message_processor_client.py new file mode 100644 index 00000000..72850d49 --- /dev/null +++ b/src/lifeblood/scheduler_message_processor_client.py @@ -0,0 +1,194 @@ +import asyncio + +import aiofiles +from contextlib import contextmanager +from . import invocationjob +from .taskspawn import TaskSpawn +from .enums import WorkerState, WorkerType, SpawnStatus, InvocationMessageResult +from .hardware_resources import HardwareResources +from .worker_metadata import WorkerMetadata +from .net_messages.impl.tcp_simple_command_message_processor import TcpCommandMessageProcessor +from .net_messages.impl.clients import CommandJsonMessageClient +from .net_messages.address import AddressChain +from .net_messages.exceptions import MessageReceiveTimeoutError + +from typing import List, Optional, Set, Tuple + + +class SchedulerBaseClient: + def __init__(self, client: CommandJsonMessageClient): + self.__client = client + + @classmethod + @contextmanager + def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerBaseClient": + with processor.message_client(scheduler_address) as message_client: + yield SchedulerBaseClient(message_client) + + async def pulse(self): + await self.__client.send_command('pulse', {}) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' + + async def get_normalized_addresses(self) -> Tuple[AddressChain, AddressChain]: + """ + TODO: this should be available to ALL clients/processors + normalized address chain is the one that has all the intermediate addresses + so that reversed address can be used as is to send messages back + example of non-normalized address would be + 192.168.0.11:1234|10.0.0.22:2345 + this assumes that target at 192.168.0.11:1234 can send messages to different subnets + while such address is correct, it cannot be used reversed as return address without additional actions + a normalized version of this address is something like + 192.168.0.11:1234|10.0.0.11:1234|10.0.0.22:2345 + + :returns: tuple of normalized addresses of destination message processor, and a reversed address of this client's processor as seen by destination processor + """ + + await self.__client.send_command('what_is_my_address', {}) + reply = await self.__client.receive_message() + reply_body = await reply.message_body_as_json() + assert reply_body.get('ok', False), 'something is not ok' + return reply.message_source(), reply_body['my_address'] + + +class SchedulerWorkerControlClient(SchedulerBaseClient): + def __init__(self, client: CommandJsonMessageClient): + super().__init__(client) + self.__client = client + + @classmethod + @contextmanager + def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerWorkerControlClient": + with processor.message_client(scheduler_address) as message_client: + yield SchedulerWorkerControlClient(message_client) + + async def ping(self, addr: AddressChain) -> WorkerState: + await self.__client.send_command('worker.ping', { + 'worker_addr': str(addr) + }) + reply = await self.__client.receive_message() + return WorkerState((await reply.message_body_as_json())['state']) + + async def report_task_done(self, task: invocationjob.Invocation, stdout_file: str, stderr_file: str): + async with aiofiles.open(stdout_file, 'r', errors='replace') as f: + stdout = await f.read() + async with aiofiles.open(stderr_file, 'r', errors='replace') as f: + stderr = await f.read() + await self.__client.send_command('worker.done', { + 'task': await asyncio.get_event_loop().run_in_executor(None, task.serialize_to_data), + 'stdout': stdout, + 'stderr': stderr + }) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' + + async def report_task_canceled(self, task: invocationjob.Invocation, stdout_file: str, stderr_file: str): + async with aiofiles.open(stdout_file, 'r') as f: + stdout = await f.read() + async with aiofiles.open(stderr_file, 'r') as f: + stderr = await f.read() + await self.__client.send_command('worker.dropped', { + 'task': await asyncio.get_event_loop().run_in_executor(None, task.serialize_to_data), + 'stdout': stdout, + 'stderr': stderr + }) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' + + async def report_invocation_progress(self, invocation_id: int, progress: float): + await self.__client.send_command('worker.progress_report', { + 'invocation_id': invocation_id, + 'progress': progress, + }) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' + + async def say_hello(self, address_to_advertise: AddressChain, worker_type: WorkerType, worker_resources: HardwareResources, worker_metadata: WorkerMetadata) -> int: + await self.__client.send_command('worker.hello', { + 'worker_addr': str(address_to_advertise), + 'worker_type': worker_type.value, + 'worker_res': worker_resources.serialize().decode('latin1'), + 'meta_hostname': worker_metadata.hostname, + }) + reply = await self.__client.receive_message() + return (await reply.message_body_as_json())['db_uid'] + + async def say_bye(self, address_of_worker: str): + await self.__client.send_command('worker.bye', { + 'worker_addr': str(address_of_worker) + }) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok', False), 'something is not ok' + + +class SchedulerExtraControlClient(SchedulerBaseClient): + def __init__(self, client: CommandJsonMessageClient): + super().__init__(client) + self.__client = client + + @classmethod + @contextmanager + def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerExtraControlClient": + with processor.message_client(scheduler_address) as message_client: + yield SchedulerExtraControlClient(message_client) + + async def spawn(self, task_spawn: TaskSpawn) -> Tuple[SpawnStatus, Optional[int]]: + await self.__client.send_command('spawn', { + 'task': task_spawn.serialize().decode('latin1') + }) + reply = await self.__client.receive_message() + ret_data = await reply.message_body_as_json() + return SpawnStatus(ret_data['status']), ret_data['task_id'] + + async def node_name_to_id(self, name: str) -> List[int]: + await self.__client.send_command('nodenametoid', { + 'name': name + }) + reply = await self.__client.receive_message() + ret_data = await reply.message_body_as_json() + return list(ret_data['node_ids']) + + async def update_task_attributes(self, task_id: int, attribs_to_update: dict, attribs_to_delete: Set[str]): + await self.__client.send_command('tupdateattribs', { + 'task_id': task_id, + 'attribs_to_update': attribs_to_update, + 'attribs_to_delete': list(attribs_to_delete), + }) + reply = await self.__client.receive_message() + assert (await reply.message_body_as_json()).get('ok') + + +class SchedulerInvocationMessageClient: + def __init__(self, client: CommandJsonMessageClient): + self.__client = client + + @classmethod + @contextmanager + def get_scheduler_control_client(cls, scheduler_address: AddressChain, processor: TcpCommandMessageProcessor) -> "SchedulerInvocationMessageClient": + with processor.message_client(scheduler_address) as message_client: + yield SchedulerInvocationMessageClient(message_client) + + async def send_invocation_message(self, + destination_invocation_id: int, + destination_addressee: str, + source_invocation_id: Optional[int], + message_body: bytes, + *, + addressee_timeout: float = 90, + overall_timeout: float = 300) -> InvocationMessageResult: + if overall_timeout < addressee_timeout: + overall_timeout = addressee_timeout + + await self.__client.send_command('forward_invocation_message', { + 'dst_invoc_id': destination_invocation_id, + 'src_invoc_id': source_invocation_id, + 'addressee': destination_addressee, + 'addressee_timeout': addressee_timeout, + 'overall_timeout': overall_timeout, + 'message_data_raw': message_body.decode('latin1'), + }) + try: + return InvocationMessageResult((await (await self.__client.receive_message(timeout=overall_timeout)).message_body_as_json())['result']) + except MessageReceiveTimeoutError: + return InvocationMessageResult.ERROR_DELIVERY_TIMEOUT diff --git a/src/lifeblood/scheduler_task_protocol.py b/src/lifeblood/scheduler_task_protocol.py index 81c15b02..8b5a6648 100644 --- a/src/lifeblood/scheduler_task_protocol.py +++ b/src/lifeblood/scheduler_task_protocol.py @@ -11,14 +11,13 @@ from .enums import WorkerType, SpawnStatus, WorkerState from .hardware_resources import HardwareResources from .worker_metadata import WorkerMetadata +from .scheduler.scheduler_core import SchedulerCore -from typing import TYPE_CHECKING, Optional, Tuple -if TYPE_CHECKING: - from .scheduler import Scheduler +from typing import Optional, Tuple class SchedulerTaskProtocol(asyncio.StreamReaderProtocol): - def __init__(self, scheduler: "Scheduler", limit=2**16): + def __init__(self, scheduler: SchedulerCore, limit=2**16): self.__logger = logging.get_logger('scheduler') self.__timeout = 300.0 self.__reader = asyncio.StreamReader(limit=limit) diff --git a/src/lifeblood/scheduler_ui_protocol.py b/src/lifeblood/scheduler_ui_protocol.py index ef89f626..85252685 100644 --- a/src/lifeblood/scheduler_ui_protocol.py +++ b/src/lifeblood/scheduler_ui_protocol.py @@ -1,13 +1,13 @@ import struct import pickle -import json import asyncio import time from asyncio.exceptions import IncompleteReadError from . import logging from .attribute_serialization import serialize_attributes_core, deserialize_attributes_core -from .uidata import NodeUi, Parameter, ParameterLocked, ParameterReadonly, ParameterNotFound, ParameterCannotHaveExpressions -from .ui_protocol_data import NodeGraphStructureData, TaskGroupBatchData, TaskBatchData, WorkerBatchData, UiData, InvocationLogData, IncompleteInvocationLogData +from .node_ui import NodeUi +from .node_parameters import Parameter, ParameterLocked, ParameterReadonly, ParameterNotFound, ParameterCannotHaveExpressions +from .ui_protocol_data import NodeGraphStructureData, TaskGroupBatchData, TaskBatchData, WorkerBatchData, InvocationLogData, IncompleteInvocationLogData from .ui_events import TaskEvent from .enums import NodeParameterType, TaskState, SpawnStatus, TaskGroupArchivedState from .exceptions import NotSubscribedError, DataIntegrityError, UiClientOperationFailed @@ -17,11 +17,11 @@ from .snippets import NodeSnippetData, NodeSnippetDataPlaceholder from .environment_resolver import EnvironmentResolverArguments from .buffered_connection import BufferedConnection +from .basenode import BaseNode +from .scheduler.scheduler_core import SchedulerCore + +from typing import Any, Dict, Iterable, Optional, Tuple, List, Union -from typing import Any, Dict, Iterable, TYPE_CHECKING, Optional, Tuple, List, Union -if TYPE_CHECKING: - from .basenode import BaseNode - from .scheduler import Scheduler def _serialize_attrib_dict(d: dict) -> bytes: @@ -33,9 +33,9 @@ def _deserialize_attrib_dict(data: bytes) -> dict: class SchedulerUiProtocol(asyncio.StreamReaderProtocol): - def __init__(self, scheduler): + def __init__(self, scheduler: SchedulerCore): self.__logger = logging.get_logger('scheduler.uiprotocol') - self.__scheduler: "Scheduler" = scheduler + self.__scheduler: SchedulerCore = scheduler self.__reader = asyncio.StreamReader() self.__timeout = 60.0 self.__saved_references = [] From b7cd8fb78a42048a33dc72364e1940975001bc9f Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:25:21 +0100 Subject: [PATCH 06/10] fix type hints --- .../impl/tcp_simple_command_message_processor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lifeblood/net_messages/impl/tcp_simple_command_message_processor.py b/src/lifeblood/net_messages/impl/tcp_simple_command_message_processor.py index b80ebe18..6bb54efa 100644 --- a/src/lifeblood/net_messages/impl/tcp_simple_command_message_processor.py +++ b/src/lifeblood/net_messages/impl/tcp_simple_command_message_processor.py @@ -1,4 +1,5 @@ from .clients import JsonMessageClientFactory, CommandJsonMessageClientFactory +from ..address import DirectAddress from ..message_handler import MessageHandlerBase from .tcp_message_processor import TcpMessageProcessor @@ -6,7 +7,7 @@ class TcpJsonMessageProcessor(TcpMessageProcessor): - def __init__(self, listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]]], *, + def __init__(self, listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]], DirectAddress, Iterable[DirectAddress]], *, backlog=4096, connection_pool_cache_time=300, message_client_factory: Optional[JsonMessageClientFactory] = None, @@ -19,7 +20,7 @@ def __init__(self, listening_address_or_addresses: Union[Tuple[str, int], Iterab class TcpCommandMessageProcessor(TcpJsonMessageProcessor): - def __init__(self, listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]]], *, + def __init__(self, listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]], DirectAddress, Iterable[DirectAddress]], *, backlog=4096, connection_pool_cache_time=300, message_handlers: Sequence[MessageHandlerBase] = ()): From 97a34864dfba1103e6cf32326039122f4b81c79b Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:26:40 +0100 Subject: [PATCH 07/10] refactor: adjust imports/types according to prev decisions --- src/lifeblood/basenode_serializer_v2.py | 2 +- .../core_nodes/environment_resolver_setter.py | 2 - src/lifeblood/core_nodes/mod_attrib.py | 4 - .../core_nodes/parent_children_waiter.py | 11 +- src/lifeblood/core_nodes/python.py | 11 +- src/lifeblood/core_nodes/rename_attrib.py | 3 - src/lifeblood/core_nodes/split_waiter.py | 8 +- src/lifeblood/core_nodes/switch.py | 4 +- src/lifeblood/core_nodes/test.py | 6 +- src/lifeblood/core_nodes/wait_for_task.py | 3 - src/lifeblood/core_nodes/wedge.py | 1 - src/lifeblood/node_plugin_base.py | 4 +- src/lifeblood/node_type_metadata.py | 2 +- src/lifeblood/pulse_checker.py | 3 +- src/lifeblood/stock_nodes/ffmpeg.py | 3 +- src/lifeblood/stock_nodes/file_watcher.py | 3 +- src/lifeblood/stock_nodes/fileop.py | 2 +- src/lifeblood/stock_nodes/filepattern.py | 2 +- src/lifeblood/stock_nodes/imagemagik.py | 4 +- src/lifeblood/uidata.py | 1462 ----------------- .../integration_common.py | 13 +- src/lifeblood_testing_common/nodes_common.py | 16 +- src/lifeblood_viewer/connection_worker.py | 4 +- .../graphics_items/graphics_items.py | 4 +- .../pretty_items/fancy_items/scene_node.py | 3 +- .../graphics_scene_with_data_controller.py | 3 +- src/lifeblood_viewer/scene_data_controller.py | 2 +- src/lifeblood_viewer/scene_ops.py | 2 +- tests/test_nodeui.py | 5 +- tests/test_scheduler.py | 2 +- tests/test_spawn_tasks_race.py | 6 +- tests/test_worker_pool.py | 10 +- ...ker_restart_double_invocation_edge_case.py | 8 +- 33 files changed, 70 insertions(+), 1548 deletions(-) delete mode 100644 src/lifeblood/uidata.py diff --git a/src/lifeblood/basenode_serializer_v2.py b/src/lifeblood/basenode_serializer_v2.py index f30261f8..603544a3 100644 --- a/src/lifeblood/basenode_serializer_v2.py +++ b/src/lifeblood/basenode_serializer_v2.py @@ -4,7 +4,7 @@ from .basenode_serialization import NodeSerializerBase, IncompatibleDeserializationMethod, FailedToApplyNodeState, FailedToApplyParameters from .basenode import BaseNode from .enums import NodeParameterType -from .uidata import ParameterFullValue +from .node_parameters import ParameterFullValue from typing import Optional, Tuple, Union diff --git a/src/lifeblood/core_nodes/environment_resolver_setter.py b/src/lifeblood/core_nodes/environment_resolver_setter.py index e303d4b1..69624a2f 100644 --- a/src/lifeblood/core_nodes/environment_resolver_setter.py +++ b/src/lifeblood/core_nodes/environment_resolver_setter.py @@ -3,8 +3,6 @@ from lifeblood.processingcontext import ProcessingContext from lifeblood.enums import NodeParameterType from lifeblood.environment_resolver import EnvironmentResolverArguments -from lifeblood.uidata import NodeUi, MultiGroupLayout, Parameter -from lifeblood.node_visualization_classes import NodeColorScheme from typing import Iterable diff --git a/src/lifeblood/core_nodes/mod_attrib.py b/src/lifeblood/core_nodes/mod_attrib.py index 09db58a6..72e57cdc 100644 --- a/src/lifeblood/core_nodes/mod_attrib.py +++ b/src/lifeblood/core_nodes/mod_attrib.py @@ -1,10 +1,6 @@ from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult -from lifeblood.taskspawn import TaskSpawn -from lifeblood.exceptions import NodeNotReadyToProcess from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi, MultiGroupLayout, Parameter -from lifeblood.node_visualization_classes import NodeColorScheme from typing import Iterable diff --git a/src/lifeblood/core_nodes/parent_children_waiter.py b/src/lifeblood/core_nodes/parent_children_waiter.py index 6c3ce740..09eb258b 100644 --- a/src/lifeblood/core_nodes/parent_children_waiter.py +++ b/src/lifeblood/core_nodes/parent_children_waiter.py @@ -1,20 +1,13 @@ import dataclasses from dataclasses import dataclass -from lifeblood.attribute_serialization import deserialize_attributes_core -from lifeblood.node_plugin_base import BaseNode, ProcessingError +from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult -from lifeblood.taskspawn import TaskSpawn from lifeblood.exceptions import NodeNotReadyToProcess from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi -from lifeblood.processingcontext import ProcessingContext from threading import Lock -from typing import Any, Dict, Iterable, List, Optional, Set, TypedDict , TYPE_CHECKING - -if TYPE_CHECKING: - from lifeblood.scheduler import Scheduler +from typing import Dict, Iterable, List, Set def node_class(): diff --git a/src/lifeblood/core_nodes/python.py b/src/lifeblood/core_nodes/python.py index 44ea60ac..3e456a07 100644 --- a/src/lifeblood/core_nodes/python.py +++ b/src/lifeblood/core_nodes/python.py @@ -1,16 +1,13 @@ import re -import time from lifeblood.node_plugin_base import BaseNodeWithTaskRequirements -from lifeblood.invocationjob import InvocationJob, InvocationEnvironment +from lifeblood.invocationjob import InvocationJob from lifeblood.processingcontext import ProcessingContext from lifeblood.nodethings import ProcessingResult, ProcessingError -from lifeblood.uidata import NodeParameterType +from lifeblood.enums import NodeParameterType + +from typing import Iterable -from types import MappingProxyType -from typing import TYPE_CHECKING, Iterable -if TYPE_CHECKING: - from lifeblood.scheduler import Scheduler def node_class(): diff --git a/src/lifeblood/core_nodes/rename_attrib.py b/src/lifeblood/core_nodes/rename_attrib.py index 82e138de..60a6a2df 100644 --- a/src/lifeblood/core_nodes/rename_attrib.py +++ b/src/lifeblood/core_nodes/rename_attrib.py @@ -1,9 +1,6 @@ from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult, ProcessingError -from lifeblood.taskspawn import TaskSpawn -from lifeblood.exceptions import NodeNotReadyToProcess from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi from typing import Iterable diff --git a/src/lifeblood/core_nodes/split_waiter.py b/src/lifeblood/core_nodes/split_waiter.py index c0e863f9..501ae4a8 100644 --- a/src/lifeblood/core_nodes/split_waiter.py +++ b/src/lifeblood/core_nodes/split_waiter.py @@ -1,18 +1,12 @@ from dataclasses import dataclass from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult -from lifeblood.taskspawn import TaskSpawn from lifeblood.exceptions import NodeNotReadyToProcess from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi -from lifeblood.processingcontext import ProcessingContext from threading import Lock -from typing import Dict, TypedDict, Set, Iterable, Optional, Any, TYPE_CHECKING - -if TYPE_CHECKING: - from lifeblood.scheduler import Scheduler +from typing import Dict, Set, Iterable, Optional @dataclass diff --git a/src/lifeblood/core_nodes/switch.py b/src/lifeblood/core_nodes/switch.py index 105e0f01..2b39aae5 100644 --- a/src/lifeblood/core_nodes/switch.py +++ b/src/lifeblood/core_nodes/switch.py @@ -1,9 +1,7 @@ from lifeblood.node_plugin_base import BaseNode -from lifeblood.nodethings import ProcessingResult, ProcessingError +from lifeblood.nodethings import ProcessingResult from lifeblood.processingcontext import ProcessingContext from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi, Parameter, VerticalParametersLayout, ParameterHierarchyItem, ParametersLayoutBase -from lifeblood.node_visualization_classes import NodeColorScheme from typing import Iterable diff --git a/src/lifeblood/core_nodes/test.py b/src/lifeblood/core_nodes/test.py index 34922267..8eaba8e9 100644 --- a/src/lifeblood/core_nodes/test.py +++ b/src/lifeblood/core_nodes/test.py @@ -3,11 +3,9 @@ from lifeblood.node_plugin_base import BaseNode from lifeblood.invocationjob import InvocationJob, InvocationEnvironment from lifeblood.nodethings import ProcessingResult -from lifeblood.uidata import NodeParameterType +from lifeblood.enums import NodeParameterType -from typing import TYPE_CHECKING, Iterable -if TYPE_CHECKING: - from lifeblood.scheduler import Scheduler +from typing import Iterable def node_class(): diff --git a/src/lifeblood/core_nodes/wait_for_task.py b/src/lifeblood/core_nodes/wait_for_task.py index 7ef064ce..4f6a4993 100644 --- a/src/lifeblood/core_nodes/wait_for_task.py +++ b/src/lifeblood/core_nodes/wait_for_task.py @@ -2,11 +2,8 @@ import shlex from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult, ProcessingContext -from lifeblood.taskspawn import TaskSpawn from lifeblood.exceptions import NodeNotReadyToProcess from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi, MultiGroupLayout, Parameter -from lifeblood.node_visualization_classes import NodeColorScheme from typing import Dict, Iterable, List, Optional, Set diff --git a/src/lifeblood/core_nodes/wedge.py b/src/lifeblood/core_nodes/wedge.py index 30170f52..c9b2a397 100644 --- a/src/lifeblood/core_nodes/wedge.py +++ b/src/lifeblood/core_nodes/wedge.py @@ -1,7 +1,6 @@ from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult, ProcessingError from lifeblood.enums import NodeParameterType -from lifeblood.uidata import NodeUi, MultiGroupLayout, Parameter from typing import Iterable diff --git a/src/lifeblood/node_plugin_base.py b/src/lifeblood/node_plugin_base.py index 9c377e3c..440415f9 100644 --- a/src/lifeblood/node_plugin_base.py +++ b/src/lifeblood/node_plugin_base.py @@ -10,7 +10,7 @@ from .invocationjob import ResourceRequirement, ResourceRequirements from .nodethings import ProcessingResult, ProcessingError # unused import - for easy reexport to plugins from .worker_resource_definition import WorkerResourceDefinition, WorkerResourceDataType, WorkerDeviceTypeDefinition -from .uidata import NodeUi +from .node_ui import NodeUi from .scheduler.scheduler import Scheduler from typing import Dict, Optional, Tuple, Union @@ -142,7 +142,7 @@ def set_parent(self, graph_holder: NodeGraphHolderBase, node_id_in_graph: int): def __apply_requirements(self, task_dict: dict, node_config: dict, result: ProcessingResult): if result.invocation_job is not None: - context = ProcessingContext(self, task_dict, node_config) + context = ProcessingContext(self.name(), self.label(), self.get_ui(), task_dict, node_config) raw_groups = context.param_value('__requirements__.worker_groups').strip() reqs = result.invocation_job.requirements() if raw_groups != '': diff --git a/src/lifeblood/node_type_metadata.py b/src/lifeblood/node_type_metadata.py index 9ea75943..4dc593e8 100644 --- a/src/lifeblood/node_type_metadata.py +++ b/src/lifeblood/node_type_metadata.py @@ -2,7 +2,7 @@ from .node_dataprovider_base import NodeDataProvider from .plugin_info import PluginInfo -from typing import Optional, TYPE_CHECKING, Tuple, Set +from typing import Optional, Tuple, Set class NodeTypePluginMetadata: diff --git a/src/lifeblood/pulse_checker.py b/src/lifeblood/pulse_checker.py index 431d8981..7e385de9 100644 --- a/src/lifeblood/pulse_checker.py +++ b/src/lifeblood/pulse_checker.py @@ -1,7 +1,6 @@ import asyncio from .logging import get_logger -#from .scheduler_task_protocol import SchedulerTaskClient -from .scheduler_message_processor import SchedulerWorkerControlClient +from .scheduler_message_processor_client import SchedulerWorkerControlClient from .net_messages.address import AddressChain from .net_messages.message_processor import MessageProcessorBase from .net_messages.exceptions import MessageTransferError diff --git a/src/lifeblood/stock_nodes/ffmpeg.py b/src/lifeblood/stock_nodes/ffmpeg.py index ba7bd9e7..434261fa 100644 --- a/src/lifeblood/stock_nodes/ffmpeg.py +++ b/src/lifeblood/stock_nodes/ffmpeg.py @@ -5,9 +5,8 @@ from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult, ProcessingError -from lifeblood.uidata import NodeParameterType +from lifeblood.enums import NodeParameterType from lifeblood.invocationjob import InvocationJob -from lifeblood.invocationjob import InvocationJob, InvocationEnvironment from typing import Iterable diff --git a/src/lifeblood/stock_nodes/file_watcher.py b/src/lifeblood/stock_nodes/file_watcher.py index 1c861ccf..8d996434 100644 --- a/src/lifeblood/stock_nodes/file_watcher.py +++ b/src/lifeblood/stock_nodes/file_watcher.py @@ -4,10 +4,9 @@ from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult, ProcessingError -from lifeblood.uidata import NodeParameterType from lifeblood.processingcontext import ProcessingContext from lifeblood.invocationjob import InvocationJob, InvocationRequirements -from lifeblood.enums import WorkerType +from lifeblood.enums import WorkerType, NodeParameterType from typing import Iterable diff --git a/src/lifeblood/stock_nodes/fileop.py b/src/lifeblood/stock_nodes/fileop.py index b19a8656..83821a0f 100644 --- a/src/lifeblood/stock_nodes/fileop.py +++ b/src/lifeblood/stock_nodes/fileop.py @@ -4,7 +4,7 @@ from lifeblood.nodethings import ProcessingResult, ProcessingError from lifeblood.invocationjob import InvocationJob from lifeblood.processingcontext import ProcessingContext -from lifeblood.uidata import NodeParameterType +from lifeblood.enums import NodeParameterType from typing import Iterable diff --git a/src/lifeblood/stock_nodes/filepattern.py b/src/lifeblood/stock_nodes/filepattern.py index 0c886bb9..58e56b66 100644 --- a/src/lifeblood/stock_nodes/filepattern.py +++ b/src/lifeblood/stock_nodes/filepattern.py @@ -6,7 +6,7 @@ from lifeblood.nodethings import ProcessingResult, ProcessingError from lifeblood.invocationjob import InvocationJob from lifeblood.processingcontext import ProcessingContext -from lifeblood.uidata import NodeParameterType +from lifeblood.enums import NodeParameterType from lifeblood.enums import WorkerType from lifeblood.text import match_pattern diff --git a/src/lifeblood/stock_nodes/imagemagik.py b/src/lifeblood/stock_nodes/imagemagik.py index db3f7110..34fbaa0a 100644 --- a/src/lifeblood/stock_nodes/imagemagik.py +++ b/src/lifeblood/stock_nodes/imagemagik.py @@ -1,13 +1,11 @@ import os import shlex import re -from math import sqrt, floor, ceil from lifeblood.node_plugin_base import BaseNode from lifeblood.nodethings import ProcessingResult, ProcessingError -from lifeblood.uidata import NodeParameterType +from lifeblood.enums import NodeParameterType from lifeblood.invocationjob import InvocationJob -from lifeblood.invocationjob import InvocationJob, InvocationEnvironment from typing import Iterable diff --git a/src/lifeblood/uidata.py b/src/lifeblood/uidata.py deleted file mode 100644 index de73bd98..00000000 --- a/src/lifeblood/uidata.py +++ /dev/null @@ -1,1462 +0,0 @@ -import asyncio -from dataclasses import dataclass -import pickle -import os -import pathlib -import math -from copy import deepcopy -from .enums import NodeParameterType -from .processingcontext import ProcessingContext -from .node_visualization_classes import NodeColorScheme -from .logging import get_logger -import re - -from typing import TYPE_CHECKING, TypedDict, Dict, Any, List, Set, Optional, Tuple, Union, Iterable, FrozenSet, Type, Callable - -if TYPE_CHECKING: - from .basenode import BaseNode - - -class ParameterExpressionError(Exception): - def __init__(self, inner_exception): - self.__inner_exception = inner_exception - - def __str__(self): - return f'ParameterExpressionError: {str(self.__inner_exception)}' - - def inner_expection(self): - return self.__inner_exception - - -class ParameterExpressionCastError(ParameterExpressionError): - """ - represents error with type casting of the expression result - """ - pass - - -class LayoutError(RuntimeError): - pass - - -class LayoutReadonlyError(LayoutError): - pass - - -# if TYPE_CHECKING: -# class Parameter(TypedDict): -# type: NodeParameterType -# value: Any -class ParameterHierarchyItem: - def __init__(self): - self.__parent: Optional["ParameterHierarchyItem"] = None - self.__children: Set["ParameterHierarchyItem"] = set() - - def parent(self) -> Optional["ParameterHierarchyItem"]: - return self.__parent - - def set_parent(self, item: Optional["ParameterHierarchyItem"]): - if self.__parent == item: - return - if self.__parent is not None: - assert self in self.__parent.__children - self.__parent._child_about_to_be_removed(self) - self.__parent.__children.remove(self) - self.__parent = item - if self.__parent is not None: - self.__parent.__children.add(self) - self.__parent._child_added(self) - - def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): - """ - callback for just before a child is removed - :param child: - :return: - """ - pass - - def _child_added(self, child: "ParameterHierarchyItem"): - """ - callback for just after child is added - :param child: - :return: - """ - pass - - def children(self) -> FrozenSet["ParameterHierarchyItem"]: - return frozenset(self.__children) - - def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): - if self.__parent is not None: - self.__parent._children_definition_changed([self]) - - def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): - if self.__parent is not None: - self.__parent._children_appearance_changed([self]) - - def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): - if self.__parent is not None: - self.__parent._children_value_changed([self]) - - def visible(self) -> bool: - return False - - -class ParameterHierarchyLeaf(ParameterHierarchyItem): - def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): - return - - def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): - return - - def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): - return - - def _child_added(self, child: "ParameterHierarchyItem"): - raise RuntimeError('cannot add children to ParameterHierarchyLeaf') - - def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): - raise RuntimeError('cannot remove children from ParameterHierarchyLeaf') - - -def evaluate_expression(expression, context: Optional[ProcessingContext]): - try: - return eval(expression, - {'os': os, 're': re, 'pathlib': pathlib, 'Path': pathlib.Path, **{k: getattr(math, k) for k in dir(math) if not k.startswith('_')}}, - context.locals() if context is not None else {}) - except Exception as e: - raise ParameterExpressionError(e) from None - - -class Separator(ParameterHierarchyLeaf): - pass - - -class Parameter(ParameterHierarchyLeaf): - __re_expand_pattern = None - __re_escape_backticks_pattern = None - - class DontChange: - pass - - def __init__(self, param_name: str, param_label: Optional[str], param_type: NodeParameterType, param_val: Any, can_have_expression: bool = True, readonly: bool = False, default_value = None): - super(Parameter, self).__init__() - self.__name = param_name - self.__label = param_label - self.__type = param_type - self.__value = None - self.__menu_items: Optional[Dict[str, str]] = None - self.__menu_items_order: List[str] = [] - self.__vis_when = [] - self.__force_hidden = False - self.__is_readonly = False # set it False until the end of constructor - self.__locked = False # same as readonly, but is settable by user - - self.__expression = None - self.__can_have_expressions = can_have_expression - - if Parameter.__re_expand_pattern is None: - Parameter.__re_expand_pattern = re.compile(r'((? str: - return self.__name - - def _set_name(self, name: str): - """ - this should only be called by layout classes - """ - self.__name = name - if self.parent() is not None: - self.parent()._children_definition_changed([self]) - - def label(self) -> Optional[str]: - return self.__label - - def type(self) -> NodeParameterType: - return self.__type - - def unexpanded_value(self, context: Optional[ProcessingContext] = None): # TODO: why context parameter here? - return self.__value - - def default_value(self): - """ - note that this value will be unexpanded - - :return: - """ - return self.__default_value - - def value(self, context: Optional[ProcessingContext] = None) -> Any: - """ - returns value of this parameter - :param context: optional dict like locals, for expression evaluations - """ - - if self.__expression is not None: - result = evaluate_expression(self.__expression, context) - # check type and cast - try: - if self.__type == NodeParameterType.INT: - result = int(result) - elif self.__type == NodeParameterType.FLOAT: - result = float(result) - elif self.__type == NodeParameterType.STRING and not isinstance(result, str): - result = str(result) - elif self.__type == NodeParameterType.BOOL: - result = bool(result) - except ValueError: - raise ParameterExpressionCastError(f'could not cast {result} to {self.__type.name}') from None - #check limits - if self.__type in (NodeParameterType.INT, NodeParameterType.FLOAT): - if self.__hard_borders[0] is not None and result < self.__hard_borders[0]: - result = self.__hard_borders[0] - if self.__hard_borders[1] is not None and result > self.__hard_borders[1]: - result = self.__hard_borders[1] - return result - - if self.__type != NodeParameterType.STRING: - return self.__value - - # for string parameters we expand expressions in ``, kinda like bash - parts = self.__re_expand_pattern.split(self.__value) - for i, part in enumerate(parts): - if part.startswith('`') and part.endswith('`'): # expression - parts[i] = str(evaluate_expression(self.__re_escape_backticks_pattern.sub('`', part[1:-1]), context)) - else: - parts[i] = self.__re_escape_backticks_pattern.sub('`', part) - return ''.join(parts) - # return self.__re_expand_pattern.sub(lambda m: str(evaluate_expression(m.group(1), context)), self.__value) - - def set_slider_visualization(self, value_min=DontChange, value_max=DontChange): # type: (Union[int, float], Union[int, float]) -> Parameter - """ - set a visual slider's minimum and maximum - this does nothing to the parameter itself, and it's up to parameter renderer to interpret this data - - :return: self to be chained - """ - if self.__type not in (NodeParameterType.INT, NodeParameterType.FLOAT): - raise ParameterDefinitionError('cannot set limits for parameters of types other than INT and FLOAT') - - if self.__type == NodeParameterType.INT: - value_min = int(value_min) - elif self.__type == NodeParameterType.FLOAT: - value_min = float(value_min) - - if self.__type == NodeParameterType.INT: - value_max = int(value_max) - elif self.__type == NodeParameterType.FLOAT: - value_max = float(value_max) - - self.__display_borders = (value_min, value_max) - return self - - def set_value_limits(self, value_min=DontChange, value_max=DontChange): # type: (Union[int, float, None, Type[DontChange]], Union[int, float, None, Type[DontChange]]) -> Parameter - """ - set minimum and maximum values that parameter will enforce - None means no limit (unset limit) - - :return: self to be chained - """ - if self.__type not in (NodeParameterType.INT, NodeParameterType.FLOAT): - raise ParameterDefinitionError('cannot set limits for parameters of types other than INT and FLOAT') - if value_min == self.DontChange: - value_min = self.__hard_borders[0] - elif value_min is not None: - if self.__type == NodeParameterType.INT: - value_min = int(value_min) - elif self.__type == NodeParameterType.FLOAT: - value_min = float(value_min) - if value_max == self.DontChange: - value_max = self.__hard_borders[1] - elif value_max is not None: - if self.__type == NodeParameterType.INT: - value_max = int(value_max) - elif self.__type == NodeParameterType.FLOAT: - value_max = float(value_max) - assert value_min != self.DontChange - assert value_max != self.DontChange - - self.__hard_borders = (value_min, value_max) - if value_min is not None and self.__value < value_min: - self.__value = value_min - if value_max is not None and self.__value > value_max: - self.__value = value_max - return self - - def set_text_multiline(self, syntax_hint: Optional[str] = None): - if self.__type != NodeParameterType.STRING: - raise ParameterDefinitionError('multiline can be only set for string parameters') - self.__string_multiline = True - self.__string_multiline_syntax_hint = syntax_hint - return self - - def is_text_multiline(self): - return self.__string_multiline - - def syntax_hint(self) -> Optional[str]: - """ - may hint an arbitrary string hint to the renderer - it's up to renderer to decide what to do. - common conception is to use language name lowercase, like: python - None means no hint - """ - return self.__string_multiline_syntax_hint - - def display_value_limits(self) -> Tuple[Union[int, float, None], Union[int, float, None]]: - """ - returns a tuple of limits for display purposes. - parameter itself ignores this totally. - it's up to parameter renderer to interpret this info - """ - return self.__display_borders - - def value_limits(self) -> Tuple[Union[int, float, None], Union[int, float, None]]: - """ - returns a tuple of hard limits. - these limits are enforced by the parameter itself - """ - return self.__hard_borders - - def is_readonly(self): - return self.__is_readonly - - def is_locked(self): - return self.__locked - - def set_locked(self, locked: bool): - if locked == self.__locked: - return - self.__locked = locked - if self.parent() is not None: - self.parent()._children_definition_changed([self]) - - def set_value(self, value: Any): - if self.__is_readonly: - raise ParameterReadonly() - if self.__locked: - raise ParameterLocked() - if self.__type == NodeParameterType.FLOAT: - param_value = float(value) - if self.__hard_borders[0] is not None: - param_value = max(param_value, self.__hard_borders[0]) - if self.__hard_borders[1] is not None: - param_value = min(param_value, self.__hard_borders[1]) - elif self.__type == NodeParameterType.INT: - param_value = int(value) - if self.__hard_borders[0] is not None: - param_value = max(param_value, self.__hard_borders[0]) - if self.__hard_borders[1] is not None: - param_value = min(param_value, self.__hard_borders[1]) - elif self.__type == NodeParameterType.BOOL: - param_value = bool(value) - elif self.__type == NodeParameterType.STRING: - param_value = str(value) - else: - raise NotImplementedError() - self.__value = param_value - for other_param in self.__params_referencing_me: - other_param._referencing_param_value_changed(self) - - if self.parent() is not None: - self.parent()._children_value_changed([self]) - - def can_have_expressions(self): - return self.__can_have_expressions - - def has_expression(self): - return self.__expression is not None - - def expression(self): - return self.__expression - - def set_expression(self, expression: Union[str, None]): - """ - sets or removes expression from a parameter - :param expression: either expression code or None means removing expression - :return: - """ - if self.__is_readonly: - raise ParameterReadonly() - if self.__locked: - raise ParameterLocked() - if not self.__can_have_expressions: - raise ParameterCannotHaveExpressions() - if expression != self.__expression: - self.__expression = expression - if self.parent() is not None: - self.parent()._children_definition_changed([self]) - - def remove_expression(self): - self.set_expression(None) - - @classmethod - def python_from_expandable_string(cls, expandable_string, context: Optional[ProcessingContext] = None) -> str: - """ - given string value that may contain backtick expressions return python equivalent - """ - expression_parts = [] - parts = cls.__re_expand_pattern.split(expandable_string) - for i, part in enumerate(parts): - if part.startswith('`') and part.endswith('`'): # expression - maybe_expr = f'({cls.__re_escape_backticks_pattern.sub("`", part[1:-1])})' - try: - val = evaluate_expression(maybe_expr, context) - if not isinstance(val, str): - maybe_expr = f'str{maybe_expr}' # note, maybe_expr is already enclosed in parentheses - except ParameterExpressionError as e: - # we just catch syntax errors, other runtime errors are allowed as real context is set per task - if isinstance(e.inner_expection(), SyntaxError): - maybe_expr = '""' - expression_parts.append(maybe_expr) - else: - val = cls.__re_escape_backticks_pattern.sub('`', part) - if not val: - continue - expression_parts.append(repr(val)) - - return ' + '.join(expression_parts) - - def _referencing_param_value_changed(self, other_parameter): - """ - when a parameter that we are referencing changes - it will report here - :param other_parameter: - """ - # TODO: this now only works with referencing param in visibility condition - # TODO: butt we want general references, including from parameter expressions - # TODO: OOOORR will i need references for expressions at all? - # TODO: references between node bring SOOOO much pain when serializing them separately - if self.__vis_when: - self.__vis_cache = None - if self.parent() is not None and isinstance(self.parent(), ParametersLayoutBase): - self.parent()._children_appearance_changed([self]) - - def set_hidden(self, hidden): - self.__force_hidden = hidden - - def visible(self) -> bool: - if self.__force_hidden: - return False - if self.__vis_cache is not None: - return self.__vis_cache - if self.__vis_when: - for other_param, op, value in self.__vis_when: - if op == '==' and other_param.value() != value \ - or op == '!=' and other_param.value() == value \ - or op == '>' and other_param.value() <= value \ - or op == '>=' and other_param.value() < value \ - or op == '<' and other_param.value() >= value \ - or op == '<=' and other_param.value() > value \ - or op == 'in' and other_param.value() not in value \ - or op == 'not in' and other_param.value() in value: - self.__vis_cache = False - return False - self.__vis_cache = True - return True - - def _add_referencing_me(self, other_parameter: "Parameter"): - """ - other_parameter MUST belong to the same node to avoid cross-node references - :param other_parameter: - :return: - """ - assert self.has_same_parent(other_parameter), 'references MUST belong to the same node' - self.__params_referencing_me.add(other_parameter) - - def _remove_referencing_me(self, other_parameter: "Parameter"): - assert other_parameter in self.__params_referencing_me - self.__params_referencing_me.remove(other_parameter) - - def references(self) -> Tuple["Parameter", ...]: - """ - returns tuple of parameters referenced by this parameter's definition - static/dynamic references from expressions ARE NOT INCLUDED - they are not parameter's DEFINITION - currently the only thing that can be a reference is parameter from visibility conditions - """ - return tuple(x[0] for x in self.__vis_when) - - def visibility_conditions(self) -> Tuple[Tuple["Parameter", str, Union[bool, int, float, str, tuple]], ...]: - return tuple(self.__vis_when) - - def append_visibility_condition(self, other_param: "Parameter", condition: str, value: Union[bool, int, float, str, tuple]) -> "Parameter": - """ - condition currently can only be a simplest - :param other_param: - :param condition: - :param value: - :return: self to allow easy chaining - """ - allowed_conditions = ('==', '!=', '>=', '<=', '<', '>', 'in', 'not in') - if condition not in allowed_conditions: - raise ParameterDefinitionError(f'condition must be one of: {", ".join(x for x in allowed_conditions)}') - if condition in ('in', 'not in') and not isinstance(value, tuple): - raise ParameterDefinitionError('for in/not in conditions value must be a tuple of possible values') - elif condition not in ('in', 'not in') and isinstance(value, tuple): - raise ParameterDefinitionError('value can be tuple only for in/not in conditions') - - otype = other_param.type() - if otype == NodeParameterType.INT: - if not isinstance(value, tuple): - value = int(value) - elif otype == NodeParameterType.BOOL: - if not isinstance(value, tuple): - value = bool(value) - elif otype == NodeParameterType.FLOAT: - if not isinstance(value, tuple): - value = float(value) - elif otype == NodeParameterType.STRING: - if not isinstance(value, tuple): - value = str(value) - else: # for future - raise ParameterDefinitionError(f'cannot add visibility condition check based on this type of parameters: {otype}') - self.__vis_when.append((other_param, condition, value)) - other_param._add_referencing_me(self) - self.__vis_cache = None - - self.parent()._children_definition_changed([self]) - return self - - def add_menu(self, menu_items_pairs) -> "Parameter": - """ - adds UI menu to parameter param_name - :param menu_items_pairs: dict of label -> value for parameter menu. type of value MUST match type of parameter param_name. type of label MUST be string - :return: self to allow easy chaining - """ - # sanity check and regroup - my_type = self.type() - menu_items = {} - menu_order = [] - for key, value in menu_items_pairs: - menu_items[key] = value - menu_order.append(key) - if not isinstance(key, str): - raise ParameterDefinitionError('menu label type must be string') - if my_type == NodeParameterType.INT and not isinstance(value, int): - raise ParameterDefinitionError(f'wrong menu value for int parameter "{self.name()}"') - elif my_type == NodeParameterType.BOOL and not isinstance(value, bool): - raise ParameterDefinitionError(f'wrong menu value for bool parameter "{self.name()}"') - elif my_type == NodeParameterType.FLOAT and not isinstance(value, float): - raise ParameterDefinitionError(f'wrong menu value for float parameter "{self.name()}"') - elif my_type == NodeParameterType.STRING and not isinstance(value, str): - raise ParameterDefinitionError(f'wrong menu value for string parameter "{self.name()}"') - - self.__menu_items = menu_items - self.__menu_items_order = menu_order - self.parent()._children_definition_changed([self]) - return self - - def has_menu(self): - return self.__menu_items is not None - - def get_menu_items(self): - return self.__menu_items_order, self.__menu_items - - def has_same_parent(self, other_parameter: "Parameter") -> bool: - """ - finds if somewhere down the hierarchy there is a shared parent of self and other_parameter - """ - my_ancestry_line = set() - ancestor = self - while ancestor is not None: - my_ancestry_line.add(ancestor) - ancestor = ancestor.parent() - - ancestor = other_parameter - while ancestor is not None: - if ancestor in my_ancestry_line: - return True - ancestor = ancestor.parent() - return False - - def nodeui(self) -> Optional["NodeUi"]: - """ - returns parent nodeui if it is the root of current hierarchy. otherwise returns None - """ - ancestor = self - while ancestor is not None: - if isinstance(ancestor, NodeUi): - return ancestor - ancestor = ancestor.parent() - return None - - def __setstate__(self, state): - """ - overriden for easier parameter class iterations during active development. - otherwise all node ui data should be recreated from zero in DB every time a change is made - """ - # this init here only to init new shit when unpickling old parameters without resetting DB all the times - self.__init__('', '', NodeParameterType.INT, 0, False) - self.__dict__.update(state) - - -class ParameterError(RuntimeError): - pass - - -class ParameterDefinitionError(ParameterError): - pass - - -class ParameterNotFound(ParameterError): - pass - - -class ParameterNameCollisionError(ParameterError): - pass - - -class ParameterReadonly(ParameterError): - pass - - -class ParameterLocked(ParameterError): - pass - - -class ParameterCannotHaveExpressions(ParameterError): - pass - - -class ParametersLayoutBase(ParameterHierarchyItem): - def __init__(self): - super(ParametersLayoutBase, self).__init__() - self.__parameters: Dict[str, Parameter] = {} # just for quicker access - self.__layouts: Set[ParametersLayoutBase] = set() - self.__block_ui_callbacks = False - - def initializing_interface_lock(self): - return self.block_ui_callbacks() - - def block_ui_callbacks(self): - class _iiLock: - def __init__(self, lockable): - self.__nui = lockable - self.__prev_state = False - - def __enter__(self): - self.__prev_state = self.__nui._ParametersLayoutBase__block_ui_callbacks - self.__nui._ParametersLayoutBase__block_ui_callbacks = True - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__nui._ParametersLayoutBase__block_ui_callbacks = self.__prev_state - - return _iiLock(self) - - def _is_initialize_lock_set(self): - return self.__block_ui_callbacks - - def add_parameter(self, new_parameter: Parameter): - self.add_generic_leaf(new_parameter) - - def add_generic_leaf(self, item: ParameterHierarchyLeaf): - if not self._is_initialize_lock_set(): - raise LayoutError('initializing interface not inside initializing_interface_lock') - item.set_parent(self) - - def add_layout(self, new_layout: "ParametersLayoutBase"): - if not self._is_initialize_lock_set(): - raise LayoutError('initializing interface not inside initializing_interface_lock') - new_layout.set_parent(self) - - def items(self, recursive=False) -> Iterable["ParameterHierarchyItem"]: - for child in self.children(): - yield child - if not recursive: - continue - elif isinstance(child, ParametersLayoutBase): - for child_param in child.parameters(recursive=recursive): - yield child_param - - def parameters(self, recursive=False) -> Iterable[Parameter]: - for item in self.items(recursive=recursive): - if isinstance(item, Parameter): - yield item - - def parameter(self, name: str) -> Parameter: - if name in self.__parameters: - return self.__parameters[name] - for layout in self.__layouts: - try: - return layout.parameter(name) - except ParameterNotFound: - continue - raise ParameterNotFound(f'parameter "{name}" not found in layout hierarchy') - - def visible(self) -> bool: - return len(self.children()) != 0 and any(x.visible() for x in self.items()) - - def _child_added(self, child: "ParameterHierarchyItem"): - super(ParametersLayoutBase, self)._child_added(child) - if isinstance(child, Parameter): - # check global parameter name uniqueness - rootparent = self - while isinstance(rootparent.parent(), ParametersLayoutBase): - rootparent = rootparent.parent() - if child.name() in (x.name() for x in rootparent.parameters(recursive=True) if x != child): - raise ParameterNameCollisionError('cannot add parameters with the same name to the same layout hierarchy') - self.__parameters[child.name()] = child - elif isinstance(child, ParametersLayoutBase): - self.__layouts.add(child) - # check global parameter name uniqueness - rootparent = self - while isinstance(rootparent.parent(), ParametersLayoutBase): - rootparent = rootparent.parent() - new_params = list(child.parameters(recursive=True)) - existing_params = set(x.name() for x in rootparent.parameters(recursive=True) if x not in new_params) - for new_param in new_params: - if new_param.name() in existing_params: - raise ParameterNameCollisionError('cannot add parameters with the same name to the same layout hierarchy') - - def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): - if isinstance(child, Parameter): - del self.__parameters[child.name()] - elif isinstance(child, ParametersLayoutBase): - self.__layouts.remove(child) - super(ParametersLayoutBase, self)._child_about_to_be_removed(child) - - def _children_definition_changed(self, changed_children: Iterable["ParameterHierarchyItem"]): - """ - :param children: - :return: - """ - super(ParametersLayoutBase, self)._children_definition_changed(changed_children) - # check self.__parameters consistency - reversed_parameters: Dict[Parameter, str] = {v: k for k, v in self.__parameters.items()} - for child in changed_children: - if not isinstance(child, Parameter): - continue - if child in reversed_parameters: - del self.__parameters[reversed_parameters[child]] - self.__parameters[child.name()] = child - - def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): - """ - :param children: - :return: - """ - super(ParametersLayoutBase, self)._children_value_changed(children) - - def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): - super(ParametersLayoutBase, self)._children_appearance_changed(children) - - def relative_size_for_child(self, child: ParameterHierarchyItem) -> Tuple[float, float]: - """ - get relative size of a child in this layout - the exact interpretation of size is up to subclass to decide - :param child: - :return: - """ - raise NotImplementedError() - - -class OrderedParametersLayout(ParametersLayoutBase): - def __init__(self): - super(OrderedParametersLayout, self).__init__() - self.__parameter_order: List[ParameterHierarchyItem] = [] - - def _child_added(self, child: "ParameterHierarchyItem"): - super(OrderedParametersLayout, self)._child_added(child) - self.__parameter_order.append(child) - - def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): - self.__parameter_order.remove(child) - super(OrderedParametersLayout, self)._child_about_to_be_removed(child) - - def items(self, recursive=False): - """ - unlike base method, we need to return parameters in order - :param recursive: - :return: - """ - for child in self.__parameter_order: - yield child - if not recursive: - continue - elif isinstance(child, ParametersLayoutBase): - for child_param in child.items(recursive=recursive): - yield child_param - - def relative_size_for_child(self, child: ParameterHierarchyItem) -> Tuple[float, float]: - """ - get relative size of a child in this layout - the exact interpretation of size is up to subclass to decide - :param child: - :return: - """ - assert child in self.children() - return 1.0, 1.0 - - -class VerticalParametersLayout(OrderedParametersLayout): - """ - simple vertical parameter layout. - """ - pass - - -class CollapsableVerticalGroup(VerticalParametersLayout): - """ - a vertical parameter layout to be drawn as collapsable block - """ - def __init__(self, group_name, group_label): - super(CollapsableVerticalGroup, self).__init__() - - # for now it's here just to ensure name uniqueness. in future - maybe store collapsed state - self.__unused_param = Parameter(group_name, group_name, NodeParameterType.BOOL, True) - - self.__group_name = group_name - self.__group_label = group_label - - def is_collapsed(self): - return True - - def name(self): - return self.__group_name - - def label(self): - return self.__group_label - - -class OneLineParametersLayout(OrderedParametersLayout): - """ - horizontal parameter layout. - unlike vertical, this one has to keep track of portions of line it's parameters are taking - parameters of this group should be rendered in one line - """ - def __init__(self): - super(OneLineParametersLayout, self).__init__() - self.__hsizes = {} - - def _children_appearance_changed(self, children: Iterable["ParameterHierarchyItem"]): - super(ParametersLayoutBase, self)._children_appearance_changed(children) - self.__hsizes = {} - - def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): - super(OneLineParametersLayout, self)._children_definition_changed(children) - self.__hsizes = {} - - def relative_size_for_child(self, child: ParameterHierarchyItem) -> Tuple[float, float]: - assert child in self.children() - if child not in self.__hsizes: - self._update_hsizes() - assert child in self.__hsizes - return self.__hsizes[child], 1.0 - - def _update_hsizes(self): - self.__hsizes = {} - totalitems = 0 - for item in self.items(): - if item.visible(): - totalitems += 1 - if totalitems == 0: - uniform_size = 1.0 - else: - uniform_size = 1.0 / float(totalitems) - for item in self.items(): - self.__hsizes[item] = uniform_size - - -class MultiGroupLayout(OrderedParametersLayout): - """ - this group can dynamically spawn more parameters according to it's template - spawning more parameters does NOT count as definition change - """ - def __init__(self, name, label=None): - super(MultiGroupLayout, self).__init__() - self.__template: Union[ParametersLayoutBase, Parameter, None] = None - if label is None: - label = 'count' - self.__count_param = Parameter(name, label, NodeParameterType.INT, 0, can_have_expression=False) - self.__count_param.set_parent(self) - self.__last_count = 0 - self.__nested_indices = [] - - def nested_indices(self): - """ - if a multiparam is inside other multiparams - those multiparams should add their indices - to this one, so that this multiparam will be able to uniquely and predictable name it's parameters - """ - return tuple(self.__nested_indices) - - def __append_nested_index(self, index: int): - """ - this should be called only when a multiparam is instanced by another multiparam - """ - self.__nested_indices.append(index) - - def set_spawning_template(self, layout: ParametersLayoutBase): - self.__template = deepcopy(layout) - - def add_layout(self, new_layout: "ParametersLayoutBase"): - """ - this function is unavailable cuz of the nature of this layout - """ - raise LayoutError('NO') - - def add_parameter(self, new_parameter: Parameter): - """ - this function is unavailable cuz of the nature of this layout - """ - raise LayoutError('NO') - - def add_template_instance(self): - self.__count_param.set_value(self.__count_param.value() + 1) - - def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): - - for child in children: - if child == self.__count_param: - break - else: - super(MultiGroupLayout, self)._children_value_changed(children) - return - if self.__count_param.value() < 0: - self.__count_param.set_value(0) - super(MultiGroupLayout, self)._children_value_changed(children) - return - - new_count = self.__count_param.value() - if self.__last_count < new_count: - if self.__template is None: - raise LayoutError('template is not set') - for _ in range(new_count - self.__last_count): - # note: the check below is good, but it's not needed currently, cuz visibility condition on append checks common parent - # and nodes not from template do not share parents with template, so that prevents external references - for param in self.__template.parameters(recursive=True): - # sanity check - for now we only support references within the same template block only - for ref_param in param.references(): - if not ref_param.has_same_parent(param): - raise ParameterDefinitionError('Parameters within MultiGroupLayout\'s template currently cannot reference outer parameters') - ## - new_layout = deepcopy(self.__template) - i = len(self.children()) - 1 - for param in new_layout.parameters(recursive=True): - param._set_name(param.name() + '_' + '.'.join(str(x) for x in (*self.nested_indices(), i))) - parent = param.parent() - if isinstance(parent, MultiGroupLayout): - for idx in self.nested_indices(): - parent.__append_nested_index(idx) - parent.__append_nested_index(i) - new_layout.set_parent(self) - elif self.__last_count > self.__count_param.value(): - for _ in range(self.__last_count - new_count): - instances = list(self.items(recursive=False)) - assert len(instances) > 1 - instances[-1].set_parent(None) - self.__last_count = new_count - super(MultiGroupLayout, self)._children_value_changed(children) - - def _child_added(self, child: "ParameterHierarchyItem"): - super(MultiGroupLayout, self)._child_added(child) - - def _child_about_to_be_removed(self, child: "ParameterHierarchyItem"): - super(MultiGroupLayout, self)._child_about_to_be_removed(child) - - -class _SpecialOutputCountChangingLayout(VerticalParametersLayout): - def __init__(self, nodeui: "NodeUi", parameter_name, parameter_label): - super(_SpecialOutputCountChangingLayout, self).__init__() - self.__my_nodeui = nodeui - newparam = Parameter(parameter_name, parameter_label, NodeParameterType.INT, 2, can_have_expression=False) - newparam.set_value_limits(2) - with self.initializing_interface_lock(): - self.add_parameter(newparam) - - def add_layout(self, new_layout: "ParametersLayoutBase"): - """ - this function is unavailable cuz of the nature of this layout - """ - raise LayoutError('NO') - - def add_parameter(self, new_parameter: Parameter): - """ - this function is unavailable cuz of the nature of this layout - """ - if len(list(self.parameters())) > 0: - raise LayoutError('NO') - super(_SpecialOutputCountChangingLayout, self).add_parameter(new_parameter) - - def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): - # we expect this special layout to have only one single specific child - child = None - for child in children: - break - if child is None: - return - assert isinstance(child, Parameter) - new_num_outputs = child.value() - num_outputs = len(self.__my_nodeui.outputs_names()) - if num_outputs == new_num_outputs: - return - - if num_outputs < new_num_outputs: - for i in range(num_outputs, new_num_outputs): - self.__my_nodeui._add_output_unsafe(f'output{i}') - else: # num_outputs > new_num_outputs - for _ in range(new_num_outputs, num_outputs): - self.__my_nodeui._remove_last_output_unsafe() - self.__my_nodeui._outputs_definition_changed() - - -@dataclass -class ParameterFullValue: - unexpanded_value: Union[int, float, str, bool] - expression: Optional[str] - - -class NodeUiError(RuntimeError): - pass - - -class NodeUiDefinitionError(RuntimeError): - pass - - -class NodeUi(ParameterHierarchyItem): - def __init__(self, attached_node: "BaseNode"): - super(NodeUi, self).__init__() - self.__logger = get_logger('scheduler.nodeUI') - self.__parameter_layout = VerticalParametersLayout() - self.__parameter_layout.set_parent(self) - self.__attached_node: Optional[BaseNode] = attached_node - self.__block_ui_callbacks = False - self.__lock_ui_readonly = False - self.__postpone_ui_callbacks = False - self.__postponed_callbacks = None - self.__inputs_names = ('main',) - self.__outputs_names = ('main',) - - self.__groups_stack = [] - - self.__have_output_parameter_set: bool = False - - # default colorscheme - self.__color_scheme = NodeColorScheme() - self.__color_scheme.set_main_color(0.1882, 0.2510, 0.1882) # dark-greenish - - def is_attached_to_node(self): - return self.__attached_node is not None - - def attach_to_node(self, node: "BaseNode"): - self.__attached_node = node - - def color_scheme(self): - return self.__color_scheme - - def main_parameter_layout(self): - return self.__parameter_layout - - def parent(self) -> Optional["ParameterHierarchyItem"]: - return None - - def set_parent(self, item: Optional["ParameterHierarchyItem"]): - if item is not None: - raise RuntimeError('NodeUi class is supposed to be tree root') - - def initializing_interface_lock(self): - return self.block_ui_callbacks() - - def block_ui_callbacks(self): - class _iiLock: - def __init__(self, lockable): - self.__nui = lockable - self.__prev_state = None - - def __enter__(self): - self.__prev_state = self.__nui._NodeUi__block_ui_callbacks - self.__nui._NodeUi__block_ui_callbacks = True - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__nui._NodeUi__block_ui_callbacks = self.__prev_state - - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - return _iiLock(self) - - def lock_interface_readonly(self): - raise NotImplementedError("read trello task, read TODO. this do NOT work multitheaded, leads to permalocks, needs rethinking") - class _roLock: - def __init__(self, lockable): - self.__nui = lockable - self.__prev_state = None - - def __enter__(self): - self.__prev_state = self.__nui._NodeUi__lock_ui_readonly - self.__nui._NodeUi__lock_ui_readonly = True - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__nui._NodeUi__lock_ui_readonly = self.__prev_state - - return _roLock(self) - - def postpone_ui_callbacks(self): - """ - use this in with-statement - for mass change of parameters it may be more efficient to perform changes in batches - """ - class _iiPostpone: - def __init__(self, nodeui): - self.__nui = nodeui - self.__val = None - - def __enter__(self): - if not self.__nui._NodeUi__postpone_ui_callbacks: - assert self.__nui._NodeUi__postponed_callbacks is None - self.__val = self.__nui._NodeUi__postpone_ui_callbacks - self.__nui._NodeUi__postpone_ui_callbacks = True - # otherwise: already blocked - we are in nested block, ignore - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.__val is None: - return - assert not self.__val # nested block should do nothing - self.__nui._NodeUi__postpone_ui_callbacks = self.__val - if self.__nui._NodeUi__postponed_callbacks is not None: - self.__nui._NodeUi__ui_callback(self.__nui._NodeUi__postponed_callbacks) - self.__nui._NodeUi__postponed_callbacks = None - - return _iiPostpone(self) - - class _slwrapper: - def __init__(self, ui: "NodeUi", layout_creator, layout_creator_kwargs=None): - self.__ui = ui - self.__layout_creator = layout_creator - self.__layout_creator_kwargs = layout_creator_kwargs or {} - - def __enter__(self): - new_layout = self.__layout_creator(**self.__layout_creator_kwargs) - self.__ui._NodeUi__groups_stack.append(new_layout) - with self.__ui._NodeUi__parameter_layout.initializing_interface_lock(): - self.__ui._NodeUi__parameter_layout.add_layout(new_layout) - - def __exit__(self, exc_type, exc_val, exc_tb): - layout = self.__ui._NodeUi__groups_stack.pop() - self.__ui._add_layout(layout) - - def parameters_on_same_line_block(self): - """ - use it in with statement - :return: - """ - return self.parameter_layout_block(OneLineParametersLayout) - - def parameter_layout_block(self, parameter_layout_producer: Callable[[], ParametersLayoutBase]): - """ - arbitrary simple parameter override block - use it in with statement - :return: - """ - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - return NodeUi._slwrapper(self, parameter_layout_producer) - - def add_parameter_to_control_output_count(self, parameter_name: str, parameter_label: str): - """ - a very special function for a very special case when you want the number of outputs to be controlled - by a parameter - - from now on output names will be: 'main', 'output1', 'output2', ... - - :return: - """ - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - if self.__have_output_parameter_set: - raise NodeUiDefinitionError('there can only be one parameter to control output count') - self.__have_output_parameter_set = True - self.__outputs_names = ('main', 'output1') - - with self.parameter_layout_block(lambda: _SpecialOutputCountChangingLayout(self, parameter_name, parameter_label)): - # no need to do anything, with block will add that layout to stack, and parameter is created in that layout's constructor - layout = self.current_layout() - # this layout should always have exactly one parameter - assert len(list(layout.parameters())) == 1, f'oh no, {len(list(layout.parameters()))}' - return layout.parameter(parameter_name) - - def multigroup_parameter_block(self, name: str, label: Optional[str] = None): - """ - use it in with statement - creates a block like multiparameter block in houdini - any parameters added will be actually added to template to be instanced later as needed - :return: - """ - class _slwrapper_multi: - def __init__(self, ui: "NodeUi", name: str, label: Optional[str] = None): - self.__ui = ui - self.__new_layout = None - self.__name = name - self.__label = label - - def __enter__(self): - self.__new_layout = VerticalParametersLayout() - self.__ui._NodeUi__groups_stack.append(self.__new_layout) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - assert self.__ui._NodeUi__groups_stack.pop() == self.__new_layout - with self.__ui._NodeUi__parameter_layout.initializing_interface_lock(): - multi_layout = MultiGroupLayout(self.__name, self.__label) - with multi_layout.initializing_interface_lock(): - multi_layout.set_spawning_template(self.__new_layout) - self.__ui._add_layout(multi_layout) - - def multigroup(self) -> VerticalParametersLayout: - return self.__new_layout - - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - return _slwrapper_multi(self, name, label) - - def current_layout(self): - """ - get current layout to which add_parameter would add parameter - this can be main nodeUI's layout, but can be something else, if we are in some with block, - like for ex: collapsable_group_block or parameters_on_same_line_block - - :return: - """ - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - layout = self.__parameter_layout - if len(self.__groups_stack) != 0: - layout = self.__groups_stack[-1] - return layout - - def collapsable_group_block(self, group_name: str, group_label: str = ''): - """ - use it in with statement - creates a visually distinct group of parameters that renderer should draw as a collapsable block - - :return: - """ - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - return NodeUi._slwrapper(self, CollapsableVerticalGroup, {'group_name': group_name, 'group_label': group_label}) - - def _add_layout(self, new_layout): - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - layout = self.__parameter_layout - if len(self.__groups_stack) != 0: - layout = self.__groups_stack[-1] - with layout.initializing_interface_lock(): - layout.add_layout(new_layout) - - def add_parameter(self, param_name: str, param_label: Optional[str], param_type: NodeParameterType, param_val: Any, can_have_expressions: bool = True, readonly: bool = False) -> Parameter: - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - layout = self.__parameter_layout - if len(self.__groups_stack) != 0: - layout = self.__groups_stack[-1] - with layout.initializing_interface_lock(): - newparam = Parameter(param_name, param_label, param_type, param_val, can_have_expressions, readonly) - layout.add_parameter(newparam) - return newparam - - def add_separator(self): - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - layout = self.__parameter_layout - if len(self.__groups_stack) != 0: - layout = self.__groups_stack[-1] - with layout.initializing_interface_lock(): - newsep = Separator() - layout.add_generic_leaf(newsep) - return newsep - - def add_input(self, input_name): - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - if input_name not in self.__inputs_names: - self.__inputs_names += (input_name,) - - def _add_output_unsafe(self, output_name): - if output_name not in self.__outputs_names: - self.__outputs_names += (output_name,) - - def add_output(self, output_name): - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - if self.__have_output_parameter_set: - raise NodeUiDefinitionError('cannot add outputs when output count is controlled by a parameter') - return self._add_output_unsafe(output_name) - - def _remove_last_output_unsafe(self): - if len(self.__outputs_names) < 2: - return - self.__outputs_names = self.__outputs_names[:-1] - - def remove_last_output(self): - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if not self.__block_ui_callbacks: - raise NodeUiDefinitionError('initializing NodeUi interface not inside initializing_interface_lock') - if self.__have_output_parameter_set: - raise NodeUiDefinitionError('cannot add outputs when output count is controlled by a parameter') - return self._remove_last_output_unsafe() - - def add_output_for_spawned_tasks(self): - return self.add_output('spawned') - - def _children_definition_changed(self, children: Iterable["ParameterHierarchyItem"]): - self.__ui_callback(definition_changed=True) - - def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): - self.__ui_callback(definition_changed=False) - - def _outputs_definition_changed(self): #TODO: not entirely sure how safe this is right now - self.__ui_callback(definition_changed=True) - - def __ui_callback(self, definition_changed=False): - if self.__lock_ui_readonly: - raise LayoutReadonlyError() - if self.__postpone_ui_callbacks: - # so we save definition_changed to __postponed_callbacks - self.__postponed_callbacks = self.__postponed_callbacks or definition_changed - return - - if self.__attached_node is not None and not self.__block_ui_callbacks: - self.__attached_node._ui_changed(definition_changed) - - def inputs_names(self) -> Tuple[str]: - return self.__inputs_names - - def outputs_names(self) -> Tuple[str]: - return self.__outputs_names - - def parameter(self, param_name: str) -> Parameter: - return self.__parameter_layout.parameter(param_name) - - def parameters(self) -> Iterable[Parameter]: - return self.__parameter_layout.parameters(recursive=True) - - def items(self, recursive=False) -> Iterable[ParameterHierarchyItem]: - return self.__parameter_layout.items(recursive=recursive) - - def set_parameters_batch(self, parameters: Dict[str, ParameterFullValue]): - """ - If signal blocking is needed - caller can do it - - for now it's implemented the stupid way - """ - names_to_set = list(parameters.keys()) - names_to_set.append(None) - something_set_this_iteration = False - parameters_were_postponed = False - for param_name in names_to_set: - if param_name is None: - if parameters_were_postponed: - if not something_set_this_iteration: - self.__logger.warning(f'failed to set all parameters!') - break - names_to_set.append(None) - something_set_this_iteration = False - continue - assert isinstance(param_name, str) - param = self.parameter(param_name) - if param is None: - parameters_were_postponed = True - continue - param_value = parameters[param_name] - try: - param.set_value(param_value.unexpanded_value) - except (ParameterReadonly, ParameterLocked): - # if value is already correct - just skip - if param.unexpanded_value() != param_value.unexpanded_value: - self.__logger.error(f'unable to set value for "{param_name}"') - # shall we just ignore the error? - except ParameterError as e: - self.__logger.error(f'failed to set value for "{param_name}" because {repr(e)}') - if param.can_have_expressions(): - try: - param.set_expression(param_value.expression) - except (ParameterReadonly, ParameterLocked): - # if value is already correct - just skip - if param.expression() != param_value.expression: - self.__logger.error(f'unable to set expression for "{param_name}"') - # shall we just ignore the error? - except ParameterError as e: - self.__logger.error(f'failed to set expression for "{param_name}" because {repr(e)}') - elif param_value.expression is not None: - self.__logger.error(f'parameter "{param_name}" cannot have expressions, yet expression is stored for it') - - something_set_this_iteration = True - - def __deepcopy__(self, memo): - cls = self.__class__ - crap = cls.__new__(cls) - newdict = self.__dict__.copy() - newdict['_NodeUi__attached_node'] = None - newdict['_NodeUi__lock_ui_readonly'] = False - assert id(self) not in memo - memo[id(self)] = crap # to avoid recursion, though manual tells us to treat memo as opaque object - for k, v in newdict.items(): - crap.__dict__[k] = deepcopy(v, memo) - return crap - - def __setstate__(self, state): - ensure_attribs = { # this exists only for the ease of upgrading NodeUi classes during development - '_NodeUi__lock_ui_readonly': False, - '_NodeUi__postpone_ui_callbacks': False - } - self.__dict__.update(state) - for attrname, default_value in ensure_attribs.items(): - if not hasattr(self, attrname): - setattr(self, attrname, default_value) - - def serialize(self) -> bytes: - """ - note - this serialization disconnects the node to which this UI is connected - :return: - """ - obj = deepcopy(self) - assert obj.__attached_node is None - return pickle.dumps(obj) - - async def serialize_async(self) -> bytes: - return await asyncio.get_event_loop().run_in_executor(None, self.serialize) - - def __repr__(self): - return 'NodeUi: ' + ', '.join(('%s: %s' % (x.name() if isinstance(x, Parameter) else '-layout-', x) for x in self.__parameter_layout.items())) - - @classmethod - def deserialize(cls, data: bytes) -> "NodeUi": - return pickle.loads(data) - - @classmethod - async def deserialize_async(cls, data: bytes) -> "NodeUi": - return await asyncio.get_event_loop().run_in_executor(None, cls.deserialize, data) diff --git a/src/lifeblood_testing_common/integration_common.py b/src/lifeblood_testing_common/integration_common.py index 56863616..b5be45bc 100644 --- a/src/lifeblood_testing_common/integration_common.py +++ b/src/lifeblood_testing_common/integration_common.py @@ -9,9 +9,10 @@ from lifeblood.config import Config from lifeblood_testing_common.common import create_default_scheduler from lifeblood.nethelpers import get_default_addr -from lifeblood.simple_worker_pool import WorkerPool +from lifeblood.simple_worker_pool import SimpleWorkerPool from lifeblood.net_messages.address import AddressChain from lifeblood.taskspawn import NewTask +from lifeblood.worker_pool_message_processor import WorkerPoolMessageProcessor from lifeblood.worker_resource_definition import WorkerResourceDefinition, WorkerDeviceTypeDefinition from lifeblood.enums import SpawnStatus @@ -63,21 +64,23 @@ async def asyncSetUp(self): device_type_definitions=self._device_type_definitions(), helpers_minimal_idle_to_ensure=self._minimal_helper_idle_to_ensure(), ) - self.worker_pool = WorkerPool( + self.worker_pool = SimpleWorkerPool( scheduler_address=AddressChain(f'{get_default_addr()}:{test_server_port2}'), minimal_idle_to_ensure=self._minimal_idle_to_ensure(), minimal_total_to_ensure=self._minimal_total_to_ensure(), maximum_total=self._maximum_total(), - config=self._worker_config() + config=self._worker_config(), + message_processor_factory=WorkerPoolMessageProcessor, ) self.worker_pool2 = None if worker_config2 := self._worker_config2(): - self.worker_pool2 = WorkerPool( + self.worker_pool2 = SimpleWorkerPool( scheduler_address=AddressChain(f'{get_default_addr()}:{test_server_port2}'), minimal_idle_to_ensure=self._minimal_idle_to_ensure(), minimal_total_to_ensure=self._minimal_total_to_ensure(), maximum_total=self._maximum_total(), - config=worker_config2 + config=worker_config2, + message_processor_factory=WorkerPoolMessageProcessor, ) await self.scheduler.start() diff --git a/src/lifeblood_testing_common/nodes_common.py b/src/lifeblood_testing_common/nodes_common.py index f19d2334..c2902699 100644 --- a/src/lifeblood_testing_common/nodes_common.py +++ b/src/lifeblood_testing_common/nodes_common.py @@ -97,7 +97,7 @@ def set_input_name(self, name: str): self.__task_dict['node_input_name'] = self.__input_name def get_context_for(self, node: BaseNode) -> ProcessingContext: - return ProcessingContext(node, self.task_dict(), {}) + return ProcessingContext(node.name(), node.label(), node.get_ui(), self.task_dict(), {}) def task_dict(self) -> dict: return {**self.__task_dict, **{ @@ -233,8 +233,10 @@ async def _helper_test_worker_node(self, workers = [] for i in range(worker_count): - worker = Worker(sched.server_message_addresses()[0], - scheduler_ping_interval=9001) + worker = Worker( + sched.server_message_addresses()[0], + scheduler_ping_interval=9001, + ) await worker.start() workers.append(worker) @@ -336,7 +338,7 @@ async def _logic(scheduler, workers, tmp_script_path, done_waiter): 'outimage': out_exr_path, 'frames': [1, 2, 3] } - res = node.process_task(ProcessingContext(node, {'attributes': serialize_attributes_core(start_attrs)}, {})) + res = node.process_task(ProcessingContext(node.name(), node.label(), node.get_ui(), {'attributes': serialize_attributes_core(start_attrs)}, {})) ij = res.invocation_job self.assertTrue(ij is not None) @@ -367,7 +369,7 @@ async def _logic(scheduler, workers, tmp_script_path, done_waiter): await asyncio.wait([done_waiter], timeout=30) # now postprocess task - res = node.postprocess_task(ProcessingContext(node, {'attributes': serialize_attributes_core({ + res = node.postprocess_task(ProcessingContext(node.name(), node.label(), node.get_ui(), {'attributes': serialize_attributes_core({ **start_attrs, **updated_attrs })}, {})) @@ -440,7 +442,7 @@ async def _logic(scheduler, workers, script_path, done_waiter): for param, val in params.items(): node.set_param_value(param, val) - res = node.process_task(ProcessingContext(node, {'attributes': serialize_attributes_core(task_attrs)}, {})) + res = node.process_task(ProcessingContext(node.name(), node.label(), node.get_ui(), {'attributes': serialize_attributes_core(task_attrs)}, {})) if res.attributes_to_set: updated_attrs.update(res.attributes_to_set) @@ -477,7 +479,7 @@ async def _logic(scheduler, workers, script_path, done_waiter): await asyncio.wait([done_waiter], timeout=30) # now postprocess task - res = node.postprocess_task(ProcessingContext(node, {'attributes': serialize_attributes_core({ + res = node.postprocess_task(ProcessingContext(node.name(), node.label(), node.get_ui(), {'attributes': serialize_attributes_core({ **task_attrs, **updated_attrs })}, {})) diff --git a/src/lifeblood_viewer/connection_worker.py b/src/lifeblood_viewer/connection_worker.py index 44e073ef..50c8f0a7 100644 --- a/src/lifeblood_viewer/connection_worker.py +++ b/src/lifeblood_viewer/connection_worker.py @@ -3,7 +3,7 @@ import json import time -from lifeblood.uidata import NodeUi +from lifeblood.node_ui import NodeUi from lifeblood.invocationjob import InvocationJob from lifeblood.nethelpers import address_to_ip_port, get_default_addr from lifeblood import logging @@ -11,7 +11,7 @@ from lifeblood.broadcasting import await_broadcast from lifeblood.config import get_config from lifeblood.exceptions import UiClientOperationFailed -from lifeblood.uidata import Parameter +from lifeblood.node_parameters import Parameter from lifeblood.taskspawn import NewTask from lifeblood.snippets import NodeSnippetData from lifeblood.defaults import ui_port diff --git a/src/lifeblood_viewer/graphics_items/graphics_items.py b/src/lifeblood_viewer/graphics_items/graphics_items.py index fe0a26af..18e58989 100644 --- a/src/lifeblood_viewer/graphics_items/graphics_items.py +++ b/src/lifeblood_viewer/graphics_items/graphics_items.py @@ -7,7 +7,7 @@ from .scene_network_item import SceneNetworkItem, SceneNetworkItemWithUI from .graphics_scene_base import GraphicsSceneBase -from lifeblood.uidata import NodeUi +from lifeblood.node_ui import NodeUi from lifeblood.ui_protocol_data import TaskData, TaskDelta, DataNotSet, IncompleteInvocationLogData, InvocationLogData from lifeblood.basenode import BaseNode from lifeblood.enums import TaskState @@ -84,7 +84,7 @@ def set_selected(self, selected: bool, *, unselect_others=False): def update_nodeui(self, nodeui: NodeUi): self.__nodeui = nodeui - self.__nodeui.attach_to_node(Node.PseudoNode(self)) + self.__nodeui.set_ui_change_callback_receiver(Node.PseudoNode(self)) self.reanalyze_nodeui() def reanalyze_nodeui(self): diff --git a/src/lifeblood_viewer/graphics_items/pretty_items/fancy_items/scene_node.py b/src/lifeblood_viewer/graphics_items/pretty_items/fancy_items/scene_node.py index ae763541..cf495078 100644 --- a/src/lifeblood_viewer/graphics_items/pretty_items/fancy_items/scene_node.py +++ b/src/lifeblood_viewer/graphics_items/pretty_items/fancy_items/scene_node.py @@ -2,7 +2,8 @@ from lifeblood import logging from lifeblood.config import get_config from lifeblood.enums import NodeParameterType -from lifeblood.uidata import CollapsableVerticalGroup, OneLineParametersLayout, Parameter, ParameterExpressionError, ParametersLayoutBase, Separator, NodeUi +from lifeblood.node_parameters import CollapsableVerticalGroup, OneLineParametersLayout, Parameter, ParameterExpressionError, ParametersLayoutBase, Separator +from lifeblood.node_ui import NodeUi from lifeblood_viewer.graphics_items import Node from ...utils import call_later from ..decorated_node import DecoratedNode diff --git a/src/lifeblood_viewer/graphics_scene_with_data_controller.py b/src/lifeblood_viewer/graphics_scene_with_data_controller.py index 39c0628b..2fa82e01 100644 --- a/src/lifeblood_viewer/graphics_scene_with_data_controller.py +++ b/src/lifeblood_viewer/graphics_scene_with_data_controller.py @@ -23,7 +23,8 @@ ParameterChangeOp) from lifeblood.misc import timeit -from lifeblood.uidata import NodeUi, Parameter +from lifeblood.node_ui import NodeUi +from lifeblood.node_parameters import Parameter from lifeblood.ui_protocol_data import TaskBatchData, NodeGraphStructureData, TaskDelta, DataNotSet, IncompleteInvocationLogData, InvocationLogData from lifeblood.enums import TaskState, TaskGroupArchivedState from lifeblood import logging diff --git a/src/lifeblood_viewer/scene_data_controller.py b/src/lifeblood_viewer/scene_data_controller.py index 6ee0ac7f..13a19c02 100644 --- a/src/lifeblood_viewer/scene_data_controller.py +++ b/src/lifeblood_viewer/scene_data_controller.py @@ -1,7 +1,7 @@ from .undo_stack import UndoableOperation, OperationCompletionDetails from .long_op import LongOperation, LongOperationData from .ui_snippets import NodeSnippetData -from lifeblood.uidata import Parameter +from lifeblood.node_parameters import Parameter from lifeblood.ui_protocol_data import InvocationLogData from lifeblood.node_type_metadata import NodeTypeMetadata from lifeblood.enums import TaskState, TaskGroupArchivedState diff --git a/src/lifeblood_viewer/scene_ops.py b/src/lifeblood_viewer/scene_ops.py index 355a864d..b43d80a1 100644 --- a/src/lifeblood_viewer/scene_ops.py +++ b/src/lifeblood_viewer/scene_ops.py @@ -6,7 +6,7 @@ from .graphics_scene import GraphicsScene from .scene_data_controller import SceneDataController from lifeblood.snippets import NodeSnippetData -from lifeblood.uidata import ParameterLocked, ParameterReadonly +from lifeblood.node_parameters import ParameterLocked, ParameterReadonly from PySide2.QtCore import QPointF from typing import Callable, Optional, Tuple, Iterable diff --git a/tests/test_nodeui.py b/tests/test_nodeui.py index d947db05..57d31c3f 100644 --- a/tests/test_nodeui.py +++ b/tests/test_nodeui.py @@ -2,9 +2,10 @@ import random from itertools import chain from lifeblood import basenode -from lifeblood.uidata import NodeUi, NodeParameterType, ParameterNameCollisionError, ParameterNotFound, ParameterError, \ +from lifeblood.node_parameters import NodeParameterType, ParameterNameCollisionError, ParameterNotFound, ParameterError, \ ParameterExpressionError, ParameterExpressionCastError, ParameterReadonly, ParameterLocked, LayoutError, \ - ParameterDefinitionError, NodeUiDefinitionError, ParameterCannotHaveExpressions + ParameterDefinitionError, ParameterCannotHaveExpressions +from lifeblood.node_ui import NodeUi, NodeUiDefinitionError from typing import Iterable diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index a9e05c3a..421110ba 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -7,7 +7,7 @@ from lifeblood.db_misc import sql_init_script from lifeblood.scheduler.scheduler import Scheduler from lifeblood.scheduler.pinger import Pinger -from lifeblood.scheduler_message_processor import SchedulerWorkerControlClient +from lifeblood.scheduler_message_processor_client import SchedulerWorkerControlClient from lifeblood.net_messages.address import AddressChain from lifeblood.net_messages.impl.tcp_simple_command_message_processor import TcpJsonMessageProcessor from lifeblood.net_messages.exceptions import MessageTransferError diff --git a/tests/test_spawn_tasks_race.py b/tests/test_spawn_tasks_race.py index 501ed172..9892c8a5 100644 --- a/tests/test_spawn_tasks_race.py +++ b/tests/test_spawn_tasks_race.py @@ -14,7 +14,11 @@ async def test_race(self): and spawn called by message server for example. """ config = SchedulerConfigProviderOverrides(main_db_location=self.db_file, do_broadcast=False) - sched = Scheduler(scheduler_config_provider=config, node_data_provider=None, node_serializers=[None]) + sched = Scheduler( + scheduler_config_provider=config, + node_data_provider=None, + node_serializers=[None], + ) async with sched.data_access.data_connection() as con: await con.execute('BEGIN IMMEDIATE') task1 = asyncio.create_task(sched.spawn_tasks([NewTask('foo1', 1, None, {})])) diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py index 92851cb7..44a05064 100644 --- a/tests/test_worker_pool.py +++ b/tests/test_worker_pool.py @@ -9,7 +9,8 @@ import tracemalloc from lifeblood.logging import get_logger -from lifeblood.simple_worker_pool import WorkerPool, create_worker_pool +from lifeblood.simple_worker_pool import SimpleWorkerPool +from lifeblood.simple_worker_pool_main import create_worker_pool from lifeblood.enums import WorkerType, WorkerState from lifeblood.config import get_config from lifeblood.nethelpers import get_default_addr @@ -52,11 +53,12 @@ def tearDownClass(cls) -> None: def __init__(self, method='runTest'): super(WorkerPoolTests, self).__init__(method) - get_logger(WorkerPool.__name__.lower()).setLevel('DEBUG') + get_logger(SimpleWorkerPool.__name__.lower()).setLevel('DEBUG') async def _helper_test_basic(self, rnd): print('a') swp = await create_worker_pool(idle_timeout=30, scheduler_address=WorkerPoolTests.sched_addr) + await swp.start() print('b') await swp.add_worker() print('c') @@ -107,6 +109,7 @@ async def _helper_test_min1(self, rnd): minimal_idle_to_ensure=mini, worker_suspicious_lifetime=0, scheduler_address=WorkerPoolTests.sched_addr) + await swp.start() await asyncio.sleep(rnd.uniform(0, 1)) workers = swp.list_workers() self.assertEqual(mint, len(workers)) @@ -138,6 +141,7 @@ async def _helper_test_min_from_idle(self, rnd, mini: int, test_total: bool = Fa housekeeping_interval=0.2, idle_timeout=0.3, scheduler_address=WorkerPoolTests.sched_addr) + await swp.start() await asyncio.sleep(rnd.uniform(0, 1)) workers = swp.list_workers() self.assertEqual(mini*2, len(workers)) @@ -182,6 +186,7 @@ async def test_min1(self): async def _helper_test_max1(self, rnd): maxt = 5 swp = await create_worker_pool(scheduler_address=WorkerPoolTests.sched_addr) + await swp.start() swp.set_maximum_workers(maxt) for i in range(maxt+5): await swp.add_worker() @@ -199,6 +204,7 @@ async def test_max1(self): async def _helper_test_smth1(self, rnd): swp = await create_worker_pool(minimal_idle_to_ensure=1, scheduler_address=WorkerPoolTests.sched_addr) + await swp.start() await asyncio.sleep(2) swp.stop() await swp.wait_till_stops() diff --git a/tests/test_worker_restart_double_invocation_edge_case.py b/tests/test_worker_restart_double_invocation_edge_case.py index 2658c3b0..f5c8df8c 100644 --- a/tests/test_worker_restart_double_invocation_edge_case.py +++ b/tests/test_worker_restart_double_invocation_edge_case.py @@ -130,7 +130,11 @@ async def test_multi_invoc3_empty1_delays2(self): async def _helper_test_multi_invoc(self, racing_tasks_count: int, num_empty_invocs: int = 0, delays: Optional[List[int]] = None): config = SchedulerConfigProviderOverrides(self.db_file, 60, do_broadcast=False) - sched = Scheduler(scheduler_config_provider=config, node_data_provider=None, node_serializers=[None]) + sched = Scheduler( + scheduler_config_provider=config, + node_data_provider=None, + node_serializers=[None], + ) data_access = DataAccess(config_provider=config) m = mock.MagicMock() m.data_access = data_access @@ -162,7 +166,7 @@ async def _helper_test_multi_invoc(self, racing_tasks_count: int, num_empty_invo with mock.patch('lifeblood.scheduler.Scheduler._update_worker_resouce_usage'), \ mock.patch('lifeblood.scheduler.Scheduler.server_message_address'), \ mock.patch('lifeblood.scheduler.data_access.DataAccess.get_invocation_resources_assigned_to') as res_mock, \ - mock.patch('lifeblood.worker_messsage_processor.WorkerControlClient.get_worker_control_client') as get_client_mock: + mock.patch('lifeblood.worker_message_processor_client.WorkerControlClient.get_worker_control_client') as get_client_mock: res_mock.side_effect = get_invocation_resources_assigned_to_mock get_client_mock.side_effect = get_worker_control_client_mock if delays: From d68eeeac62c33c5ed347fec8f45ec445a4f33d98 Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:56:38 +0100 Subject: [PATCH 08/10] remove Scheduler import in scheduler package __init__.py --- src/lifeblood/main_scheduler.py | 2 +- src/lifeblood/scheduler/__init__.py | 1 - src/lifeblood_testing_common/nodes_common.py | 2 +- tests/nodes/test_attribute_splitter.py | 2 +- tests/nodes/test_hip_script.py | 2 +- tests/nodes/test_parent_children_waiter.py | 2 +- tests/nodes/test_rename_attrib.py | 2 +- tests/nodes/test_spawn_children.py | 2 +- tests/nodes/test_split_waiter.py | 2 +- tests/nodes/test_wait_for_task.py | 2 +- tests/nodes/test_wedge.py | 2 +- tests/test_scheduler.py | 2 -- tests/test_ui_state_accessor_integration.py | 1 - tests/test_worker.py | 1 - tests/test_worker_restart_double_invocation_edge_case.py | 2 -- 15 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/lifeblood/main_scheduler.py b/src/lifeblood/main_scheduler.py index b3653648..dadc8550 100644 --- a/src/lifeblood/main_scheduler.py +++ b/src/lifeblood/main_scheduler.py @@ -5,7 +5,7 @@ import signal from .config import get_config, create_default_user_config_file from .pluginloader import PluginNodeDataProvider -from .scheduler import Scheduler +from .scheduler.scheduler import Scheduler from .basenode_serializer_v1 import NodeSerializerV1 from .basenode_serializer_v2 import NodeSerializerV2 from .scheduler_config_provider_base import SchedulerConfigProviderBase diff --git a/src/lifeblood/scheduler/__init__.py b/src/lifeblood/scheduler/__init__.py index f3ba972f..e69de29b 100644 --- a/src/lifeblood/scheduler/__init__.py +++ b/src/lifeblood/scheduler/__init__.py @@ -1 +0,0 @@ -from .scheduler import Scheduler diff --git a/src/lifeblood_testing_common/nodes_common.py b/src/lifeblood_testing_common/nodes_common.py index c2902699..bde337d6 100644 --- a/src/lifeblood_testing_common/nodes_common.py +++ b/src/lifeblood_testing_common/nodes_common.py @@ -14,7 +14,7 @@ from lifeblood.basenode import BaseNode from lifeblood.nodethings import ProcessingResult from lifeblood.exceptions import NodeNotReadyToProcess -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood_testing_common.common import create_default_scheduler from lifeblood.worker import Worker from lifeblood.invocationjob import Invocation, InvocationJob, InvocationResources, Environment diff --git a/tests/nodes/test_attribute_splitter.py b/tests/nodes/test_attribute_splitter.py index 017e7e8d..5f71fc84 100644 --- a/tests/nodes/test_attribute_splitter.py +++ b/tests/nodes/test_attribute_splitter.py @@ -1,5 +1,5 @@ from asyncio import Event -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.nodethings import ProcessingError from lifeblood_testing_common.nodes_common import TestCaseBase, PseudoContext diff --git a/tests/nodes/test_hip_script.py b/tests/nodes/test_hip_script.py index a47f7f04..13c0b913 100644 --- a/tests/nodes/test_hip_script.py +++ b/tests/nodes/test_hip_script.py @@ -1,7 +1,7 @@ import random import ast from asyncio import Event -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.nodethings import ProcessingError from lifeblood_testing_common.nodes_common import TestCaseBase, PseudoContext diff --git a/tests/nodes/test_parent_children_waiter.py b/tests/nodes/test_parent_children_waiter.py index fb9dedbe..8a3c23c2 100644 --- a/tests/nodes/test_parent_children_waiter.py +++ b/tests/nodes/test_parent_children_waiter.py @@ -1,7 +1,7 @@ from asyncio import Event import json import random -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.basenode import BaseNode from lifeblood.enums import TaskState diff --git a/tests/nodes/test_rename_attrib.py b/tests/nodes/test_rename_attrib.py index a8519e07..787f5388 100644 --- a/tests/nodes/test_rename_attrib.py +++ b/tests/nodes/test_rename_attrib.py @@ -1,6 +1,6 @@ import random from asyncio import Event -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.nodethings import ProcessingError from lifeblood_testing_common.nodes_common import TestCaseBase, PseudoContext diff --git a/tests/nodes/test_spawn_children.py b/tests/nodes/test_spawn_children.py index 7503bee5..b59a6242 100644 --- a/tests/nodes/test_spawn_children.py +++ b/tests/nodes/test_spawn_children.py @@ -1,7 +1,7 @@ from asyncio import Event import string import random -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.basenode import BaseNode from lifeblood_testing_common.nodes_common import TestCaseBase, PseudoContext diff --git a/tests/nodes/test_split_waiter.py b/tests/nodes/test_split_waiter.py index 3f16a835..f532b88e 100644 --- a/tests/nodes/test_split_waiter.py +++ b/tests/nodes/test_split_waiter.py @@ -1,7 +1,7 @@ import random import json from asyncio import Event -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.basenode import BaseNode from lifeblood_testing_common.nodes_common import TestCaseBase, PseudoContext diff --git a/tests/nodes/test_wait_for_task.py b/tests/nodes/test_wait_for_task.py index ebea84e1..ec41b86c 100644 --- a/tests/nodes/test_wait_for_task.py +++ b/tests/nodes/test_wait_for_task.py @@ -1,7 +1,7 @@ from asyncio import Event import json import random -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.basenode import BaseNode from lifeblood.exceptions import NodeNotReadyToProcess diff --git a/tests/nodes/test_wedge.py b/tests/nodes/test_wedge.py index a4442e13..c9b08da2 100644 --- a/tests/nodes/test_wedge.py +++ b/tests/nodes/test_wedge.py @@ -1,6 +1,6 @@ import random from asyncio import Event -from lifeblood.scheduler import Scheduler +from lifeblood.scheduler.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.nodethings import ProcessingError from lifeblood_testing_common.nodes_common import TestCaseBase, PseudoContext diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 421110ba..c0763ce0 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -5,9 +5,7 @@ import sqlite3 from lifeblood.enums import InvocationState, TaskState from lifeblood.db_misc import sql_init_script -from lifeblood.scheduler.scheduler import Scheduler from lifeblood.scheduler.pinger import Pinger -from lifeblood.scheduler_message_processor_client import SchedulerWorkerControlClient from lifeblood.net_messages.address import AddressChain from lifeblood.net_messages.impl.tcp_simple_command_message_processor import TcpJsonMessageProcessor from lifeblood.net_messages.exceptions import MessageTransferError diff --git a/tests/test_ui_state_accessor_integration.py b/tests/test_ui_state_accessor_integration.py index 13d886c7..5e35d654 100644 --- a/tests/test_ui_state_accessor_integration.py +++ b/tests/test_ui_state_accessor_integration.py @@ -5,7 +5,6 @@ from lifeblood.logging import set_default_loglevel from lifeblood.enums import TaskState from lifeblood.exceptions import NotSubscribedError -from lifeblood.scheduler.scheduler import Scheduler from lifeblood.taskspawn import NewTask from lifeblood.ui_events import TaskFullState, TasksChanged, TasksUpdated, TasksRemoved from lifeblood.ui_protocol_data import DataNotSet diff --git a/tests/test_worker.py b/tests/test_worker.py index 34f7646d..df64ce7d 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -4,7 +4,6 @@ import tempfile from unittest import IsolatedAsyncioTestCase, mock from lifeblood.worker import Worker -from lifeblood.scheduler import Scheduler from lifeblood.logging import set_default_loglevel from lifeblood.invocationjob import Invocation, InvocationJob, InvocationEnvironment, InvocationResources from lifeblood.environment_resolver import EnvironmentResolverArguments diff --git a/tests/test_worker_restart_double_invocation_edge_case.py b/tests/test_worker_restart_double_invocation_edge_case.py index f5c8df8c..413a53b8 100644 --- a/tests/test_worker_restart_double_invocation_edge_case.py +++ b/tests/test_worker_restart_double_invocation_edge_case.py @@ -1,6 +1,5 @@ import aiosqlite import asyncio -import os from contextlib import contextmanager from lifeblood_testing_common.integration_common import IsolatedAsyncioTestCaseWithDb from lifeblood_testing_common.common import chain @@ -9,7 +8,6 @@ from lifeblood.enums import TaskState, WorkerState, WorkerPingState, TaskScheduleStatus, InvocationState from lifeblood.invocationjob import InvocationJob, InvocationResources from lifeblood.scheduler.data_access import DataAccess, TaskSpawnData -from lifeblood.scheduler.task_processor import TaskProcessor from lifeblood.scheduler.scheduler import Scheduler from typing import List, Optional From ad74c35c8334ca35684bbde75e5ab15bf00ac3e9 Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:59:36 +0100 Subject: [PATCH 09/10] fix mock patch --- src/lifeblood_testing_common/nodes_common.py | 2 +- tests/test_scheduler.py | 2 +- tests/test_scheduler_worker_comm.py | 2 +- tests/test_worker.py | 2 +- tests/test_worker_restart_double_invocation_edge_case.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lifeblood_testing_common/nodes_common.py b/src/lifeblood_testing_common/nodes_common.py index bde337d6..1b4ca963 100644 --- a/src/lifeblood_testing_common/nodes_common.py +++ b/src/lifeblood_testing_common/nodes_common.py @@ -222,7 +222,7 @@ async def _helper_test_worker_node(self, """ purge_db() - with mock.patch('lifeblood.scheduler.scheduler.Pinger') as ppatch, \ + with mock.patch('lifeblood.scheduler.scheduler_core.Pinger') as ppatch, \ mock.patch('lifeblood.worker.Worker.scheduler_pinger') as wppatch: ppatch.return_value = mock.AsyncMock(Pinger) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c0763ce0..beb82970 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -183,7 +183,7 @@ async def _helper_test_nonmessage_connection_when_stopping(self, try_open_new=Fa async def test_get_invocation_workers(self): purge_db() - with mock.patch('lifeblood.scheduler.scheduler.Pinger') as ppatch: + with mock.patch('lifeblood.scheduler.scheduler_core.Pinger') as ppatch: ppatch.return_value = mock.AsyncMock(Pinger) sched = create_default_scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) diff --git a/tests/test_scheduler_worker_comm.py b/tests/test_scheduler_worker_comm.py index 335d2e91..f924294f 100644 --- a/tests/test_scheduler_worker_comm.py +++ b/tests/test_scheduler_worker_comm.py @@ -401,7 +401,7 @@ async def test_task_get_order(self): if time.time() - sttime > 60: raise AssertionError('timeout reached!') wrun = worker.is_task_running() - wlocked = worker._Worker__task_changing_state_lock.locked() + wlocked = worker._WorkerCore__task_changing_state_lock.locked() with sqlite3.connect(database='test_swc.db') as con: cur = con.cursor() cur.execute('SELECT "state" FROM workers WHERE "id" = 1') diff --git a/tests/test_worker.py b/tests/test_worker.py index df64ce7d..d49b9cee 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -112,7 +112,7 @@ async def test_run_task_report(self): task_id=6492, resources_to_use=InvocationResources({}, {}) ) - with mock.patch('lifeblood.worker.SchedulerWorkerControlClient.get_scheduler_control_client') as m: + with mock.patch('lifeblood.worker_core.SchedulerWorkerControlClient.get_scheduler_control_client') as m: cm = mock.AsyncMock() m.return_value = cm cm.__enter__.return_value = cm diff --git a/tests/test_worker_restart_double_invocation_edge_case.py b/tests/test_worker_restart_double_invocation_edge_case.py index 413a53b8..3f717a92 100644 --- a/tests/test_worker_restart_double_invocation_edge_case.py +++ b/tests/test_worker_restart_double_invocation_edge_case.py @@ -161,8 +161,8 @@ async def _helper_test_multi_invoc(self, racing_tasks_count: int, num_empty_invo (fake_task_row['work_data'], fake_task_row['id'])) await con.commit() - with mock.patch('lifeblood.scheduler.Scheduler._update_worker_resouce_usage'), \ - mock.patch('lifeblood.scheduler.Scheduler.server_message_address'), \ + with mock.patch('lifeblood.scheduler.scheduler_core.SchedulerCore._update_worker_resouce_usage'), \ + mock.patch('lifeblood.scheduler.scheduler_core.SchedulerCore.server_message_address'), \ mock.patch('lifeblood.scheduler.data_access.DataAccess.get_invocation_resources_assigned_to') as res_mock, \ mock.patch('lifeblood.worker_message_processor_client.WorkerControlClient.get_worker_control_client') as get_client_mock: res_mock.side_effect = get_invocation_resources_assigned_to_mock From 683611311264c2d3897cefdd0066686eee516462 Mon Sep 17 00:00:00 2001 From: pedohorse <13556996+pedohorse@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:47:18 +0100 Subject: [PATCH 10/10] cleanup --- src/lifeblood/basenode.py | 5 +---- src/lifeblood/broadcasting.py | 5 +---- src/lifeblood/hardware_resources.py | 6 ++---- src/lifeblood/local_notifier.py | 3 +-- src/lifeblood/main_scheduler.py | 2 +- src/lifeblood/misc.py | 3 +-- src/lifeblood/names.py | 2 -- .../net_messages/impl/message_protocol.py | 5 +---- .../net_messages/impl/tcp_message_processor.py | 1 - .../net_messages/impl/tcp_message_receiver.py | 2 +- .../impl/tcp_message_stream_factory.py | 2 +- src/lifeblood/net_messages/interfaces.py | 4 ++-- src/lifeblood/net_messages/messages.py | 16 +--------------- src/lifeblood/net_messages/stream_wrappers.py | 2 +- src/lifeblood/node_visualization_classes.py | 3 --- src/lifeblood/pulse_checker.py | 2 +- src/lifeblood/scheduler/data_access.py | 2 +- src/lifeblood/scheduler/pinger.py | 2 +- src/lifeblood/scheduler/task_processor.py | 5 +++-- src/lifeblood/scheduler/ui_state_accessor.py | 4 ++-- src/lifeblood/scheduler_config_provider_file.py | 2 +- src/lifeblood/scheduler_event_log.py | 2 +- src/lifeblood/scheduler_task_protocol.py | 7 +------ src/lifeblood/ui_events.py | 4 ++-- src/lifeblood/ui_protocol_data.py | 2 +- src/lifeblood/worker_resource_definition.py | 2 +- src/lifeblood_testing_common/nodes_common.py | 4 +--- 27 files changed, 30 insertions(+), 69 deletions(-) diff --git a/src/lifeblood/basenode.py b/src/lifeblood/basenode.py index 8b16650f..fc70284e 100644 --- a/src/lifeblood/basenode.py +++ b/src/lifeblood/basenode.py @@ -4,16 +4,13 @@ from logging import Logger from .nodethings import ProcessingResult from .node_ui import NodeUi -from .node_parameters import ParameterNotFound, Parameter +from .node_parameters import ParameterNotFound, Parameter from .processingcontext import ProcessingContext from .logging import get_logger from .plugin_info import PluginInfo, empty_plugin_info from .nodegraph_holder_base import NodeGraphHolderBase from .node_ui_callback_receiver_base import NodeUiCallbackReceiverBase -# reexport -from .nodethings import ProcessingError - from typing import Iterable diff --git a/src/lifeblood/broadcasting.py b/src/lifeblood/broadcasting.py index 1f8e421c..b10ddba6 100644 --- a/src/lifeblood/broadcasting.py +++ b/src/lifeblood/broadcasting.py @@ -1,14 +1,11 @@ import asyncio import socket -from string import ascii_letters -import random import struct - from . import logging from .nethelpers import get_localhost from .defaults import broadcast_port as default_broadcast_port -from . import os_based_cheats +from . import os_based_cheats # import needed for windows from typing import Tuple, Union, Optional, Callable, Coroutine, Any diff --git a/src/lifeblood/hardware_resources.py b/src/lifeblood/hardware_resources.py index 9934984b..33c0d859 100644 --- a/src/lifeblood/hardware_resources.py +++ b/src/lifeblood/hardware_resources.py @@ -1,5 +1,3 @@ -import psutil -import copy import re import json from .misc import get_unique_machine_id @@ -128,11 +126,11 @@ def __repr__(self): parts = [] for res_name, res in self.__resources.items(): parts.append(f'{res_name}: {res.value}') - for dev_type, dev_res in self.__dev_resources.items(): + for dev_type, dev_name, dev_res in self.__dev_resources: dev_parts = [] for res_name, res in dev_res.items(): dev_parts.append(f'{res_name}: {res.value}') - parts.append(f'device({dev_type})[{", ".join(dev_parts)}]') + parts.append(f'device(type:"{dev_type}" name:"{dev_name}")[{", ".join(dev_parts)}]') return f'' diff --git a/src/lifeblood/local_notifier.py b/src/lifeblood/local_notifier.py index 6ff7c0f7..f726f97f 100644 --- a/src/lifeblood/local_notifier.py +++ b/src/lifeblood/local_notifier.py @@ -2,10 +2,9 @@ import json import uuid from . import broadcasting -from . import logging from .nethelpers import get_localhost -from typing import Optional, Tuple, Callable, Coroutine, Any +from typing import Optional, Tuple, Callable, Coroutine from .logging import get_logger diff --git a/src/lifeblood/main_scheduler.py b/src/lifeblood/main_scheduler.py index dadc8550..79092486 100644 --- a/src/lifeblood/main_scheduler.py +++ b/src/lifeblood/main_scheduler.py @@ -12,7 +12,7 @@ from .scheduler_config_provider_file import SchedulerConfigProviderFileOverrides from . import logging -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Tuple, Union def __construct_plugin_paths(custom_plugins_path: Union[None, str, Path], plugin_search_locations: Iterable[Union[str, Path]]) -> List[Tuple[Path, str]]: diff --git a/src/lifeblood/misc.py b/src/lifeblood/misc.py index d3f439d4..dd759f07 100644 --- a/src/lifeblood/misc.py +++ b/src/lifeblood/misc.py @@ -1,6 +1,5 @@ import os import asyncio -import random import uuid import time import psutil @@ -8,7 +7,7 @@ from contextlib import contextmanager, asynccontextmanager from .logging import get_logger, logging -from typing import List, Optional, Union +from typing import Union class DummyLock: diff --git a/src/lifeblood/names.py b/src/lifeblood/names.py index f407bb6f..794103bc 100644 --- a/src/lifeblood/names.py +++ b/src/lifeblood/names.py @@ -1,5 +1,3 @@ -import re - from lifeblood.logging import get_logger from typing import Iterable diff --git a/src/lifeblood/net_messages/impl/message_protocol.py b/src/lifeblood/net_messages/impl/message_protocol.py index 2fd38c68..328f99c6 100644 --- a/src/lifeblood/net_messages/impl/message_protocol.py +++ b/src/lifeblood/net_messages/impl/message_protocol.py @@ -1,14 +1,11 @@ import asyncio -import uuid import struct from ..logging import get_logger from ..stream_wrappers import MessageReceiveStream from ..messages import Message -from ..queue import MessageQueue from ..address import DirectAddress -from ..exceptions import MessageReceivingError, NoMessageError, MessageTransferError, MessageTransferTimeoutError -from ..interfaces import MessageStreamFactory +from ..exceptions import NoMessageError, MessageTransferError, MessageTransferTimeoutError from typing import Callable, Awaitable, Tuple diff --git a/src/lifeblood/net_messages/impl/tcp_message_processor.py b/src/lifeblood/net_messages/impl/tcp_message_processor.py index c1e1dfdf..2ba09105 100644 --- a/src/lifeblood/net_messages/impl/tcp_message_processor.py +++ b/src/lifeblood/net_messages/impl/tcp_message_processor.py @@ -1,4 +1,3 @@ -import asyncio from ..message_processor import MessageProcessorBase from ..message_handler import MessageHandlerBase from ..messages import Message diff --git a/src/lifeblood/net_messages/impl/tcp_message_receiver.py b/src/lifeblood/net_messages/impl/tcp_message_receiver.py index 0709e589..dbcf0964 100644 --- a/src/lifeblood/net_messages/impl/tcp_message_receiver.py +++ b/src/lifeblood/net_messages/impl/tcp_message_receiver.py @@ -1,6 +1,6 @@ import asyncio from .message_protocol import MessageProtocol, IProtocolInstanceCounter -from ..interfaces import MessageReceiver, MessageStreamFactory +from ..interfaces import MessageReceiver from ..messages import Message from ..address import DirectAddress from ..logging import get_logger diff --git a/src/lifeblood/net_messages/impl/tcp_message_stream_factory.py b/src/lifeblood/net_messages/impl/tcp_message_stream_factory.py index 36485b6c..7deca6e9 100644 --- a/src/lifeblood/net_messages/impl/tcp_message_stream_factory.py +++ b/src/lifeblood/net_messages/impl/tcp_message_stream_factory.py @@ -7,7 +7,7 @@ from ..exceptions import MessageTransferError, MessageTransferTimeoutError from ..interfaces import MessageStreamFactory from ..stream_wrappers import MessageSendStream, MessageSendStreamBase -from ..address import DirectAddress, AddressChain +from ..address import DirectAddress from ..defaults import default_stream_timeout from ..messages import Message diff --git a/src/lifeblood/net_messages/interfaces.py b/src/lifeblood/net_messages/interfaces.py index 3dfad908..89d1c211 100644 --- a/src/lifeblood/net_messages/interfaces.py +++ b/src/lifeblood/net_messages/interfaces.py @@ -1,6 +1,6 @@ from .messages import Message -from .message_stream import MessageSendStreamBase, MessageReceiveStreamBase -from .address import DirectAddress, AddressChain +from .message_stream import MessageSendStreamBase +from .address import DirectAddress from typing import Callable, Awaitable diff --git a/src/lifeblood/net_messages/messages.py b/src/lifeblood/net_messages/messages.py index bd8253f7..c3420b99 100644 --- a/src/lifeblood/net_messages/messages.py +++ b/src/lifeblood/net_messages/messages.py @@ -1,10 +1,8 @@ -import asyncio -import struct import uuid from .enums import MessageType from .address import AddressChain -from typing import Optional, Tuple, Union +from typing import Optional, Union class MessageInterface: @@ -70,17 +68,5 @@ def set_message_destination(self, destination: AddressChain): def set_message_source(self, source: AddressChain): self.__source = source - def create_reply_message(self, data: bytes = b''): - if self.__message_type in (MessageType.SESSION_START, MessageType.SESSION_MESSAGE): - return_type = MessageType.SESSION_MESSAGE - elif self.__message_type == MessageType.SESSION_END: - raise RuntimeError('cannot reply to session end message') - elif self.__message_type == MessageType.DEFAULT_MESSAGE: - return_type = self.__message_type - else: - raise RuntimeError(f'unknown message type {self.__message_type}') - return Message(data, return_type, self.__destination, self.__source, self.__session) - def message_type(self) -> MessageType: return self.__message_type - diff --git a/src/lifeblood/net_messages/stream_wrappers.py b/src/lifeblood/net_messages/stream_wrappers.py index 4c61776f..e9ba297b 100644 --- a/src/lifeblood/net_messages/stream_wrappers.py +++ b/src/lifeblood/net_messages/stream_wrappers.py @@ -6,7 +6,7 @@ from .message_stream import MessageSendStreamBase, MessageReceiveStreamBase from .enums import MessageType from .address import AddressChain, DirectAddress -from .exceptions import MessageReceivingError, MessageSendingError, MessageTransferTimeoutError, NoMessageError +from .exceptions import MessageSendingError, MessageTransferTimeoutError, NoMessageError from .defaults import default_stream_timeout from typing import Optional, Tuple, Union diff --git a/src/lifeblood/node_visualization_classes.py b/src/lifeblood/node_visualization_classes.py index d9468dea..506354fa 100644 --- a/src/lifeblood/node_visualization_classes.py +++ b/src/lifeblood/node_visualization_classes.py @@ -1,7 +1,4 @@ -from typing import Tuple - - class NodeColorScheme: def __init__(self): self.__main_color = (0, 0, 0) diff --git a/src/lifeblood/pulse_checker.py b/src/lifeblood/pulse_checker.py index 7e385de9..d1ec04c9 100644 --- a/src/lifeblood/pulse_checker.py +++ b/src/lifeblood/pulse_checker.py @@ -5,7 +5,7 @@ from .net_messages.message_processor import MessageProcessorBase from .net_messages.exceptions import MessageTransferError -from typing import Tuple, Callable, Coroutine +from typing import Callable, Coroutine class PulseChecker: diff --git a/src/lifeblood/scheduler/data_access.py b/src/lifeblood/scheduler/data_access.py index 7f746aa9..b47f0a18 100644 --- a/src/lifeblood/scheduler/data_access.py +++ b/src/lifeblood/scheduler/data_access.py @@ -17,7 +17,7 @@ from ..worker_resource_definition import WorkerResourceDataType from ..invocationjob import InvocationResources -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Union SCHEDULER_DB_FORMAT_VERSION = 5 diff --git a/src/lifeblood/scheduler/pinger.py b/src/lifeblood/scheduler/pinger.py index 0af25f19..5406e174 100644 --- a/src/lifeblood/scheduler/pinger.py +++ b/src/lifeblood/scheduler/pinger.py @@ -3,7 +3,7 @@ import time from .. import logging from ..worker_message_processor_client import WorkerControlClient -from ..enums import WorkerState, InvocationState, WorkerPingState, WorkerPingReply +from ..enums import WorkerState, WorkerPingState, WorkerPingReply from .scheduler_component_base import SchedulerComponentBase from ..net_messages.address import AddressChain from ..net_messages.exceptions import MessageTransferError, MessageTransferTimeoutError diff --git a/src/lifeblood/scheduler/task_processor.py b/src/lifeblood/scheduler/task_processor.py index b7b4b074..483f9bdf 100644 --- a/src/lifeblood/scheduler/task_processor.py +++ b/src/lifeblood/scheduler/task_processor.py @@ -17,12 +17,13 @@ from ..attribute_serialization import serialize_attributes, deserialize_attributes from ..exceptions import * from .. import aiosqlite_overlay -from ..ui_events import TaskData, TaskDelta +from ..ui_events import TaskDelta +from ..ui_protocol_data import TaskData from ..net_messages.address import AddressChain from .scheduler_component_base import SchedulerComponentBase -from typing import List, Optional, TYPE_CHECKING +from typing import List, TYPE_CHECKING if TYPE_CHECKING: # TODO: maybe separate a subset of scheduler's methods to smth like SchedulerData class, or idunno, for now no obvious way to separate, so having a reference back from .scheduler_core import SchedulerCore diff --git a/src/lifeblood/scheduler/ui_state_accessor.py b/src/lifeblood/scheduler/ui_state_accessor.py index b9224a46..7c73aca8 100644 --- a/src/lifeblood/scheduler/ui_state_accessor.py +++ b/src/lifeblood/scheduler/ui_state_accessor.py @@ -5,12 +5,12 @@ import time from enum import Enum from ..logging import get_logger -from ..misc import atimeit, aperformance_measurer +from ..misc import aperformance_measurer from ..enums import InvocationState, TaskState, TaskGroupArchivedState, WorkerState, WorkerType, UIEventType from ..exceptions import NotSubscribedError from ..scheduler_event_log import SchedulerEventLog from ..ui_events import TaskEvent, TaskFullState, TasksUpdated, TasksRemoved, TasksChanged -from ..ui_protocol_data import TaskBatchData, UiData, TaskGroupData, TaskGroupBatchData, TaskGroupStatisticsData, \ +from ..ui_protocol_data import TaskBatchData, TaskGroupData, TaskGroupBatchData, TaskGroupStatisticsData, \ NodeGraphStructureData, WorkerBatchData, WorkerData, WorkerResource, WorkerResourceType, WorkerResources, NodeConnectionData, NodeData, TaskData, TaskDelta, \ WorkerDevice, WorkerDeviceResource from .scheduler_component_base import SchedulerComponentBase diff --git a/src/lifeblood/scheduler_config_provider_file.py b/src/lifeblood/scheduler_config_provider_file.py index 1a7002b4..f49b4410 100644 --- a/src/lifeblood/scheduler_config_provider_file.py +++ b/src/lifeblood/scheduler_config_provider_file.py @@ -7,7 +7,7 @@ from .config import Config from .nethelpers import all_interfaces from .exceptions import SchedulerConfigurationError -from .config import create_default_user_config_file, get_local_scratch_path +from .config import get_local_scratch_path from .text import escape from typing import Dict, List, Mapping, Optional, Tuple diff --git a/src/lifeblood/scheduler_event_log.py b/src/lifeblood/scheduler_event_log.py index 342426f2..665764bc 100644 --- a/src/lifeblood/scheduler_event_log.py +++ b/src/lifeblood/scheduler_event_log.py @@ -2,7 +2,7 @@ from .ui_events import SchedulerEvent from .enums import UIEventType -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple logger = logging.getLogger(__name__) diff --git a/src/lifeblood/scheduler_task_protocol.py b/src/lifeblood/scheduler_task_protocol.py index 8b5a6648..186331a0 100644 --- a/src/lifeblood/scheduler_task_protocol.py +++ b/src/lifeblood/scheduler_task_protocol.py @@ -1,16 +1,11 @@ import struct import asyncio -import aiofiles -from enum import Enum import pickle import json from . import logging -from . import invocationjob from .taskspawn import TaskSpawn -from .enums import WorkerType, SpawnStatus, WorkerState -from .hardware_resources import HardwareResources -from .worker_metadata import WorkerMetadata +from .enums import SpawnStatus, WorkerState from .scheduler.scheduler_core import SchedulerCore from typing import Optional, Tuple diff --git a/src/lifeblood/ui_events.py b/src/lifeblood/ui_events.py index 5f89072c..ab23c39d 100644 --- a/src/lifeblood/ui_events.py +++ b/src/lifeblood/ui_events.py @@ -3,11 +3,11 @@ import struct from dataclasses import dataclass, field from .buffered_connection import BufferedReader -from .ui_protocol_data import TaskData, TaskDelta, TaskBatchData, UiData +from .ui_protocol_data import TaskDelta, TaskBatchData from .buffer_serializable import IBufferSerializable from .enums import UIEventType -from typing import ClassVar, Dict, Iterable, List, Tuple, Type, Union +from typing import ClassVar, Dict, Iterable, List, Tuple, Type @dataclass diff --git a/src/lifeblood/ui_protocol_data.py b/src/lifeblood/ui_protocol_data.py index 6592c4ae..87d5ff1f 100644 --- a/src/lifeblood/ui_protocol_data.py +++ b/src/lifeblood/ui_protocol_data.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Tuple, Type, Optional, Set, Union +from typing import Dict, List, Type, Optional, Set, Union def _serialize_string(s: str, stream: BufferedIOBase) -> int: diff --git a/src/lifeblood/worker_resource_definition.py b/src/lifeblood/worker_resource_definition.py index 657e357d..28e82259 100644 --- a/src/lifeblood/worker_resource_definition.py +++ b/src/lifeblood/worker_resource_definition.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, Union, Set, Tuple +from typing import Union, Tuple class WorkerResourceDataType(Enum): diff --git a/src/lifeblood_testing_common/nodes_common.py b/src/lifeblood_testing_common/nodes_common.py index 1b4ca963..405e5535 100644 --- a/src/lifeblood_testing_common/nodes_common.py +++ b/src/lifeblood_testing_common/nodes_common.py @@ -1,6 +1,4 @@ import asyncio -import time -from dataclasses import dataclass import os import shutil import tempfile @@ -17,7 +15,7 @@ from lifeblood.scheduler.scheduler import Scheduler from lifeblood_testing_common.common import create_default_scheduler from lifeblood.worker import Worker -from lifeblood.invocationjob import Invocation, InvocationJob, InvocationResources, Environment +from lifeblood.invocationjob import Invocation, InvocationResources, Environment from lifeblood.scheduler.pinger import Pinger from lifeblood.pluginloader import PluginNodeDataProvider from lifeblood.processingcontext import ProcessingContext