diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..fb466764 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,41 @@ +name: lint + +on: + push: + +jobs: +# pylint: +# runs-on: ubuntu-latest +# +# steps: +# - uses: actions/checkout@v3 +# - name: Set up Python 3.10 +# uses: actions/setup-python@v5 +# with: +# python-version: '3.10' +# - name: Install dependencies +# run: | +# python -m pip install --upgrade pip +# pip install -r requirements_tests.txt +# pip install pylint +# - name: Analysing the code with pylint +# run: | +# pylint `ls -R|grep .py$|xargs` + + flake8: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements_tests.txt + pip install flake8 + - name: Analysing the code with flake8 + run: | + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics \ No newline at end of file diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml deleted file mode 100644 index 9891ca6a..00000000 --- a/.github/workflows/pylint.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: Pylint - -on: [push] - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install pylint - - name: Analysing the code with pylint - run: | - pylint `ls -R|grep .py$|xargs` diff --git a/src/lifeblood/basenode.py b/src/lifeblood/basenode.py index 06131cac..eea067dd 100644 --- a/src/lifeblood/basenode.py +++ b/src/lifeblood/basenode.py @@ -7,20 +7,21 @@ from typing import Dict, Optional, List, Any from .nodethings import ProcessingResult, ProcessingError from .uidata import NodeUi, ParameterNotFound, ParameterReadonly, ParameterLocked, ParameterCannotHaveExpressions, Parameter -from .pluginloader import create_node, plugin_hash, nodes_settings from .processingcontext import ProcessingContext from .logging import get_logger from .enums import NodeParameterType, WorkerType -from .plugin_info import PluginInfo +from .plugin_info import PluginInfo, empty_plugin_info +from .nodegraph_holder_base import NodeGraphHolderBase from typing import TYPE_CHECKING, Iterable if TYPE_CHECKING: - from .scheduler import Scheduler from logging import Logger class BaseNode: + _plugin_data = None # To be set on module level by loader, set to empty_plugin_info by default + @classmethod def label(cls) -> str: raise NotImplementedError() @@ -38,7 +39,9 @@ def description(cls) -> str: return 'this node type does not have a description' def __init__(self, name: str): - self.__parent: Scheduler = None + if BaseNode._plugin_data is None: + BaseNode._plugin_data = empty_plugin_info + self.__parent: NodeGraphHolderBase = None self.__parent_nid: int = None self._parameters: NodeUi = NodeUi(self) self.__name = name @@ -49,9 +52,9 @@ def __init__(self, name: str): self.__logger = get_logger(f'BaseNode.{mytype}' if mytype is not None else 'BaseNode') # subclass is expected to add parameters at this point - def _set_parent(self, parent_scheduler, node_id): - self.__parent = parent_scheduler - self.__parent_nid = node_id + def set_parent(self, graph_holder: NodeGraphHolderBase, node_id_in_graph: int): + self.__parent = graph_holder + self.__parent_nid = node_id_in_graph def logger(self) -> "Logger": return self.__logger @@ -211,13 +214,7 @@ def copy_ui_to(self, to_node: "BaseNode"): to_node._parameters = newui newui.attach_to_node(to_node) - def apply_settings(self, settings_name: str) -> None: - mytype = self.type_name() - if mytype not in nodes_settings: - raise RuntimeError(f'no settings found for "{mytype}"') - if settings_name not in nodes_settings[mytype]: - raise RuntimeError(f'requested settings "{settings_name}" not found for type "{mytype}"') - settings = nodes_settings[mytype][settings_name] + def apply_settings(self, settings: Dict[str, Dict[str, Any]]) -> None: with self.get_ui().postpone_ui_callbacks(): for param_name, value in settings.items(): try: @@ -232,10 +229,10 @@ def apply_settings(self, settings_name: str) -> None: param.remove_expression() param.set_value(value) except ParameterNotFound: - self.logger().warning(f'applying settings "{settings_name}": skipping unrecognized parameter "{param_name}"') + self.logger().warning(f'applying settings: skipping unrecognized parameter "{param_name}"') continue except ValueError as e: - self.logger().warning(f'applying settings "{settings_name}": skipping parameter "{param_name}": bad value type: {str(e)}') + self.logger().warning(f'applying settings: skipping parameter "{param_name}": bad value type: {str(e)}') continue # # some helpers @@ -247,98 +244,32 @@ def apply_settings(self, settings_name: str) -> None: # @classmethod def my_plugin(cls) -> PluginInfo: - from . import pluginloader - type_name = cls.type_name() # this case was for nodetypes that are present in DB, but not loaded cuz of configuration errors # but it doesn't make sense in current form - if node is created - plugin info will be present # this needs to be rethought # # if type_name not in pluginloader.plugins: # # return None - return pluginloader.plugins[type_name]._plugin_info + return cls._plugin_data # # Serialize and back # - def __reduce__(self): - # typename = type(self).__module__ - # if '.' in typename: - # typename = typename.rsplit('.', 1)[-1] - return create_node, (self.type_name(), '', None, None), self.__getstate__() - - def __getstate__(self): - # TODO: if u ever implement parameter expressions - be VERY careful with pickling expressions referencing across nodes - d = copy(self.__dict__) - assert '_BaseNode__parent' in d - d['_BaseNode__parent'] = None - d['_BaseNode__parent_nid'] = None - d['_BaseNode__saved_plugin_hash'] = plugin_hash(self.type_name()) # we will use this hash to detect plugin module changes on load - return d - - def __setstate__(self, state): - # the idea here is to update node's class instance IF plugin hash is different from the saved one - # the hash being different means that node's definition was updated - we don't know how - # so what we do is save all parameter values, merge old state values with new - # and hope for the best... - - hash = plugin_hash(self.type_name()) - if hash != state.get('_BaseNode__saved_plugin_hash', None): - self.__init__(state.get('name', '')) - # update all except ui - try: - if '_parameters' in state: - old_ui: NodeUi = state['_parameters'] - del state['_parameters'] - self.__dict__.update(state) - new_ui = self.get_ui() - for param in old_ui.parameters(): - try: - newparam = new_ui.parameter(param.name()) - except ParameterNotFound: - continue - try: - newparam.set_value(param.unexpanded_value()) - except ParameterReadonly: - newparam._Parameter__value = param.unexpanded_value() - except ParameterLocked: - newparam.set_locked(False) - newparam.set_value(param.unexpanded_value()) - newparam.set_locked(True) - if param.has_expression(): - try: - newparam.set_expression(param.expression()) - except ParameterCannotHaveExpressions: - pass - else: - self.__dict__.update(state) - except AttributeError: - # something changed so much that some core attrs are different - get_logger('BaseNode').exception(f'could not update interface for some node of type {self.type_name()}. resetting node\'s inrerface') - - - # TODO: if and whenever expressions are introduced - u need to take care of expressions here too! - else: - self.__dict__.update(state) - def serialize(self) -> bytes: - """ - by default we just serialize - :return: + def get_state(self) -> Optional[dict]: """ - return pickle.dumps(self) + override this to be able to save node's unique state if it has one + None means node does not and will not have an internal state + if node CAN have an internal state and it's just empty - return empty dict instead - async def serialize_async(self) -> bytes: - return await asyncio.get_event_loop().run_in_executor(None, self.serialize) - - @classmethod - def deserialize(cls, data: bytes, parent_scheduler, node_id): - newobj = pickle.loads(data) - newobj.__parent = parent_scheduler - newobj.__parent_nid = node_id - return newobj + note: state will only be saved on normal exit, it won't be saved on crash, it's not part of any transaction + """ + return None - @classmethod - async def deserialize_async(cls, data: bytes, parent_scheduler, node_id): - return await asyncio.get_event_loop().run_in_executor(None, cls.deserialize, data, parent_scheduler, node_id) + def set_state(self, state: dict): + """ + restore state as given by get_state + """ + pass class BaseNodeWithTaskRequirements(BaseNode): diff --git a/src/lifeblood/basenode_serialization.py b/src/lifeblood/basenode_serialization.py new file mode 100644 index 00000000..970f6ef6 --- /dev/null +++ b/src/lifeblood/basenode_serialization.py @@ -0,0 +1,27 @@ +import asyncio +from .basenode import BaseNode +from .node_dataprovider_base import NodeDataProvider +from .nodegraph_holder_base import NodeGraphHolderBase + +from typing import Optional, Tuple + + +class FailedToDeserialize(RuntimeError): + pass + + +class NodeSerializerBase: + def serialize(self, node: BaseNode) -> Tuple[bytes, Optional[bytes]]: + raise NotImplementedError() + + def serialize_state_only(self, node: BaseNode) -> Optional[bytes]: + raise NotImplementedError() + + def deserialize(self, parent: NodeGraphHolderBase, node_id: int, node_data_provider: NodeDataProvider, data: bytes, state: Optional[bytes]) -> BaseNode: + raise NotImplementedError() + + async def deserialize_async(self, parent: NodeGraphHolderBase, node_id: int, node_data_provider: NodeDataProvider, data: bytes, state: Optional[bytes]) -> BaseNode: + return await asyncio.get_event_loop().run_in_executor(None, self.deserialize, parent, node_id, node_data_provider, data, state) + + async def serialize_async(self, node: BaseNode) -> Tuple[bytes, Optional[bytes]]: + return await asyncio.get_event_loop().run_in_executor(None, self.serialize, node) diff --git a/src/lifeblood/basenode_serializer_v1.py b/src/lifeblood/basenode_serializer_v1.py new file mode 100644 index 00000000..4b9eba52 --- /dev/null +++ b/src/lifeblood/basenode_serializer_v1.py @@ -0,0 +1,52 @@ +import pickle +from io import BytesIO +from dataclasses import dataclass, is_dataclass +import json +from .basenode_serialization import NodeSerializerBase, FailedToDeserialize +from .basenode import BaseNode, NodeParameterType + +from typing import Callable, Optional, Tuple, Union + +from .node_dataprovider_base import NodeDataProvider +from .nodegraph_holder_base import NodeGraphHolderBase + + +@dataclass +class ParameterData: + name: str + type: NodeParameterType + unexpanded_value: Union[int, float, str, bool] + expression: Optional[str] + + +def create_node_maker(node_data_provider: NodeDataProvider) -> Callable[[str, str, NodeGraphHolderBase, int], BaseNode]: + def create_node(type_name: str, name: str, sched_parent, node_id) -> BaseNode: + node = node_data_provider.node_factory(type_name)(name) + node.set_parent(sched_parent, node_id) + return node + return create_node + + +class NodeSerializerV1(NodeSerializerBase): + def serialize(self, node: BaseNode) -> Tuple[bytes, Optional[bytes]]: + raise DeprecationWarning('no use this!') + + def deserialize(self, parent: NodeGraphHolderBase, node_id: int, node_data_provider: NodeDataProvider, data: bytes, state: Optional[bytes]) -> BaseNode: + # this be pickled + # we do hacky things here fo backward compatibility + class Unpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'lifeblood.pluginloader' and name == 'create_node': + return create_node_maker(node_data_provider) + return super(Unpickler, self).find_class(module, name) + + if state is not None: + raise FailedToDeserialize(f'deserialization v1 is not expecting a separate state data') + + try: + newobj: BaseNode = Unpickler(BytesIO(data)).load() + except Exception as e: + raise FailedToDeserialize(f'error loading pickle: {e}') from None + + newobj.set_parent(parent, node_id) + return newobj diff --git a/src/lifeblood/basenode_serializer_v2.py b/src/lifeblood/basenode_serializer_v2.py new file mode 100644 index 00000000..0d5d4cb5 --- /dev/null +++ b/src/lifeblood/basenode_serializer_v2.py @@ -0,0 +1,140 @@ +from dataclasses import dataclass, is_dataclass +import json +from .basenode_serialization import NodeSerializerBase, FailedToDeserialize +from .basenode import BaseNode, NodeParameterType +from .uidata import ParameterFullValue + +from typing import Optional, Tuple, Union + +from .node_dataprovider_base import NodeDataProvider +from .nodegraph_holder_base import NodeGraphHolderBase + + +@dataclass +class ParameterData: + name: str + type: NodeParameterType + unexpanded_value: Union[int, float, str, bool] + expression: Optional[str] + + +class NodeSerializerV2(NodeSerializerBase): + """ + Universal json-like serializer + Note, this supports more things than json, such as: + - tuples + - sets + - int dict keys + - tuple dict keys + - limited set of dataclasses + + the final string though is json-compliant + """ + + class Serializer(json.JSONEncoder): + def __reform(self, obj): + if type(obj) is set: + return { + '__special_object_type__': 'set', + 'items': self.__reform(list(obj)) + } + elif type(obj) is tuple: + return { + '__special_object_type__': 'tuple', + 'items': self.__reform(list(obj)) + } + elif type(obj) is dict: # int keys case + if any(isinstance(x, (int, float, tuple)) for x in obj.keys()): + return { + '__special_object_type__': 'kvp', + 'items': self.__reform([[k, v] for k, v in obj.items()]) + } + return {k: self.__reform(v) for k, v in obj.items()} + elif is_dataclass(obj): + dcs = self.__reform(obj.__dict__) # dataclasses.asdict is recursive, kills inner dataclasses + dcs['__dataclass__'] = obj.__class__.__name__ + dcs['__special_object_type__'] = 'dataclass' + return dcs + elif isinstance(obj, NodeParameterType): + return {'value': obj.value, + '__special_object_type__': 'NodeParameterType' + } + elif isinstance(obj, list): + return [self.__reform(x) for x in obj] + elif isinstance(obj, (int, float, str, bool)) or obj is None: + return obj + raise NotImplementedError(f'serialization not implemented for type "{type(obj)}"') + + def encode(self, o): + return super().encode(self.__reform(o)) + + def default(self, obj): + return super(NodeSerializerV2.Serializer, self).default(obj) + + class Deserializer(json.JSONDecoder): + def dedata(self, obj): + special_type = obj.get('__special_object_type__') + if special_type == 'set': + return set(obj.get('items')) + elif special_type == 'tuple': + return tuple(obj.get('items')) + elif special_type == 'kvp': + return {k: v for k, v in obj.get('items')} + elif special_type == 'dataclass': + data = globals()[obj['__dataclass__']](**{k: v for k, v in obj.items() if k not in ('__dataclass__', '__special_object_type__')}) + if obj['__dataclass__'] == 'NodeData': + data.pos = tuple(data.pos) + return data + elif special_type == 'NodeParameterType': + return NodeParameterType(obj['value']) + return obj + + def __init__(self): + super(NodeSerializerV2.Deserializer, self).__init__(object_hook=self.dedata) + + def serialize(self, node: BaseNode) -> Tuple[bytes, Optional[bytes]]: + param_values = {} + for param in node.get_ui().parameters(): + param_values[param.name()] = ParameterData( + param.name(), + param.type(), + param.unexpanded_value(), + param.expression() + ) + + data_dict = { + 'format_version': 2, + 'type_name': node.type_name(), + 'name': node.name(), + 'ingraph_id': node.id(), # node_id will be overriden on deserialize, to make sure scheduler is consistent + 'type_definition_hash': node.my_plugin().hash(), + 'parameters': param_values, + } + + return ( + json.dumps(data_dict, cls=NodeSerializerV2.Serializer).encode('latin1'), + self.serialize_state_only(node) + ) + + def serialize_state_only(self, node: BaseNode) -> Optional[bytes]: + state = node.get_state() + return None if state is None else json.dumps(state, cls=NodeSerializerV2.Serializer).encode('latin1') + + def deserialize(self, parent: NodeGraphHolderBase, node_id: int, node_data_provider: NodeDataProvider, data: bytes, state: Optional[bytes]) -> BaseNode: + try: + data_dict = json.loads(data.decode('latin1'), cls=NodeSerializerV2.Deserializer) + except json.JSONDecodeError: + raise FailedToDeserialize('not a json') from None + for musthave in ('format_version', 'type_name', 'type_definition_hash', 'parameters', 'name', 'ingraph_id'): + if musthave not in data_dict: + raise FailedToDeserialize('missing required fields') + if (fv := data_dict['format_version']) != 2: + raise FailedToDeserialize(f'format_version {fv} is not supported') + new_node = node_data_provider.node_factory(data_dict['type_name'])(data_dict['name']) + new_node.set_parent(parent, node_id) + with new_node.get_ui().block_ui_callbacks(): + new_node.get_ui().set_parameters_batch({name: ParameterFullValue(val.unexpanded_value, val.expression) for name, val in data_dict['parameters'].items()}) + if state: + new_node.set_state(json.loads(state.decode('latin1'), cls=NodeSerializerV2.Deserializer)) + + return new_node diff --git a/src/lifeblood/core_nodes/parent_children_waiter.py b/src/lifeblood/core_nodes/parent_children_waiter.py index e47a7568..356f8dc4 100644 --- a/src/lifeblood/core_nodes/parent_children_waiter.py +++ b/src/lifeblood/core_nodes/parent_children_waiter.py @@ -1,3 +1,5 @@ +import dataclasses +from dataclasses import dataclass import json from lifeblood.basenode import BaseNode, ProcessingError from lifeblood.nodethings import ProcessingResult @@ -25,11 +27,11 @@ class ParentChildrenWaiterNode(BaseNode): when a task from first input has all it's children arriving from the second input possibly recursively """ + @dataclass class Entry: - def __init__(self): - self.children: Set[int] = set() - self.parent_ready: bool = False - self.all_children_dicts: Dict[int, dict] = {} + children: Set[int] = dataclasses.field(default_factory=set) + parent_ready: bool = False + all_children_dicts: Dict[int, dict] = dataclasses.field(default_factory=dict) @classmethod def label(cls) -> str: @@ -211,8 +213,22 @@ def _debug_has_internal_data_for_task(self, task_id: int): return True return any(task_id in l.children for l in self.__cache_children.values()) - def __getstate__(self): - d = super(ParentChildrenWaiterNode, self).__getstate__() - assert '_ParentChildrenWaiterNode__main_lock' in d - del d['_ParentChildrenWaiterNode__main_lock'] - return d + def get_state(self) -> dict: + return { + 'cache_children': { + name: { + 'children': list(val.children), + 'parent_ready': val.parent_ready, + 'all_children_dicts': val.all_children_dicts, + } for name, val in self.__cache_children.items() + } + } + + def set_state(self, state: dict): + self.__cache_children = { + int(name): ParentChildrenWaiterNode.Entry( + set(val['children']), + val['parent_ready'], + {int(k): v for k, v in val['all_children_dicts'].items()}, + ) for name, val in state['cache_children'].items() + } diff --git a/src/lifeblood/core_nodes/split_waiter.py b/src/lifeblood/core_nodes/split_waiter.py index 356dd5be..9da218ad 100644 --- a/src/lifeblood/core_nodes/split_waiter.py +++ b/src/lifeblood/core_nodes/split_waiter.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import time import json from lifeblood.basenode import BaseNode @@ -16,9 +17,11 @@ from lifeblood.scheduler import Scheduler -class SplitAwaiting(TypedDict): +@dataclass +class SplitAwaiting: arrived: Dict[int, dict] # num in split -2-> attributes awaiting: Set[int] + processed: Set[int] first_to_arrive: Optional[int] @@ -42,7 +45,7 @@ def type_name(cls) -> str: def __init__(self, name: str): super(SplitAwaiterNode, self).__init__(name) - self.__cache: Dict[int: SplitAwaiting] = {} + self.__cache: Dict[int, SplitAwaiting] = {} self.__main_lock = Lock() ui = self.get_ui() with ui.initializing_interface_lock(): @@ -71,7 +74,7 @@ def __get_promote_attribs(self, context): sort_reversed = context.param_value(f'reversed_{i}') if transfer_type == 'append': gathered_values = [] - for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): + for attribs in sorted(self.__cache[split_id].arrived.values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): if src_attr_name not in attribs: continue @@ -80,7 +83,7 @@ def __get_promote_attribs(self, context): attribs_to_promote[dst_attr_name] = gathered_values elif transfer_type == 'extend': gathered_values = [] - for attribs in sorted(self.__cache[split_id]['arrived'].values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): + for attribs in sorted(self.__cache[split_id].arrived.values(), key=lambda x: x.get(sort_attr_name, 0), reverse=sort_reversed): if src_attr_name not in attribs: continue @@ -91,7 +94,7 @@ def __get_promote_attribs(self, context): gathered_values.append(attr_val) attribs_to_promote[dst_attr_name] = gathered_values elif transfer_type == 'first': - _acd = self.__cache[split_id]['arrived'] + _acd = self.__cache[split_id].arrived if len(_acd) > 0: if sort_reversed: attribs = max(_acd.values(), key=lambda x: x.get(sort_attr_name, 0)) @@ -102,7 +105,7 @@ def __get_promote_attribs(self, context): elif transfer_type == 'sum': # we don't care about the order, assume sum is associative gathered_values = None - for attribs in self.__cache[split_id]['arrived'].values(): + for attribs in self.__cache[split_id].arrived.values(): if src_attr_name not in attribs: continue if gathered_values is None: @@ -120,8 +123,8 @@ def ready_to_process_task(self, task_dict) -> bool: # we don't even need to lock return split_id not in self.__cache or \ not context.param_value('wait for all') or \ - context.task_field('split_element') not in self.__cache[split_id]['arrived'] or \ - self.__cache[split_id]['arrived'].keys() == self.__cache[split_id]['awaiting'] + context.task_field('split_element') not in self.__cache[split_id].arrived or \ + self.__cache[split_id].arrived.keys() == self.__cache[split_id].awaiting def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib not taken into account, rethink return type orig_id = context.task_field('split_origin_task_id') @@ -131,15 +134,17 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib return ProcessingResult() with self.__main_lock: if split_id not in self.__cache: - self.__cache[split_id] = {'arrived': {}, - 'awaiting': set(range(context.task_field('split_count'))), - 'processed': set(), - 'first_to_arrive': None} - if self.__cache[split_id]['first_to_arrive'] is None and len(self.__cache[split_id]['arrived']) == 0: - self.__cache[split_id]['first_to_arrive'] = task_id - if context.task_field('split_element') not in self.__cache[split_id]['arrived']: - self.__cache[split_id]['arrived'][context.task_field('split_element')] = json.loads(context.task_field('attributes')) - self.__cache[split_id]['arrived'][context.task_field('split_element')]['_builtin_id'] = task_id + self.__cache[split_id] = SplitAwaiting( + {}, + set(range(context.task_field('split_count'))), + set(), + None + ) + if self.__cache[split_id].first_to_arrive is None and len(self.__cache[split_id].arrived) == 0: + self.__cache[split_id].first_to_arrive = task_id + if context.task_field('split_element') not in self.__cache[split_id].arrived: + self.__cache[split_id].arrived[context.task_field('split_element')] = json.loads(context.task_field('attributes')) + self.__cache[split_id].arrived[context.task_field('split_element')]['_builtin_id'] = task_id # we will not wait in loop or we risk deadlocking threadpool # check if everyone is ready @@ -147,12 +152,12 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib try: if context.param_value('wait for all'): with self.__main_lock: - if self.__cache[split_id]['arrived'].keys() == self.__cache[split_id]['awaiting']: + if self.__cache[split_id].arrived.keys() == self.__cache[split_id].awaiting: res = ProcessingResult() res.kill_task() - self.__cache[split_id]['processed'].add(context.task_field('split_element')) - if self.__cache[split_id]['first_to_arrive'] == task_id: - # transfer attributes # TODO: delete cache for already processed splits + self.__cache[split_id].processed.add(context.task_field('split_element')) + if self.__cache[split_id].first_to_arrive == task_id: + # transfer attributes attribs_to_promote = self.__get_promote_attribs(context) res.remove_split(attributes_to_set=attribs_to_promote) @@ -162,16 +167,16 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib with self.__main_lock: res = ProcessingResult() res.kill_task() - self.__cache[split_id]['processed'].add(context.task_field('split_element')) - if self.__cache[split_id]['first_to_arrive'] == task_id: + self.__cache[split_id].processed.add(context.task_field('split_element')) + if self.__cache[split_id].first_to_arrive == task_id: res.remove_split() changed = True return res finally: - if self.__cache[split_id]['processed'] == self.__cache[split_id]['awaiting']: # kinda precheck, to avoid extra lockings + if self.__cache[split_id].processed == self.__cache[split_id].awaiting: # kinda precheck, to avoid extra lockings with self.__main_lock: - if self.__cache[split_id]['processed'] == self.__cache[split_id]['awaiting']: # and proper check inside lock + if self.__cache[split_id].processed == self.__cache[split_id].awaiting: # and proper check inside lock del self.__cache[split_id] # if changed: # self._state_changed() # this cannot be called from non asyncio thread as this. @@ -188,8 +193,20 @@ def _debug_has_internal_data_for_split(self, split_id: int) -> bool: """ return split_id in self.__cache - def __getstate__(self): - d = super(SplitAwaiterNode, self).__getstate__() - assert '_SplitAwaiterNode__main_lock' in d - del d['_SplitAwaiterNode__main_lock'] - return d + def get_state(self) -> dict: + return { + 'cache': {k: { + 'arrived': v.arrived, + 'awaiting': list(v.awaiting), + 'processed': list(v.processed), + 'first_to_arrive': v.first_to_arrive, + } for k, v in self.__cache.items()}, + } + + def set_state(self, state: dict): + self.__cache = {int(k): SplitAwaiting( + {int(i): j for i, j in v['arrived'].items()}, + set(v['awaiting']), + set(v['processed']), + v['first_to_arrive'] + ) for k, v in state['cache'].items()} diff --git a/src/lifeblood/core_nodes/wait_for_task.py b/src/lifeblood/core_nodes/wait_for_task.py index 2d4d96bb..97a9029c 100644 --- a/src/lifeblood/core_nodes/wait_for_task.py +++ b/src/lifeblood/core_nodes/wait_for_task.py @@ -117,8 +117,12 @@ def process_task(self, context: ProcessingContext) -> ProcessingResult: return self.__get_default_result(context) raise NodeNotReadyToProcess() - def __getstate__(self): - d = super(WaitForTaskValue, self).__getstate__() - assert '_WaitForTaskValue__main_lock' in d - del d['_WaitForTaskValue__main_lock'] - return d + def get_state(self) -> dict: + return { + 'values_map': self.__values_map, + 'values_set_cache': list(self.__values_set_cache) + } + + def set_state(self, state: dict): + self.__values_map = {int(k): v for k,v in state['values_map'].items()} + self.__values_set_cache = set(state['values_set_cache']) diff --git a/src/lifeblood/db_misc.py b/src/lifeblood/db_misc.py index c4046067..152cf7f2 100644 --- a/src/lifeblood/db_misc.py +++ b/src/lifeblood/db_misc.py @@ -81,7 +81,8 @@ "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, "type" TEXT NOT NULL, "name" TEXT, - "node_object" BLOB + "node_object" BLOB, + "node_object_state" BLOB ); CREATE TABLE IF NOT EXISTS "task_groups" ( "task_id" INTEGER NOT NULL, diff --git a/src/lifeblood/launch.py b/src/lifeblood/launch.py index 5b894222..ad8be351 100644 --- a/src/lifeblood/launch.py +++ b/src/lifeblood/launch.py @@ -31,13 +31,13 @@ def main(argv): logging.set_default_loglevel(opts.loglevel) if opts.command == 'scheduler': - from .scheduler import main + from .main_scheduler import main return main(cmd_argv) elif opts.command == 'worker': - from .worker import main + from .main_worker import main return main(cmd_argv) elif opts.command == 'pool': - from .worker_pool import main + from .main_workerpool import main return main(cmd_argv) elif opts.command == 'viewer': try: diff --git a/src/lifeblood/main_scheduler.py b/src/lifeblood/main_scheduler.py new file mode 100644 index 00000000..df73d68e --- /dev/null +++ b/src/lifeblood/main_scheduler.py @@ -0,0 +1,176 @@ +import sys +import os +import re +from pathlib import Path +import asyncio +import signal +from .pluginloader import PluginNodeDataProvider +from .scheduler import Scheduler +from .basenode_serializer_v1 import NodeSerializerV1 +from .basenode_serializer_v2 import NodeSerializerV2 +from .defaults import scheduler_port as default_scheduler_port, ui_port as default_ui_port +from .config import get_config, create_default_user_config_file, get_local_scratch_path +from . import logging +from . import paths +from .text import escape + +from typing import Optional, Tuple + + +__esc = '\\"' + +default_config = f''' +[core] +## you can uncomment stuff below to specify some static values +## +# server_ip = "192.168.0.2" +# server_port = {default_scheduler_port()} +# ui_ip = "192.168.0.2" +# ui_port = {default_ui_port()} + +## you can turn off scheduler broadcasting if you want to manually configure viewer and workers to connect +## to a specific address +# broadcast = false + +[scheduler] + +[scheduler.globals] +## entries from this section will be available to any node from config[key] +## +## if you use more than 1 machine - you must change this to a network location shared among all workers +## by default it's set to scheduler's machine local temp path, and will only work for 1 machine setup +global_scratch_location = "{escape(get_local_scratch_path(), __esc)}" + +[scheduler.database] +## you can specify default database path, +## but this can be overriden with command line argument --db-path +# path = "/path/to/database.db" + +## uncomment line below to store task logs outside of the database +## it works in a way that all NEW logs will be saved according to settings below +## existing logs will be kept where they are +## external logs will ALWAYS be looked for in location specified by store_logs_externally_location +## so if you have ANY logs saved externally - you must keep store_logs_externally_location defined in the config, +## or those logs will be inaccessible +## but you can safely move logs and change location in config accordingly, but be sure scheduler is not accessing them at that time +# store_logs_externally = true +# store_logs_externally_location = /path/to/dir/where/to/store/logs +''' + + +def create_default_scheduler(db_file_path, *, + do_broadcasting: Optional[bool] = None, + broadcast_interval: Optional[int] = None, + helpers_minimal_idle_to_ensure=1, + server_addr: Optional[Tuple[str, int, int]] = None, + server_ui_addr: Optional[Tuple[str, int]] = None) -> Scheduler: + return Scheduler( + db_file_path, + node_data_provider=PluginNodeDataProvider(), + node_serializers=[NodeSerializerV2(), NodeSerializerV1()], + do_broadcasting=do_broadcasting, + broadcast_interval=broadcast_interval, + helpers_minimal_idle_to_ensure=helpers_minimal_idle_to_ensure, + server_addr=server_addr, + server_ui_addr=server_ui_addr, + ) + + +async def main_async(db_path=None, *, broadcast_interval: Optional[int] = None): + def graceful_closer(*args): + scheduler.stop() + + noasync_do_close = False + def noasync_windows_graceful_closer_event(*args): + nonlocal noasync_do_close + noasync_do_close = True + + async def windows_graceful_closer(): + while not noasync_do_close: + await asyncio.sleep(1) + graceful_closer() + + scheduler = create_default_scheduler( + db_path, + do_broadcasting=broadcast_interval > 0 if broadcast_interval is not None else None, + broadcast_interval=broadcast_interval + ) + win_signal_waiting_task = None + try: + asyncio.get_event_loop().add_signal_handler(signal.SIGINT, graceful_closer) + asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, graceful_closer) + except NotImplementedError: # solution for windows + signal.signal(signal.SIGINT, noasync_windows_graceful_closer_event) + signal.signal(signal.SIGBREAK, noasync_windows_graceful_closer_event) + win_signal_waiting_task = asyncio.create_task(windows_graceful_closer()) + + await scheduler.start() + await scheduler.wait_till_stops() + if win_signal_waiting_task is not None: + if not win_signal_waiting_task.done(): + win_signal_waiting_task.cancel() + logging.get_logger('scheduler').info('SCHEDULER STOPPED') + + +def main(argv): + import argparse + import tempfile + + parser = argparse.ArgumentParser('lifeblood scheduler') + parser.add_argument('--db-path', help='path to sqlite database to use') + parser.add_argument('--ephemeral', action='store_true', help='start with an empty one time use database, that is placed into shared memory IF POSSIBLE') + parser.add_argument('--verbosity-pinger', help='set individual verbosity for worker pinger') + parser.add_argument('--broadcast-interval', type=int, help='help easily override broadcasting interval (in seconds). value 0 disables broadcasting') + opts = parser.parse_args(argv) + + # check and create default config if none + create_default_user_config_file('scheduler', default_config) + + config = get_config('scheduler') + if opts.db_path is not None: + db_path = opts.db_path + else: + db_path = config.get_option_noasync('scheduler.database.path', str(paths.default_main_database_location())) + + global_logger = logging.get_logger('scheduler') + + fd = None + if opts.ephemeral: + if opts.db_path is not None: + parser.error('only one of --db-path or --ephemeral must be provided, not both') + # 'file:memorydb?mode=memory&cache=shared' + # this does not work ^ cuz shared cache means that all runs on the *same connection* + # and when there is a transaction conflict on the same connection - we get instalocked (SQLITE_LOCKED) + # and there is no way to emulate normal DB in memory but with shared cache + + # look for shm (UNIX only) + shm_path = Path('/dev/shm') + lb_shm_path = None + if shm_path.exists(): + lb_shm_path = shm_path/f'u{os.getuid()}-lifeblood' + try: + lb_shm_path.mkdir(exist_ok=True) + except Exception as e: + global_logger.warning('/dev/shm is not accessible (permission issues?), creating ephemeral database in temp dir') + lb_shm_path = None + else: + global_logger.warning('/dev/shm is not supported by OS, creating ephemeral database in temp dir') + + fd, db_path = tempfile.mkstemp(dir=lb_shm_path, prefix='shedb-') + + if opts.verbosity_pinger: + logging.get_logger('scheduler.worker_pinger').setLevel(opts.verbosity_pinger) + try: + asyncio.run(main_async(db_path, broadcast_interval=opts.broadcast_interval)) + except KeyboardInterrupt: + global_logger.warning('SIGINT caught') + global_logger.info('SIGINT caught. Scheduler is stopped now.') + finally: + if opts.ephemeral: + assert fd is not None + os.close(fd) + os.unlink(db_path) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/src/lifeblood/main_worker.py b/src/lifeblood/main_worker.py new file mode 100644 index 00000000..aa5ec34f --- /dev/null +++ b/src/lifeblood/main_worker.py @@ -0,0 +1,200 @@ +import asyncio +import json +import signal +from . import logging +from .nethelpers import get_default_addr +from .broadcasting import await_broadcast +from .config import get_config, create_default_user_config_file +from .enums import WorkerType, ProcessPriorityAdjustment +from .net_messages.address import AddressChain +from .worker import Worker + +from typing import Optional + +default_config = ''' +[worker] +listen_to_broadcast = true + +[default_env_wrapper] +## here you can uncomment lines below to specify your own default environment wrapper and default arguments +## this will only be used by invocation jobs that have NO environment wrappers specified +# name = TrivialEnvironmentResolver +# arguments = [ "project_name", "or", "config_name", "idunno", "maybe rez packages requirements?", [1,4,11] ] + +[resources] +## here you can override resources that this machine has +## if you don't specify anything - resources will be detected automatically +## NOTE: automatic detection DOES NOT WORK FOR GPU yet, you have to specify it manually +# cpu_count = 32 # by default treated as the number of cores +# cpu_mem = "128G" # you can either specify int amount of bytes, or use string ending with one of "K" "M" "G" "T" "P" meaning Kilo, Mega, Giga, ... +# gpu_count = 1 # by default treated as the number devices +# gpu_mem = "8G" # you can either specify int amount of bytes, or use string ending with one of "K" "M" "G" "T" "P" meaning Kilo, Mega, Giga, ... +''' + + +async def main_async(worker_type=WorkerType.STANDARD, + child_priority_adjustment: ProcessPriorityAdjustment = ProcessPriorityAdjustment.NO_CHANGE, + singleshot: bool = False, worker_id: Optional[int] = None, pool_address=None, noloop=False): + """ + listen to scheduler broadcast in a loop. + if received - create the worker and work + if worker cannot ping the scheduler a number of times - it stops + and listenting for broadcast starts again + :return: Never! + """ + graceful_closer_no_reentry = False + + def graceful_closer(*args): + nonlocal graceful_closer_no_reentry + if graceful_closer_no_reentry: + print('DOUBLE SIGNAL CAUGHT: ALREADY EXITING') + return + graceful_closer_no_reentry = True + logging.get_logger('worker').info('SIGINT/SIGTERM caught') + nonlocal noloop + noloop = True + stop_event.set() + if worker is not None: + worker.stop() + + noasync_do_close = False + + def noasync_windows_graceful_closer_event(*args): + nonlocal noasync_do_close + noasync_do_close = True + + async def windows_graceful_closer(): + while not noasync_do_close: + await asyncio.sleep(1) + graceful_closer() + + worker = None + stop_event = asyncio.Event() + win_signal_waiting_task = None + try: + asyncio.get_event_loop().add_signal_handler(signal.SIGINT, graceful_closer) + asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, graceful_closer) + except NotImplementedError: # solution for windows + signal.signal(signal.SIGINT, noasync_windows_graceful_closer_event) + signal.signal(signal.SIGBREAK, noasync_windows_graceful_closer_event) + win_signal_waiting_task = asyncio.create_task(windows_graceful_closer()) + + config = get_config('worker') + logger = logging.get_logger('worker') + if await config.get_option('worker.listen_to_broadcast', True): + stop_task = asyncio.create_task(stop_event.wait()) + while True: + logger.info('listening for scheduler broadcasts...') + broadcast_task = asyncio.create_task(await_broadcast('lifeblood_scheduler')) + done, _ = await asyncio.wait((broadcast_task, stop_task), return_when=asyncio.FIRST_COMPLETED) + if stop_task in done: + broadcast_task.cancel() + logger.info('broadcast listening cancelled') + break + assert broadcast_task.done() + message = await broadcast_task + scheduler_info = json.loads(message) + logger.debug('received', scheduler_info) + addr = AddressChain(scheduler_info['message_address']) + try: + worker = Worker(addr, child_priority_adjustment=child_priority_adjustment, worker_type=worker_type, singleshot=singleshot, worker_id=worker_id, pool_address=pool_address) + await worker.start() # note that server is already started at this point + except Exception: + logger.exception('could not start the worker') + else: + await worker.wait_till_stops() + logger.info('worker quited') + if noloop: + break + else: + logger.info('boradcast listening disabled') + while True: + addr = AddressChain(await config.get_option('worker.scheduler_address', get_default_addr())) + logger.debug(f'using {addr}') + try: + worker = Worker(addr, child_priority_adjustment=child_priority_adjustment, worker_type=worker_type, singleshot=singleshot, worker_id=worker_id, pool_address=pool_address) + await worker.start() # note that server is already started at this point + except ConnectionRefusedError as e: + logger.exception('Connection error', str(e)) + await asyncio.sleep(10) + continue + await worker.wait_till_stops() + logger.info('worker quited') + if noloop: + break + + if win_signal_waiting_task is not None: # this happens only on windows + if not win_signal_waiting_task.done(): + win_signal_waiting_task.cancel() + else: + asyncio.get_event_loop().remove_signal_handler(signal.SIGINT) # this seem to fix the bad signal fd error + asyncio.get_event_loop().remove_signal_handler(signal.SIGTERM) # my guess what happens is that loop closes, but signal handlers remain if not unsed + + +def main(argv): + # import signal + # prev = None + # def signal_handler(sig, frame): + # print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! You pressed Ctrl+C !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + # prev(sig, frame) + # + # prev = signal.signal(signal.SIGINT, signal_handler) + import argparse + parser = argparse.ArgumentParser('lifeblood worker', description='executes invocations from scheduler') + parser.add_argument('--scheduler-address', help='manually specify scheduler to connect to. if not specified - by default worker will start listening to broadcasts from schedulers') + parser.add_argument('--no-listen-broadcast', action='store_true', help='do not listen to scheduler\'s broadcast, use config') + parser.add_argument('--no-loop', action='store_true', help='by default worker will return into the loop of waiting for scheduler every time it quits because of connection loss, or other errors. ' + 'but this flag will force worker to just completely quit instead') + parser.add_argument('--singleshot', action='store_true', help='worker will pick one job and exit after that job is completed or cancelled. ' + 'this is on by default when type=SCHEDULER_HELPER') + parser.add_argument('--type', choices=('STANDARD', 'SCHEDULER_HELPER'), default='STANDARD') + parser.add_argument('--id', help='integer identifier which worker should use when talking to worker pool') + parser.add_argument('--pool-address', help='if this worker is a part of a pool - pool address. currently pool can only be on the same host') + parser.add_argument('--priority', choices=tuple(x.name for x in ProcessPriorityAdjustment), default=ProcessPriorityAdjustment.NO_CHANGE.name, help='adjust child process priority') + parser.add_argument('--generate-config-only', action='store_true', help='just generate initial config and exit. Note that existing config will NOT be overriden') + + args = parser.parse_args(argv) + + # check and create default config if none + create_default_user_config_file('worker', default_config) + + if args.generate_config_only: + return + + if args.type == 'STANDARD': + wtype = WorkerType.STANDARD + elif args.type == 'SCHEDULER_HELPER': + wtype = WorkerType.SCHEDULER_HELPER + else: + raise NotImplementedError(f'worker type {args.type} is not yet implemented') + + priority_adjustment = [x for x in ProcessPriorityAdjustment if x.name == args.priority][0] # there MUST be exactly 1 match + + global_logger = logging.get_logger('worker') + + # check and create default config if none + create_default_user_config_file('worker', default_config) + + # check legality of the address + paddr = AddressChain(args.pool_address) + + config = get_config('worker') + if args.no_listen_broadcast: + config.set_override('worker.listen_to_broadcast', False) + if args.scheduler_address is not None: + config.set_override('worker.listen_to_broadcast', False) + saddr = AddressChain(args.scheduler_address) + config.set_override('worker.scheduler_address', str(saddr)) + try: + asyncio.run(main_async(wtype, child_priority_adjustment=priority_adjustment, singleshot=args.singleshot, worker_id=int(args.id) if args.id is not None else None, pool_address=paddr, noloop=args.no_loop)) + except KeyboardInterrupt: + # if u see errors in pycharm around this area when running from scheduler - + # it's because pycharm and most shells send SIGINTs to this child process on top of SIGINT that pool sends + # this stuff above tries to suppress that double SIGINTing, but it's not 100% solution + global_logger.warning('SIGINT caught where it wasn\'t supposed to be caught') + global_logger.info('SIGINT caught. Worker is stopped now.') + + +if __name__ == '__main__': + import sys + main(sys.argv) diff --git a/src/lifeblood/worker_pool.py b/src/lifeblood/main_workerpool.py similarity index 100% rename from src/lifeblood/worker_pool.py rename to src/lifeblood/main_workerpool.py diff --git a/src/lifeblood/net_classes.py b/src/lifeblood/net_classes.py index 342c207b..d54bbb9f 100644 --- a/src/lifeblood/net_classes.py +++ b/src/lifeblood/net_classes.py @@ -3,14 +3,10 @@ import pickle import copy import re -from .base import TypeMetadata from .misc import get_unique_machine_id from .logging import get_logger -from typing import Optional, TYPE_CHECKING, Tuple, Type, Union, Set -if TYPE_CHECKING: - from .basenode import BaseNode - from .plugin_info import PluginInfo +from typing import Optional, Union __logger = get_logger('worker_resources') @@ -52,55 +48,6 @@ def _try_parse_mem_spec(s: Union[str, int], default: Optional[int] = None): return int(bytes_count) -class NodeTypePluginMetadata: - def __init__(self, plugin_info: "PluginInfo"): - self.__package_name = plugin_info.package_name() - self.__category = plugin_info.category() - - @property - def package_name(self) -> Optional[str]: - return self.__package_name - - @property - def category(self) -> str: - return self.__category - - -class NodeTypeMetadata(TypeMetadata): - def __init__(self, node_type: Type["BaseNode"]): - from . import pluginloader # here cuz it should only be created from lifeblood, but can be used from viewer too - self.__type_name = node_type.type_name() - self.__plugin_info = NodeTypePluginMetadata(node_type.my_plugin()) - self.__label = node_type.label() - self.__tags = set(node_type.tags()) - self.__description = node_type.description() - self.__settings_names = tuple(pluginloader.nodes_settings.get(node_type.type_name(), {}).keys()) - - @property - def type_name(self) -> str: - return self.__type_name - - @property - def plugin_info(self) -> NodeTypePluginMetadata: - return self.__plugin_info - - @property - def label(self) -> Optional[str]: - return self.__label - - @property - def tags(self) -> Set[str]: - return self.__tags - - @property - def description(self) -> str: - return self.__description - - @property - def settings_names(self) -> Tuple[str, ...]: - return self.__settings_names - - class WorkerResources: __res_names = ('cpu_count', 'cpu_mem', 'gpu_count', 'gpu_mem') # name of all main resources __resource_epsilon = 1e-5 diff --git a/src/lifeblood/node_dataprovider_base.py b/src/lifeblood/node_dataprovider_base.py new file mode 100644 index 00000000..2e390e71 --- /dev/null +++ b/src/lifeblood/node_dataprovider_base.py @@ -0,0 +1,41 @@ +from pathlib import Path +from .basenode import BaseNode +from .snippets import NodeSnippetData + +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union + + +class NodeDataProvider: + def node_settings_names(self, type_name: str) -> Set[str]: + raise NotImplementedError() + + def node_settings(self, type_name: str, settings_name: str) -> dict: + raise NotImplementedError() + + def node_type_names(self) -> Set[str]: + raise NotImplementedError() + + def node_class(self, type_name) -> Type[BaseNode]: + raise NotImplementedError() + + def node_factory(self, node_type: str) -> Callable[[str], BaseNode]: + raise NotImplementedError() + + def has_node_factory(self, node_type: str) -> bool: + raise NotImplementedError() + + def node_preset_packages(self) -> Set[str]: + raise NotImplementedError() + + # node presets - + def node_preset_names(self, package_name: str) -> Set[str]: + raise NotImplementedError() + + def node_preset(self, package_name: str, preset_name: str) -> NodeSnippetData: + raise NotImplementedError() + + def add_settings_to_existing_package(self, package_name_or_path: Union[str, Path], node_type_name: str, settings_name: str, settings: Dict[str, Any]): + raise NotImplementedError() + + def set_settings_as_default(self, node_type_name: str, settings_name: Optional[str]): + raise NotImplementedError() diff --git a/src/lifeblood/node_type_metadata.py b/src/lifeblood/node_type_metadata.py new file mode 100644 index 00000000..9ea75943 --- /dev/null +++ b/src/lifeblood/node_type_metadata.py @@ -0,0 +1,54 @@ +from .base import TypeMetadata +from .node_dataprovider_base import NodeDataProvider +from .plugin_info import PluginInfo + +from typing import Optional, TYPE_CHECKING, Tuple, Set + + +class NodeTypePluginMetadata: + def __init__(self, plugin_info: PluginInfo): + self.__package_name = plugin_info.package_name() + self.__category = plugin_info.category() + + @property + def package_name(self) -> Optional[str]: + return self.__package_name + + @property + def category(self) -> str: + return self.__category + + +class NodeTypeMetadata(TypeMetadata): + def __init__(self, node_data_provider: NodeDataProvider, node_type_name: str): + self.__type_name = node_type_name + node_class = node_data_provider.node_class(node_type_name) + self.__plugin_info = NodeTypePluginMetadata(node_class.my_plugin()) + self.__label = node_class.label() + self.__tags = set(node_class.tags()) + self.__description = node_class.description() + self.__settings_names = tuple(node_data_provider.node_settings_names(node_type_name)) + + @property + def type_name(self) -> str: + return self.__type_name + + @property + def plugin_info(self) -> NodeTypePluginMetadata: + return self.__plugin_info + + @property + def label(self) -> Optional[str]: + return self.__label + + @property + def tags(self) -> Set[str]: + return self.__tags + + @property + def description(self) -> str: + return self.__description + + @property + def settings_names(self) -> Tuple[str, ...]: + return self.__settings_names diff --git a/src/lifeblood/nodegraph_holder_base.py b/src/lifeblood/nodegraph_holder_base.py new file mode 100644 index 00000000..2ab772ed --- /dev/null +++ b/src/lifeblood/nodegraph_holder_base.py @@ -0,0 +1,16 @@ +import asyncio +from typing import Optional + + +class NodeGraphHolderBase: + async def get_node_input_connections(self, node_id: int, input_name: Optional[str] = None): + raise NotImplementedError() + + async def get_node_output_connections(self, node_id: int, output_name: Optional[str] = None): + raise NotImplementedError() + + async def node_reports_changes_needs_saving(self, node_id): + raise NotImplementedError() + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + raise NotImplementedError() diff --git a/src/lifeblood/plugin_info.py b/src/lifeblood/plugin_info.py index 3da37994..6eb9795c 100644 --- a/src/lifeblood/plugin_info.py +++ b/src/lifeblood/plugin_info.py @@ -1,6 +1,6 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Union class PluginInfo: @@ -8,9 +8,10 @@ class PluginInfo: class for getting information about a plugin """ - def __init__(self, file_path, category, parent_package=None): + def __init__(self, file_path: Union[str, Path], plugin_hash: str, category: str, parent_package: Union[None, str, Path] = None): self.__file_path = Path(file_path) self.__category = category + self.__hash = plugin_hash self.__parent_package = Path(parent_package) if parent_package is not None else None self.__parent_package_data = None @@ -24,6 +25,9 @@ def __init__(self, file_path, category, parent_package=None): def category(self) -> str: return self.__category + def hash(self) -> str: + return self.__hash + def package_name(self) -> Optional[str]: if self.__parent_package is None: return None @@ -42,3 +46,6 @@ def node_definition_file_path(self) -> Path: def __str__(self): return f'Plugin from {self.node_definition_file_path()}, part of {self.package_name()}' + + +empty_plugin_info = PluginInfo('', '', 'invalid', None) \ No newline at end of file diff --git a/src/lifeblood/pluginloader.py b/src/lifeblood/pluginloader.py index c177e9ea..1bf4d248 100644 --- a/src/lifeblood/pluginloader.py +++ b/src/lifeblood/pluginloader.py @@ -7,285 +7,330 @@ import toml from pathlib import Path +from .basenode import BaseNode +from .node_dataprovider_base import NodeDataProvider from .snippets import NodeSnippetData from . import logging, plugin_info, paths -from typing import List, Tuple, Dict, Any, Union, Optional, TYPE_CHECKING - -if TYPE_CHECKING: - from .basenode import BaseNode - - -plugins = {} -presets: Dict[str, Dict[str, NodeSnippetData]] = {} -# map of node type -2-> -# preset_name -2-> -# dict of parameter name -2-> value -nodes_settings: Dict[str, Dict[str, Dict[str, Any]]] = {} -default_settings_config: Dict[str, str] = {} -__plugin_file_hashes = {} - -# package is identified by it's path, but it's handy to address them by short names -# short name is generated from dir name. packages can have same dir names, but then -# only first one will get into this locations dict -__package_locations: Dict[str, Path] = {} - -logger = logging.get_logger('plugin_loader') - - -def _install_node(filepath, plugin_category, parent_package=None): - """ - - :param filepath: - :param plugin_category: - :param parent_package: path to the base of the package, if this plugin is part of one, else - None - :return: - """ - filename = os.path.basename(filepath) - filebasename, fileext = os.path.splitext(filename) - - modpath = f'lifeblood.nodeplugins.{plugin_category}.{filebasename}' - mod_spec = importlib.util.spec_from_file_location(modpath, filepath) - try: - mod = importlib.util.module_from_spec(mod_spec) - mod_spec.loader.exec_module(mod) - mod._plugin_info = plugin_info.PluginInfo(filepath, plugin_category, parent_package) - except: - logger.exception(f'failed to load plugin "{filebasename}". skipping.') - return - for requred_attr in ('node_class',): - if not hasattr(mod, requred_attr): - logger.error(f'error loading plugin "{filebasename}". ' - f'required method {requred_attr} is missing.') - return - plugins[mod.node_class().type_name()] = mod - hasher = hashlib.md5() - with open(filepath, 'rb') as f: - hasher.update(f.read()) - __plugin_file_hashes[mod.node_class().type_name()] = hasher.hexdigest() - sys.modules[modpath] = mod - - -def _install_package(package_path, plugin_category): - """ - package structure: - [package_name:dir] - |_bin - | |_any <- this is always added to PATH - | |_system-arch1 <- these are added to PATH only if system+arch match - | |_system-arch2 <-/ - |_python - | |_X <- these are added to PYTHONPATH based on X.Y - | |_X.Y <-/ - |_nodes - | |_node1.py <- these are loaded as usual node plugins - | |_node2.py <-/ - |_data <- just a convenient place to store shit, can be accessed with data from plugin - |_settings <- for future saved nodes settings. not implemented yet - | |_node_type_name1 - | | |_settings1.lbs - | | |_settings2.lbs - | |_node_type_name2 - | | |_settings1.lbs - | | |_settings2.lbs - |_whatever_file1.lol - |_whatever_dir1 - |_whatever_file2.lol - - :param package_path: - :param plugin_category: - :return: - """ - package_name = os.path.basename(package_path) - global __package_locations - if package_name not in __package_locations: # read logic of this up - __package_locations[package_name] = Path(package_path) - # add extra bin paths - extra_bins = [] - for subbin in (f'{platform.system().lower()}-{platform.machine().lower()}', 'any'): - bin_base_path = os.path.join(package_path, 'bin', subbin) - if not os.path.exists(bin_base_path): - continue - extra_bins.append(bin_base_path) - if extra_bins: - os.environ['PATH'] = os.pathsep.join(extra_bins) + os.environ['PATH'] - - # install extra python modules - python_base_path = os.path.join(package_path, 'python') - if os.path.exists(python_base_path): - sysver = sys.version_info - pyvers = [tuple(int(y) for y in x.split('.')) for x in os.listdir(python_base_path) if x.isdigit() or re.match(r'^\d+\.\d+$', x)] - pyvers = [x for x in pyvers if x[0] == sysver.major - and (len(x) < 2 or x[1] == sysver.minor) - and (len(x) < 3 or x[2] == sysver.micro)] - pyvers = sorted(pyvers, key=lambda x: len(x), reverse=True) - for pyver in pyvers: - extra_python = os.path.join(python_base_path, '.'.join(str(x) for x in pyver)) - sys.path.append(extra_python) - os.environ['PYTHONPATH'] = os.pathsep.join((extra_python, os.environ['PYTHONPATH'])) if 'PYTHONPATH' in os.environ else extra_python - - # install nodes - nodes_path = os.path.join(package_path, 'nodes') - if os.path.exists(nodes_path): - for filename in os.listdir(nodes_path): - filebasename, fileext = os.path.splitext(filename) - if fileext != '.py': +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Set + + +class PluginNodeDataProvider(NodeDataProvider): + __instance = None + + @classmethod + def instance(cls): + if cls.__instance is None: + cls.__instance = PluginNodeDataProvider() + return cls.__instance + + def __init__(self): + if self.__instance is not None: + raise RuntimeError("cannot have more than one PluginNodeDataProvider instance, as it manages global state") + + self.__plugins = {} + self.__presets: Dict[str, Dict[str, NodeSnippetData]] = {} + # map of node type -2-> + # preset_name -2-> + # dict of parameter name -2-> value + self.__nodes_settings: Dict[str, Dict[str, Dict[str, Any]]] = {} + self.__default_settings_config: Dict[str, str] = {} + self.__plugin_file_hashes = {} + + # package is identified by it's path, but it's handy to address them by short names + # short name is generated from dir name. packages can have same dir names, but then + # only first one will get into this locations dict + self.__package_locations: Dict[str, Path] = {} + + self.logger = logging.get_logger('plugin_loader') + + # now do initial scannings + + self.logger.info('loading core plugins') + self.__plugins = {} + plugin_paths: List[Tuple[str, str]] = [] # list of tuples of path to dir, plugin category + core_plugins_path = os.path.join(os.path.dirname(__file__), 'core_nodes') + stock_plugins_path = os.path.join(os.path.dirname(__file__), 'stock_nodes') + custom_plugins_path = paths.config_path('', 'custom_plugins') + plugin_paths.append((core_plugins_path, 'core')) + plugin_paths.append((stock_plugins_path, 'stock')) + (custom_plugins_path/'custom_default').mkdir(parents=True, exist_ok=True) + + plugin_paths.append((str(custom_plugins_path), 'user')) + + extra_paths = [] + for path in os.environ.get('LIFEBLOOD_PLUGIN_PATH', '').split(os.pathsep): + if path == '': continue - _install_node(os.path.join(nodes_path, filename), plugin_category, package_path) - - # install presets - presets_path = os.path.join(package_path, 'presets') - if os.path.exists(presets_path): - for filename in os.listdir(presets_path): - filebasename, fileext = os.path.splitext(filename) - if fileext != '.lbp': + if not os.path.isabs(path): + self.logger.warning(f'"{path}" is not absolute, skipping') continue - try: - with open(os.path.join(presets_path, filename), 'rb') as f: - snippet = NodeSnippetData.deserialize(f.read()) - snippet.add_tag('preset') - except Exception as e: - logger.error(f'failed to load snippet {filebasename}, error: {str(e)}') + if not os.path.exists(path): + self.logger.warning(f'"{path}" does not exist, skipping') continue - - if package_name not in presets: - presets[package_name] = {} - presets[package_name][snippet.label] = snippet - - # install node settings - settings_path = os.path.join(package_path, 'settings') - if os.path.exists(settings_path): - for nodetype_name in os.listdir(settings_path): - if nodetype_name not in nodes_settings: - nodes_settings[nodetype_name] = {} - nodetype_path = os.path.join(settings_path, nodetype_name) - for preset_filename in os.listdir(nodetype_path): - preset_name, fileext = os.path.splitext(preset_filename) - if fileext != '.lbs': + extra_paths.append(path) + self.logger.debug(f'using extra plugin path: "{path}"') + + plugin_paths.extend((x, 'extra') for x in extra_paths) + + for plugin_path, plugin_category in plugin_paths: + for filename in os.listdir(plugin_path): + filepath = os.path.join(plugin_path, filename) + if os.path.isdir(filepath): + self._install_package(filepath, plugin_category) + else: + filebasename, fileext = os.path.splitext(filename) + if fileext != '.py': + continue + self._install_node(filepath, plugin_category) + + self.logger.info('loaded node types:\n\t' + '\n\t'.join(self.__plugins.keys())) + self.logger.info('loaded node presets:\n\t' + '\n\t'.join(f'{pkg}::{label}' for pkg, pkgdata in self.__presets.items() for label in pkgdata.keys())) + + # load default settings + default_settings_config_path = paths.config_path('defaults.toml', 'scheduler.nodes') + if default_settings_config_path.exists(): + with open(default_settings_config_path) as f: + self.__default_settings_config = toml.load(f) + + bad_defaults = [] + for node_type, settings_name in self.__default_settings_config.items(): + if settings_name not in self.__nodes_settings.get(node_type, {}): + self.logger.warning(f'"{settings_name}" is set as default for "{node_type}", but no such settings is loaded') + bad_defaults.append(node_type) continue - try: - with open(os.path.join(nodetype_path, preset_filename), 'r') as f: - nodes_settings[nodetype_name][preset_name] = toml.load(f) - except Exception as e: - logger.error(f'failed to load settings {nodetype_name}/{preset_name}, error: {str(e)}') - - -def init(): - logger.info('loading core plugins') - global plugins - plugins = {} - plugin_paths: List[Tuple[str, str]] = [] # list of tuples of path to dir, plugin category - core_plugins_path = os.path.join(os.path.dirname(__file__), 'core_nodes') - stock_plugins_path = os.path.join(os.path.dirname(__file__), 'stock_nodes') - custom_plugins_path = paths.config_path('', 'custom_plugins') - plugin_paths.append((core_plugins_path, 'core')) - plugin_paths.append((stock_plugins_path, 'stock')) - (custom_plugins_path/'custom_default').mkdir(parents=True, exist_ok=True) - - plugin_paths.append((str(custom_plugins_path), 'user')) - - extra_paths = [] - for path in os.environ.get('LIFEBLOOD_PLUGIN_PATH', '').split(os.pathsep): - if path == '': - continue - if not os.path.isabs(path): - logger.warning(f'"{path}" is not absolute, skipping') - continue - if not os.path.exists(path): - logger.warning(f'"{path}" does not exist, skipping') - continue - extra_paths.append(path) - logger.debug(f'using extra plugin path: "{path}"') - - plugin_paths.extend((x, 'extra') for x in extra_paths) - - for plugin_path, plugin_category in plugin_paths: - for filename in os.listdir(plugin_path): - filepath = os.path.join(plugin_path, filename) - if os.path.isdir(filepath): - _install_package(filepath, plugin_category) - else: + + def _install_node(self, filepath, plugin_category, parent_package=None): + """ + + :param filepath: + :param plugin_category: + :param parent_package: path to the base of the package, if this plugin is part of one, else - None + :return: + """ + filename = os.path.basename(filepath) + filebasename, fileext = os.path.splitext(filename) + + # calc module hash + hasher = hashlib.md5() + with open(filepath, 'rb') as f: + hasher.update(f.read()) + plugin_hash = hasher.hexdigest() + + modpath = f'lifeblood.nodeplugins.{plugin_category}.{filebasename}' + mod_spec = importlib.util.spec_from_file_location(modpath, filepath) + try: + mod = importlib.util.module_from_spec(mod_spec) + mod_spec.loader.exec_module(mod) + pluginfo = plugin_info.PluginInfo(filepath, plugin_hash, plugin_category, parent_package) + mod._plugin_info = pluginfo + except: + self.logger.exception(f'failed to load plugin "{filebasename}". skipping.') + return + for requred_attr in ('node_class',): + if not hasattr(mod, requred_attr): + self.logger.error(f'error loading plugin "{filebasename}". ' + f'required method {requred_attr} is missing.') + return + node_class = mod.node_class() + node_class._plugin_data = pluginfo + self.__plugins[node_class.type_name()] = mod + self.__plugin_file_hashes[mod.node_class().type_name()] = plugin_hash + + # TODO: what if it's overriding existing module? + sys.modules[modpath] = mod + + def _install_package(self, package_path, plugin_category): + """ + package structure: + [package_name:dir] + |_bin + | |_any <- this is always added to PATH + | |_system-arch1 <- these are added to PATH only if system+arch match + | |_system-arch2 <-/ + |_python + | |_X <- these are added to PYTHONPATH based on X.Y + | |_X.Y <-/ + |_nodes + | |_node1.py <- these are loaded as usual node plugins + | |_node2.py <-/ + |_data <- just a convenient place to store shit, can be accessed with data from plugin + |_settings <- for future saved nodes settings. not implemented yet + | |_node_type_name1 + | | |_settings1.lbs + | | |_settings2.lbs + | |_node_type_name2 + | | |_settings1.lbs + | | |_settings2.lbs + |_whatever_file1.lol + |_whatever_dir1 + |_whatever_file2.lol + + :param package_path: + :param plugin_category: + :return: + """ + package_name = os.path.basename(package_path) + if package_name not in self.__package_locations: # read logic of this up + self.__package_locations[package_name] = Path(package_path) + # add extra bin paths + extra_bins = [] + for subbin in (f'{platform.system().lower()}-{platform.machine().lower()}', 'any'): + bin_base_path = os.path.join(package_path, 'bin', subbin) + if not os.path.exists(bin_base_path): + continue + extra_bins.append(bin_base_path) + if extra_bins: + os.environ['PATH'] = os.pathsep.join(extra_bins) + os.environ['PATH'] + + # install extra python modules + python_base_path = os.path.join(package_path, 'python') + if os.path.exists(python_base_path): + sysver = sys.version_info + pyvers = [tuple(int(y) for y in x.split('.')) for x in os.listdir(python_base_path) if x.isdigit() or re.match(r'^\d+\.\d+$', x)] + pyvers = [x for x in pyvers if x[0] == sysver.major + and (len(x) < 2 or x[1] == sysver.minor) + and (len(x) < 3 or x[2] == sysver.micro)] + pyvers = sorted(pyvers, key=lambda x: len(x), reverse=True) + for pyver in pyvers: + extra_python = os.path.join(python_base_path, '.'.join(str(x) for x in pyver)) + sys.path.append(extra_python) + + # TODO: this is questionable, this will affect all child processes, we don't want that + os.environ['PYTHONPATH'] = os.pathsep.join((extra_python, os.environ['PYTHONPATH'])) if 'PYTHONPATH' in os.environ else extra_python + + # install nodes + nodes_path = os.path.join(package_path, 'nodes') + if os.path.exists(nodes_path): + for filename in os.listdir(nodes_path): filebasename, fileext = os.path.splitext(filename) if fileext != '.py': continue - _install_node(filepath, plugin_category) - - logger.info('loaded node types:\n\t' + '\n\t'.join(plugins.keys())) - logger.info('loaded node presets:\n\t' + '\n\t'.join(f'{pkg}::{label}' for pkg, pkgdata in presets.items() for label in pkgdata.keys())) - - # load default settings - default_settings_config_path = paths.config_path('defaults.toml', 'scheduler.nodes') - global default_settings_config - if default_settings_config_path.exists(): - with open(default_settings_config_path) as f: - default_settings_config = toml.load(f) - - bad_defaults = [] - for node_type, settings_name in default_settings_config.items(): - if settings_name not in nodes_settings.get(node_type, {}): - logger.warning(f'"{settings_name}" is set as default for "{node_type}", but no such settings is loaded') - bad_defaults.append(node_type) - continue + self._install_node(os.path.join(nodes_path, filename), plugin_category, package_path) + # install presets + presets_path = os.path.join(package_path, 'presets') + if os.path.exists(presets_path): + for filename in os.listdir(presets_path): + filebasename, fileext = os.path.splitext(filename) + if fileext != '.lbp': + continue + try: + with open(os.path.join(presets_path, filename), 'rb') as f: + snippet = NodeSnippetData.deserialize(f.read()) + snippet.add_tag('preset') + except Exception as e: + self.logger.error(f'failed to load snippet {filebasename}, error: {str(e)}') + continue -def plugin_hash(plugin_name) -> str: - return __plugin_file_hashes[plugin_name] - - -def add_settings_to_existing_package(package_name_or_path: Union[str, Path], node_type_name: str, settings_name: str, settings: Dict[str, Any]): - - if isinstance(package_name_or_path, str) and package_name_or_path in __package_locations: - package_name_or_path = __package_locations[package_name_or_path] - else: - package_name_or_path = Path(package_name_or_path) - if package_name_or_path not in __package_locations.values(): - raise RuntimeError('no package with that name or pathfound') - - # at this point package_name_or_path is path - assert(package_name_or_path.exists()) - base_path = package_name_or_path / 'settings' / node_type_name - if not base_path.exists(): - base_path.mkdir(parents=True, exist_ok=True) - with open(base_path / (settings_name + '.lbs'), 'w') as f: - toml.dump(settings, f) - - # add to settings - nodes_settings.setdefault(node_type_name, {})[settings_name] = settings - - -def set_settings_as_default(node_type_name: str, settings_name: Optional[str]): - """ - - :param node_type_name: - :param settings_name: if None - unset any defaults - :return: - """ - if node_type_name not in nodes_settings: - raise RuntimeError(f'node type "{nodes_settings}" is unknown') - if settings_name is not None and settings_name not in nodes_settings[node_type_name]: - raise RuntimeError(f'node type "{nodes_settings}" doesn\'t have settings "{settings_name}"') - if settings_name is None and node_type_name in default_settings_config: - del default_settings_config[node_type_name] - else: - default_settings_config[node_type_name] = settings_name - with open(paths.config_path('defaults.toml', 'scheduler.nodes'), 'w') as f: - toml.dump(default_settings_config, f) - - -def create_node(type_name: str, name, scheduler_parent, node_id): # type: (str, str, Scheduler, int) -> BaseNode - """ - this function is a global node creation point. - it has to be available somewhere global, so plugins loaded from dynamically created modules have an entry point for pickle - """ - if type_name not in plugins: - if type_name == 'basenode': # debug case! base class should never be created directly! - logger.warning('creating BASENODE. if it\'s not for debug/test purposes - it\'s bad!') - from .basenode import BaseNode - node = BaseNode(name) - raise RuntimeError('unknown plugin') - node = plugins[type_name].node_class()(name) - # now set defaults, before parent is set to prevent ui callbacks to parent - if type_name in default_settings_config: - node.apply_settings(default_settings_config[type_name]) - node._set_parent(scheduler_parent, node_id) - return node + if package_name not in self.__presets: + self.__presets[package_name] = {} + self.__presets[package_name][snippet.label] = snippet + + # install node settings + settings_path = os.path.join(package_path, 'settings') + if os.path.exists(settings_path): + for nodetype_name in os.listdir(settings_path): + if nodetype_name not in self.__nodes_settings: + self.__nodes_settings[nodetype_name] = {} + nodetype_path = os.path.join(settings_path, nodetype_name) + for preset_filename in os.listdir(nodetype_path): + preset_name, fileext = os.path.splitext(preset_filename) + if fileext != '.lbs': + continue + try: + with open(os.path.join(nodetype_path, preset_filename), 'r') as f: + self.__nodes_settings[nodetype_name][preset_name] = toml.load(f) + except Exception as e: + self.logger.error(f'failed to load settings {nodetype_name}/{preset_name}, error: {str(e)}') + + def plugin_hash(self, plugin_name) -> str: + return self.__plugin_file_hashes[plugin_name] + + def node_settings_names(self, type_name: str) -> Set[str]: + if type_name not in self.__nodes_settings: + return set() + return set(self.__nodes_settings[type_name].keys()) + + def node_settings(self, type_name: str, settings_name: str) -> dict: + return self.__nodes_settings[type_name][settings_name] + + def node_type_names(self) -> Set[str]: + return set(self.__plugins.keys()) + + def node_class(self, type_name) -> Type[BaseNode]: + return self.__plugins[type_name].node_class() + + def node_factory(self, node_type: str) -> Callable[[str], BaseNode]: + return self.node_class(node_type) + + def has_node_factory(self, node_type: str) -> bool: + return node_type in self.node_type_names() + + def node_preset_packages(self) -> Set[str]: + return set(self.__presets.keys()) + + # node presets - + def node_preset_names(self, package_name: str) -> Set[str]: + return set(self.__presets[package_name]) + + def node_preset(self, package_name: str, preset_name: str) -> NodeSnippetData: + return self.__presets[package_name][preset_name] + + def add_settings_to_existing_package(self, package_name_or_path: Union[str, Path], node_type_name: str, settings_name: str, settings: Dict[str, Any]): + + if isinstance(package_name_or_path, str) and package_name_or_path in self.__package_locations: + package_name_or_path = self.__package_locations[package_name_or_path] + else: + package_name_or_path = Path(package_name_or_path) + if package_name_or_path not in self.__package_locations.values(): + raise RuntimeError('no package with that name or pathfound') + + # at this point package_name_or_path is path + assert package_name_or_path.exists() + base_path = package_name_or_path / 'settings' / node_type_name + if not base_path.exists(): + base_path.mkdir(parents=True, exist_ok=True) + with open(base_path / (settings_name + '.lbs'), 'w') as f: + toml.dump(settings, f) + + # add to settings + self.__nodes_settings.setdefault(node_type_name, {})[settings_name] = settings + + def set_settings_as_default(self, node_type_name: str, settings_name: Optional[str]): + """ + + :param node_type_name: + :param settings_name: if None - unset any defaults + :return: + """ + if node_type_name not in self.__nodes_settings: + raise RuntimeError(f'node type "{self.__nodes_settings}" is unknown') + if settings_name is not None and settings_name not in self.__nodes_settings[node_type_name]: + raise RuntimeError(f'node type "{self.__nodes_settings}" doesn\'t have settings "{settings_name}"') + if settings_name is None and node_type_name in self.__default_settings_config: + self.__default_settings_config.pop(node_type_name) + else: + self.__default_settings_config[node_type_name] = settings_name + with open(paths.config_path('defaults.toml', 'scheduler.nodes'), 'w') as f: + toml.dump(self.__default_settings_config, f) + + # def apply_settings(self, node: BaseNode, settings_name: str) -> None: + # if settings_name not in self.node_settings_names(node.type_name()): + # raise RuntimeError(f'requested settings "{settings_name}" not found for type "{node.type_name()}"') + # settings = self.node_settings(node.type_name(), settings_name) + # node.apply_settings(settings) + + # def create_node(self, type_name: str, name: str) -> BaseNode: + # if type_name not in self.__plugins: + # if type_name == 'basenode': # debug case! base class should never be created directly! + # self.logger.warning('creating BASENODE. if it\'s not for debug/test purposes - it\'s bad!') + # from .basenode import BaseNode + # node = BaseNode(name) + # raise RuntimeError('unknown plugin') + # node: "BaseNode" = self.__plugins[type_name].node_class()(name) + # # now set defaults, before parent is set to prevent ui callbacks to parent + # if type_name in self.__default_settings_config: + # node.apply_settings(self.__default_settings_config[type_name]) + # node.set_data_provider(self) + # return node diff --git a/src/lifeblood/scheduler/__init__.py b/src/lifeblood/scheduler/__init__.py index 2604268d..f3ba972f 100644 --- a/src/lifeblood/scheduler/__init__.py +++ b/src/lifeblood/scheduler/__init__.py @@ -1,153 +1 @@ -import sys -import os -import re -from pathlib import Path -import asyncio -import signal -from ..defaults import scheduler_port as default_scheduler_port, ui_port as default_ui_port -from ..config import get_config, create_default_user_config_file, get_local_scratch_path -from .. import logging -from .. import paths -from ..text import escape - from .scheduler import Scheduler - -from typing import Optional - -__esc = '\\"' - -default_config = f''' -[core] -## you can uncomment stuff below to specify some static values -## -# server_ip = "192.168.0.2" -# server_port = {default_scheduler_port()} -# ui_ip = "192.168.0.2" -# ui_port = {default_ui_port()} - -## you can turn off scheduler broadcasting if you want to manually configure viewer and workers to connect -## to a specific address -# broadcast = false - -[scheduler] - -[scheduler.globals] -## entries from this section will be available to any node from config[key] -## -## if you use more than 1 machine - you must change this to a network location shared among all workers -## by default it's set to scheduler's machine local temp path, and will only work for 1 machine setup -global_scratch_location = "{escape(get_local_scratch_path(), __esc)}" - -[scheduler.database] -## you can specify default database path, -## but this can be overriden with command line argument --db-path -# path = "/path/to/database.db" - -## uncomment line below to store task logs outside of the database -## it works in a way that all NEW logs will be saved according to settings below -## existing logs will be kept where they are -## external logs will ALWAYS be looked for in location specified by store_logs_externally_location -## so if you have ANY logs saved externally - you must keep store_logs_externally_location defined in the config, -## or those logs will be inaccessible -## but you can safely move logs and change location in config accordingly, but be sure scheduler is not accessing them at that time -# store_logs_externally = true -# store_logs_externally_location = /path/to/dir/where/to/store/logs -''' - - -async def main_async(db_path=None, *, broadcast_interval: Optional[int] = None): - def graceful_closer(*args): - scheduler.stop() - - noasync_do_close = False - def noasync_windows_graceful_closer_event(*args): - nonlocal noasync_do_close - noasync_do_close = True - - async def windows_graceful_closer(): - while not noasync_do_close: - await asyncio.sleep(1) - graceful_closer() - - scheduler = Scheduler(db_path, - do_broadcasting=broadcast_interval > 0 if broadcast_interval is not None else None, - broadcast_interval=broadcast_interval) - win_signal_waiting_task = None - try: - asyncio.get_event_loop().add_signal_handler(signal.SIGINT, graceful_closer) - asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, graceful_closer) - except NotImplementedError: # solution for windows - signal.signal(signal.SIGINT, noasync_windows_graceful_closer_event) - signal.signal(signal.SIGBREAK, noasync_windows_graceful_closer_event) - win_signal_waiting_task = asyncio.create_task(windows_graceful_closer()) - - await scheduler.start() - await scheduler.wait_till_stops() - if win_signal_waiting_task is not None: - if not win_signal_waiting_task.done(): - win_signal_waiting_task.cancel() - logging.get_logger('scheduler').info('SCHEDULER STOPPED') - - -def main(argv): - import argparse - import tempfile - - parser = argparse.ArgumentParser('lifeblood scheduler') - parser.add_argument('--db-path', help='path to sqlite database to use') - parser.add_argument('--ephemeral', action='store_true', help='start with an empty one time use database, that is placed into shared memory IF POSSIBLE') - parser.add_argument('--verbosity-pinger', help='set individual verbosity for worker pinger') - parser.add_argument('--broadcast-interval', type=int, help='help easily override broadcasting interval (in seconds). value 0 disables broadcasting') - opts = parser.parse_args(argv) - - # check and create default config if none - create_default_user_config_file('scheduler', default_config) - - config = get_config('scheduler') - if opts.db_path is not None: - db_path = opts.db_path - else: - db_path = config.get_option_noasync('scheduler.database.path', str(paths.default_main_database_location())) - - global_logger = logging.get_logger('scheduler') - - fd = None - if opts.ephemeral: - if opts.db_path is not None: - parser.error('only one of --db-path or --ephemeral must be provided, not both') - # 'file:memorydb?mode=memory&cache=shared' - # this does not work ^ cuz shared cache means that all runs on the *same connection* - # and when there is a transaction conflict on the same connection - we get instalocked (SQLITE_LOCKED) - # and there is no way to emulate normal DB in memory but with shared cache - - # look for shm (UNIX only) - shm_path = Path('/dev/shm') - lb_shm_path = None - if shm_path.exists(): - lb_shm_path = shm_path/f'u{os.getuid()}-lifeblood' - try: - lb_shm_path.mkdir(exist_ok=True) - except Exception as e: - global_logger.warning('/dev/shm is not accessible (permission issues?), creating ephemeral database in temp dir') - lb_shm_path = None - else: - global_logger.warning('/dev/shm is not supported by OS, creating ephemeral database in temp dir') - - fd, db_path = tempfile.mkstemp(dir=lb_shm_path, prefix='shedb-') - - if opts.verbosity_pinger: - logging.get_logger('scheduler.worker_pinger').setLevel(opts.verbosity_pinger) - try: - asyncio.run(main_async(db_path, broadcast_interval=opts.broadcast_interval)) - except KeyboardInterrupt: - global_logger.warning('SIGINT caught') - global_logger.info('SIGINT caught. Scheduler is stopped now.') - finally: - if opts.ephemeral: - assert fd is not None - os.close(fd) - os.unlink(db_path) - - -if __name__ == '__main__': - main(sys.argv[1:]) diff --git a/src/lifeblood/scheduler/data_access.py b/src/lifeblood/scheduler/data_access.py index f6841f2a..b2742b72 100644 --- a/src/lifeblood/scheduler/data_access.py +++ b/src/lifeblood/scheduler/data_access.py @@ -8,7 +8,7 @@ from ..shared_lazy_sqlite_connection import SharedLazyAiosqliteConnection from .. import aiosqlite_overlay -SCHEDULER_DB_FORMAT_VERSION = 1 +SCHEDULER_DB_FORMAT_VERSION = 2 class DataAccess: @@ -40,6 +40,14 @@ def __init__(self, db_path, db_connection_timeout): cur = con.execute('SELECT * FROM lifeblood_metadata') metadata = cur.fetchone() # there should be exactly one single row. cur.close() + elif metadata['version'] != SCHEDULER_DB_FORMAT_VERSION: + self.__database_schema_upgrade(con, metadata['version'], SCHEDULER_DB_FORMAT_VERSION) # returns true if commit needed, but we do update next line anyway + con.execute('UPDATE lifeblood_metadata SET "version" = ?', (SCHEDULER_DB_FORMAT_VERSION,)) + con.commit() + # reget metadata + cur = con.execute('SELECT * FROM lifeblood_metadata') + metadata = cur.fetchone() # there should be exactly one single row. + cur.close() self.__db_uid = struct.unpack('>Q', struct.pack('>q', metadata['unique_db_id']))[0] # reinterpret signed as unsigned @property @@ -63,3 +71,29 @@ async def write_back_cache(self): 'WHERE "id"=?', (cached_row['last_seen'], cached_row['last_checked'], cached_row['ping_state'], wid)) await con.commit() + + # + # db schema update logic + # + def __database_schema_upgrade(self, con: sqlite3.Connection, from_version: int, to_version: int) -> bool: + if from_version == to_version: + return False + if from_version < 1 or to_version > 2: + raise NotImplementedError(f"Don't know how to update db schema from v{from_version} to v{to_version}") + if to_version < from_version: + raise ValueError(f'to_version cannot be less than from_version ({to_version}<{from_version})') + if from_version - to_version > 1: + need_commit = False + for i in range(from_version, to_version): + need_commit = self.__database_schema_upgrade(con, from_version, from_version + 1) or need_commit + return need_commit + + # at this point we are sure that from_version +1 = to_version + assert from_version + 1 == to_version + self.__logger.warning(f'updating database schema from {from_version} to {to_version}') + + # actual logic + if to_version == 2: + # need to ensure new node_object_state field is present + con.execute('ALTER TABLE "nodes" ADD COLUMN "node_object_state" BLOB') + return True diff --git a/src/lifeblood/scheduler/scheduler.py b/src/lifeblood/scheduler/scheduler.py index 3b7f019d..9a7ddc65 100644 --- a/src/lifeblood/scheduler/scheduler.py +++ b/src/lifeblood/scheduler/scheduler.py @@ -13,6 +13,7 @@ from .. import logging from .. import paths +from ..nodegraph_holder_base import NodeGraphHolderBase #from ..worker_task_protocol import WorkerTaskClient from ..worker_messsage_processor import WorkerControlClient from ..scheduler_task_protocol import SchedulerTaskProtocol, SpawnStatus @@ -26,7 +27,8 @@ from ..taskspawn import TaskSpawn from ..basenode import BaseNode from ..exceptions import * -from .. import pluginloader +from ..node_dataprovider_base import NodeDataProvider +from ..basenode_serialization import NodeSerializerBase, FailedToDeserialize from ..enums import WorkerState, WorkerPingState, TaskState, InvocationState, WorkerType, \ SchedulerMode, TaskGroupArchivedState from ..config import get_config @@ -47,8 +49,12 @@ from typing import Optional, Any, Tuple, List, Iterable, Union, Dict -class Scheduler: - def __init__(self, db_file_path, *, do_broadcasting: Optional[bool] = None, broadcast_interval: Optional[int] = None, +class Scheduler(NodeGraphHolderBase): + def __init__(self, db_file_path, *, + node_data_provider: NodeDataProvider, + node_serializers: List[NodeSerializerBase], + do_broadcasting: Optional[bool] = None, + broadcast_interval: Optional[int] = None, helpers_minimal_idle_to_ensure=1, server_addr: Optional[Tuple[str, int, int]] = None, server_ui_addr: Optional[Tuple[str, int]] = None): @@ -61,11 +67,15 @@ def __init__(self, db_file_path, *, do_broadcasting: Optional[bool] = None, broa :param server_addr: :param server_ui_addr: """ + self.__node_data_provider: NodeDataProvider = node_data_provider + if len(node_serializers) < 1: + raise ValueError('at least one serializer must be provided!') + self.__node_serializers = list(node_serializers) self.__logger = logging.get_logger('scheduler') self.__logger.info('loading core plugins') - pluginloader.init() # TODO: move it outside of constructor self.__node_objects: Dict[int, BaseNode] = {} self.__node_objects_locks: Dict[int, RWLock] = {} + self.__node_objects_creation_locks: Dict[int, asyncio.Lock] = {} config = get_config('scheduler') # this lock will prevent tasks from being reported cancelled and done at the same exact time should that ever happen @@ -156,6 +166,9 @@ def __init__(self, db_file_path, *, do_broadcasting: Optional[bool] = None, broa def get_event_loop(self): return self.__event_loop + def node_data_provider(self) -> NodeDataProvider: + return self.__node_data_provider + def _scheduler_protocol_factory(self): return SchedulerTaskProtocol(self) @@ -235,29 +248,45 @@ async def _get_node_object_by_id(self, node_id: int) -> BaseNode: """ if node_id in self.__node_objects: return self.__node_objects[node_id] - async with self.data_access.data_connection() as con: - con.row_factory = aiosqlite.Row - async with con.execute('SELECT * FROM "nodes" WHERE "id" = ?', (node_id,)) as nodecur: - node_row = await nodecur.fetchone() - if node_row is None: - raise RuntimeError('node id is invalid') - - node_type = node_row['type'] - if node_type not in pluginloader.plugins: - raise RuntimeError('node type is unsupported') - - if node_row['node_object'] is not None: - self.__node_objects[node_id] = await BaseNode.deserialize_async(node_row['node_object'], self, node_id) + async with self.__get_node_creation_lock_by_id(node_id): + # in case by the time we got here and the node was already created + if node_id in self.__node_objects: return self.__node_objects[node_id] + # if no - need to create one after all + async with self.data_access.data_connection() as con: + con.row_factory = aiosqlite.Row + async with con.execute('SELECT * FROM "nodes" WHERE "id" = ?', (node_id,)) as nodecur: + node_row = await nodecur.fetchone() + if node_row is None: + raise RuntimeError('node id is invalid') + + node_type = node_row['type'] + if not self.__node_data_provider.has_node_factory(node_type): + raise RuntimeError('node type is unsupported') + + if node_row['node_object'] is not None: + for serializer in self.__node_serializers: + try: + node_object = await serializer.deserialize_async(self, node_id, self.__node_data_provider, node_row['node_object'], node_row['node_object_state']) + break + except FailedToDeserialize as e: + self.__logger.warning(f'deserialization method failed with {e} ({serializer})') + continue + else: + raise RuntimeError(f'node entry {node_id} has unknown serialization method') + self.__node_objects[node_id] = node_object + return self.__node_objects[node_id] - # newnode: BaseNode = pluginloader.plugins[node_type].create_node_object(node_row['name'], self) - newnode: BaseNode = pluginloader.create_node(node_type, node_row['name'], self, node_id) - self.__node_objects[node_id] = newnode - await con.execute('UPDATE "nodes" SET node_object = ? WHERE "id" = ?', - (await newnode.serialize_async(), node_id)) - await con.commit() + newnode = self.__node_data_provider.node_factory(node_type)(node_row['name']) + newnode.set_parent(self, node_id) - return newnode + self.__node_objects[node_id] = newnode + node_data, state_data = await self.__node_serializers[0].serialize_async(newnode) + await con.execute('UPDATE "nodes" SET node_object = ?, node_object_state = ? WHERE "id" = ?', + (node_data, state_data, node_id)) + await con.commit() + + return newnode def get_node_lock_by_id(self, node_id: int) -> RWLock: """ @@ -274,6 +303,14 @@ def get_node_lock_by_id(self, node_id: int) -> RWLock: self.__node_objects_locks[node_id] = RWLock(fast=True) # read about fast on github. the points is if we have awaits inside critical section - it's safe to use fast return self.__node_objects_locks[node_id] + def __get_node_creation_lock_by_id(self, node_id: int) -> asyncio.Lock: + """ + This lock is for node creation/deserialization sections ONLY + """ + if node_id not in self.__node_objects_creation_locks: + self.__node_objects_creation_locks[node_id] = asyncio.Lock() + return self.__node_objects_creation_locks[node_id] + async def get_task_attributes(self, task_id: int) -> Tuple[Dict[str, Any], Optional[EnvironmentResolverArguments]]: """ get tasks, atributes and it's enviroment resolver's attributes @@ -1183,12 +1220,12 @@ async def duplicate_nodes(self, node_ids: Iterable[int]) -> Dict[int, int]: """ old_to_new = {} for nid in node_ids: - node_obj = await self._get_node_object_by_id(nid) - node_type, node_name = await self.get_node_type_and_name_by_id(nid) - new_id = await self.add_node(node_type, f'{node_name} copy') - new_node_obj = await self._get_node_object_by_id(new_id) - node_obj.copy_ui_to(new_node_obj) - old_to_new[nid] = new_id + async with self.node_object_by_id_for_reading(nid) as node_obj: + node_type, node_name = await self.get_node_type_and_name_by_id(nid) + new_id = await self.add_node(node_type, f'{node_name} copy') + async with self.node_object_by_id_for_writing(new_id) as new_node_obj: + node_obj.copy_ui_to(new_node_obj) + old_to_new[nid] = new_id # now copy connections async with self.data_access.data_connection() as con: @@ -1208,7 +1245,6 @@ async def duplicate_nodes(self, node_ids: Iterable[int]) -> Dict[int, int]: # node reports it's interface was changed. not sure why it exists async def node_reports_changes_needs_saving(self, node_id): assert node_id in self.__node_objects, 'this may be caused by race condition with node deletion' - # TODO: introduce __node_objects lock? or otherwise secure access await self.save_node_to_database(node_id) # @@ -1225,13 +1261,15 @@ async def save_node_to_database(self, node_id): # why? this happens on ui_update, which can happen cuz of request from viewer. # while node processing happens in a different thread, so this CAN happen at the same time with this # AND THIS IS BAD! (potentially) if a node has changing internal state - this can save some inconsistent snapshot of node state! + # this works now only cuz scheduler_ui_protocol does the locking for param settings node_object = self.__node_objects[node_id] if node_object is None: self.__logger.error('node_object is None while') return + node_data, state_data = await self.__node_serializers[0].serialize_async(node_object) async with self.data_access.data_connection() as con: - await con.execute('UPDATE "nodes" SET node_object = ? WHERE "id" = ?', - (await node_object.serialize_async(), node_id)) + await con.execute('UPDATE "nodes" SET node_object = ?, node_object_state = ? WHERE "id" = ?', + (node_data, state_data, node_id)) await con.commit() # @@ -1308,8 +1346,9 @@ async def remove_node_connection(self, node_connection_id: int): # # add node async def add_node(self, node_type: str, node_name: str) -> int: - if node_type not in pluginloader.plugins: - raise RuntimeError('unknown node type') + if not self.__node_data_provider.has_node_factory(node_type): # preliminary check + raise RuntimeError(f'unknown node type: "{node_type}"') + async with self.data_access.data_connection() as con: con.row_factory = aiosqlite.Row async with con.execute('INSERT INTO "nodes" ("type", "name") VALUES (?,?)', @@ -1319,6 +1358,12 @@ async def add_node(self, node_type: str, node_name: str) -> int: self.ui_state_access.bump_graph_update_id() return ret + async def apply_node_settings(self, node_id: int, settings_name: str): + async with self.node_object_by_id_for_writing(node_id) as node_object: + settings = self.__node_data_provider.node_settings(node_object.type_name(), settings_name) + async with self.node_object_by_id_for_writing(node_id) as node: + await asyncio.get_event_loop().run_in_executor(None, node.apply_settings, settings) + async def remove_node(self, node_id: int) -> bool: try: async with self.data_access.data_connection() as con: diff --git a/src/lifeblood/scheduler/task_processor.py b/src/lifeblood/scheduler/task_processor.py index 72937138..d58a87ab 100644 --- a/src/lifeblood/scheduler/task_processor.py +++ b/src/lifeblood/scheduler/task_processor.py @@ -15,7 +15,6 @@ from ..environment_resolver import EnvironmentResolverArguments from ..nodethings import ProcessingResult from ..exceptions import * -from .. import pluginloader from .. import aiosqlite_overlay from ..config import get_config from ..ui_events import TaskData, TaskDelta @@ -179,9 +178,6 @@ async def _awaiter(self, processor_to_run, task_row, abort_state: TaskState, ski async with con.execute('SELECT attributes FROM tasks WHERE "id" = ?', (task_row['split_origin_task_id'],)) as attcur: attributes = await asyncio.get_event_loop().run_in_executor(None, json.loads, (await attcur.fetchone())['attributes']) attributes.update(process_result.split_attributes_to_set) - for k, v in process_result.split_attributes_to_set.items(): - if v is None: - del attributes[k] result_serialized = await asyncio.get_event_loop().run_in_executor(None, json.dumps, attributes) await con.execute('UPDATE tasks SET "attributes" = ? WHERE "id" = ?', (result_serialized, task_row['split_origin_task_id'])) @@ -473,7 +469,7 @@ def _gszofdr(obj): awaiters = [] set_to_stuff = [] for task_row in all_task_rows: - if task_row['node_type'] not in pluginloader.plugins: + if task_row['node_type'] not in self.scheduler.node_data_provider().node_type_names(): self.__logger.error(f'plugin to process "{task_row["node_type"]}" not found!') # await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', # (TaskState.ERROR.value, task_row['id'])) @@ -495,6 +491,7 @@ def _gszofdr(obj): set_to_stuff.append((TaskState.GENERATING.value, task_row['id'])) total_state_changes += 1 # NOTE: awaiters are NOT started here, just coroutines created + # TODO: catch node getting errors here and in other places in task_processor. awaiters.append(self._awaiter((await self.scheduler._get_node_object_by_id(task_row['node_id']))._process_task_wrapper, dict(task_row), abort_state=TaskState.WAITING, skip_state=TaskState.POST_WAITING)) if set_to_stuff: @@ -514,7 +511,7 @@ def _gszofdr(obj): awaiters = [] set_to_stuff = [] for task_row in all_task_rows: - if task_row['node_type'] not in pluginloader.plugins: + if task_row['node_type'] not in self.scheduler.node_data_provider().node_type_names(): self.__logger.error(f'plugin to process "{task_row["node_type"]}" not found!') # await con.execute('UPDATE tasks SET "state" = ? WHERE "id" = ?', # (TaskState.ERROR.value, task_row['id'])) diff --git a/src/lifeblood/scheduler_ui_protocol.py b/src/lifeblood/scheduler_ui_protocol.py index cd3ad4b2..424132d5 100644 --- a/src/lifeblood/scheduler_ui_protocol.py +++ b/src/lifeblood/scheduler_ui_protocol.py @@ -9,9 +9,8 @@ from .ui_events import TaskEvent from .enums import NodeParameterType, TaskState, SpawnStatus, TaskGroupArchivedState from .exceptions import NotSubscribedError -from . import pluginloader from .invocationjob import InvocationJob -from .net_classes import NodeTypeMetadata +from .node_type_metadata import NodeTypeMetadata from .taskspawn import NewTask from .snippets import NodeSnippetData, NodeSnippetDataPlaceholder from .environment_resolver import EnvironmentResolverArguments @@ -168,9 +167,9 @@ async def comm_get_task_invocation(): # elif command == b'gettaskinvoc': # node related commands async def comm_list_node_types(): # elif command == b'listnodetypes': typemetas = [] - for type_name, module in pluginloader.plugins.items(): - cls = module.node_class() - typemetas.append(NodeTypeMetadata(cls)) + data_provider = self.__scheduler.node_data_provider() + for type_name in data_provider.node_type_names(): + typemetas.append(NodeTypeMetadata(data_provider, type_name)) writer.write(struct.pack('>Q', len(typemetas))) for typemeta in typemetas: data: bytes = await asyncio.get_event_loop().run_in_executor(None, pickle.dumps, typemeta) @@ -178,7 +177,13 @@ async def comm_list_node_types(): # elif command == b'listnodetypes': writer.write(data) async def comm_list_presets(): - preset_metadata: Dict[str, Dict[str, NodeSnippetDataPlaceholder]] = {pack: {pres: NodeSnippetDataPlaceholder.from_nodesnippetdata(snip) for pres, snip in packdata.items()} for pack, packdata in pluginloader.presets.items()} + data_provider = self.__scheduler.node_data_provider() + preset_metadata: Dict[str, Dict[str, NodeSnippetDataPlaceholder]] = { + pack: { + pres: NodeSnippetDataPlaceholder.from_nodesnippetdata(data_provider.node_preset(pack, pres)) + for pres in data_provider.node_preset_names(pack) + } for pack in data_provider.node_preset_packages() + } data: bytes = await asyncio.get_event_loop().run_in_executor(None, pickle.dumps, preset_metadata) writer.write(struct.pack('>Q', len(data))) writer.write(data) @@ -186,12 +191,13 @@ async def comm_list_presets(): async def comm_get_node_preset(): package_name = await read_string() preset_name = await read_string() - if package_name not in pluginloader.presets or preset_name not in pluginloader.presets[package_name]: + data_provider = self.__scheduler.node_data_provider() + if package_name not in data_provider.node_preset_packages() or preset_name not in data_provider.node_preset_names(package_name): self.__logger.warning(f'requested preset {package_name}::{preset_name} is not found') writer.write(struct.pack('>?', False)) return writer.write(struct.pack('>?', True)) - data = pluginloader.presets[package_name][preset_name].serialize(ascii=False) + data = data_provider.node_preset(package_name, preset_name).serialize(ascii=False) writer.write(struct.pack('>Q', len(data))) writer.write(data) @@ -408,8 +414,7 @@ async def comm_apply_node_settings(): node_id = struct.unpack('>Q', await reader.readexactly(8))[0] settings_name = await read_string() try: - async with self.__scheduler.node_object_by_id_for_writing(node_id) as node: - await asyncio.get_event_loop().run_in_executor(None, node.apply_settings, settings_name) + await self.__scheduler.apply_node_settings(node_id, settings_name) except Exception: self.__logger.exception(f'FAILED to apply node settings for node {node_id}, settings name "{settings_name}"') writer.write(b'\0') @@ -422,8 +427,9 @@ async def comm_save_custom_node_settings(): datasize = struct.unpack('>Q', await reader.readexactly(8))[0] settings = await asyncio.get_event_loop().run_in_executor(None, pickle.loads, await reader.readexactly(datasize)) + data_provider = self.__scheduler.node_data_provider() try: - pluginloader.add_settings_to_existing_package('custom_default', node_type_name, settings_name, settings) + data_provider.add_settings_to_existing_package('custom_default', node_type_name, settings_name, settings) except RuntimeError as e: self.__logger.error(f'failed to add custom node settings: {str(e)}') writer.write(b'\0') @@ -437,8 +443,9 @@ async def comm_set_settings_default(): settings_name = await read_string() else: settings_name = None + data_provider = self.__scheduler.node_data_provider() try: - pluginloader.set_settings_as_default(node_type_name, settings_name) + data_provider.set_settings_as_default(node_type_name, settings_name) except RuntimeError: self.__logger.error(f'failed to set node default settings: {str(e)}') writer.write(b'\0') diff --git a/src/lifeblood/uidata.py b/src/lifeblood/uidata.py index cfc43c87..c5f11dae 100644 --- a/src/lifeblood/uidata.py +++ b/src/lifeblood/uidata.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import dataclass import pickle import os import pathlib @@ -7,6 +8,7 @@ from .enums import NodeParameterType from .processingcontext import ProcessingContext from .node_visualization_classes import NodeColorScheme +from .logging import get_logger import re from typing import TYPE_CHECKING, TypedDict, Dict, Any, List, Set, Optional, Tuple, Union, Iterable, FrozenSet, Type, Callable @@ -610,6 +612,9 @@ def __init__(self): self.__block_ui_callbacks = False def initializing_interface_lock(self): + return self.block_ui_callbacks() + + def block_ui_callbacks(self): class _iiLock: def __init__(self, lockable): self.__nui = lockable @@ -980,6 +985,12 @@ def _children_value_changed(self, children: Iterable["ParameterHierarchyItem"]): self.__my_nodeui._outputs_definition_changed() +@dataclass +class ParameterFullValue: + unexpanded_value: Union[int, float, str, bool] + expression: Optional[str] + + class NodeUiError(RuntimeError): pass @@ -991,6 +1002,7 @@ class NodeUiDefinitionError(RuntimeError): class NodeUi(ParameterHierarchyItem): def __init__(self, attached_node: "BaseNode"): super(NodeUi, self).__init__() + self.__logger = get_logger('scheduler.nodeUI') self.__parameter_layout = VerticalParametersLayout() self.__parameter_layout.set_parent(self) self.__attached_node: Optional[BaseNode] = attached_node @@ -1029,6 +1041,9 @@ def set_parent(self, item: Optional["ParameterHierarchyItem"]): raise RuntimeError('NodeUi class is supposed to be tree root') def initializing_interface_lock(self): + return self.block_ui_callbacks() + + def block_ui_callbacks(self): class _iiLock: def __init__(self, lockable): self.__nui = lockable @@ -1321,6 +1336,55 @@ def parameters(self) -> Iterable[Parameter]: def items(self, recursive=False) -> Iterable[ParameterHierarchyItem]: return self.__parameter_layout.items(recursive=recursive) + def set_parameters_batch(self, parameters: Dict[str, ParameterFullValue]): + """ + If signal blocking is needed - caller can do it + + for now it's implemented the stupid way + """ + names_to_set = list(parameters.keys()) + names_to_set.append(None) + something_set_this_iteration = False + parameters_were_postponed = False + for param_name in names_to_set: + if param_name is None: + if parameters_were_postponed: + if not something_set_this_iteration: + self.__logger.warning(f'failed to set all parameters!') + break + names_to_set.append(None) + something_set_this_iteration = False + continue + assert isinstance(param_name, str) + param = self.parameter(param_name) + if param is None: + parameters_were_postponed = True + continue + param_value = parameters[param_name] + try: + param.set_value(param_value.unexpanded_value) + except (ParameterReadonly, ParameterLocked): + # if value is already correct - just skip + if param.unexpanded_value() != param_value.unexpanded_value: + self.__logger.error(f'unable to set value for "{param_name}"') + # shall we just ignore the error? + except ParameterError as e: + self.__logger.error(f'failed to set value for "{param_name}" because {repr(e)}') + if param.can_have_expressions(): + try: + param.set_expression(param_value.expression) + except (ParameterReadonly, ParameterLocked): + # if value is already correct - just skip + if param.expression() != param_value.expression: + self.__logger.error(f'unable to set expression for "{param_name}"') + # shall we just ignore the error? + except ParameterError as e: + self.__logger.error(f'failed to set expression for "{param_name}" because {repr(e)}') + elif param_value.expression is not None: + self.__logger.error(f'parameter "{param_name}" cannot have expressions, yet expression is stored for it') + + something_set_this_iteration = True + def __deepcopy__(self, memo): cls = self.__class__ crap = cls.__new__(cls) diff --git a/src/lifeblood/worker.py b/src/lifeblood/worker.py index 81907b54..88236272 100644 --- a/src/lifeblood/worker.py +++ b/src/lifeblood/worker.py @@ -56,18 +56,6 @@ is_posix = not sys.platform.startswith('win') -async def create_worker(scheduler_address: AddressChain, *, - child_priority_adjustment: ProcessPriorityAdjustment = ProcessPriorityAdjustment.NO_CHANGE, - worker_type: WorkerType = WorkerType.STANDARD, - singleshot: bool = False, - worker_id: Optional[int] = None, - pool_address: Optional[Tuple[str, int]] = None): - worker = Worker(scheduler_address, child_priority_adjustment=child_priority_adjustment, worker_type=worker_type, singleshot=singleshot, worker_id=worker_id, pool_address=pool_address) - - await worker.start() # note that server is already started at this point - return worker - - class Worker: def __init__(self, scheduler_addr: AddressChain, *, child_priority_adjustment: ProcessPriorityAdjustment = ProcessPriorityAdjustment.NO_CHANGE, @@ -748,190 +736,3 @@ async def _reintroduce_ourself(): # and since error is most probably due to network - it will either resolve itself, or there is no point reintroducing if connection cannot be established anyway elif result is not None: self.__ping_missed = 0 - - -async def main_async(worker_type=WorkerType.STANDARD, - child_priority_adjustment: ProcessPriorityAdjustment = ProcessPriorityAdjustment.NO_CHANGE, - singleshot: bool = False, worker_id: Optional[int] = None, pool_address=None, noloop=False): - """ - listen to scheduler broadcast in a loop. - if received - create the worker and work - if worker cannot ping the scheduler a number of times - it stops - and listenting for broadcast starts again - :return: Never! - """ - graceful_closer_no_reentry = False - - def graceful_closer(*args): - nonlocal graceful_closer_no_reentry - if graceful_closer_no_reentry: - print('DOUBLE SIGNAL CAUGHT: ALREADY EXITING') - return - graceful_closer_no_reentry = True - logging.get_logger('worker').info('SIGINT/SIGTERM caught') - nonlocal noloop - noloop = True - stop_event.set() - if worker is not None: - worker.stop() - - noasync_do_close = False - - def noasync_windows_graceful_closer_event(*args): - nonlocal noasync_do_close - noasync_do_close = True - - async def windows_graceful_closer(): - while not noasync_do_close: - await asyncio.sleep(1) - graceful_closer() - - worker = None - stop_event = asyncio.Event() - win_signal_waiting_task = None - try: - asyncio.get_event_loop().add_signal_handler(signal.SIGINT, graceful_closer) - asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, graceful_closer) - except NotImplementedError: # solution for windows - signal.signal(signal.SIGINT, noasync_windows_graceful_closer_event) - signal.signal(signal.SIGBREAK, noasync_windows_graceful_closer_event) - win_signal_waiting_task = asyncio.create_task(windows_graceful_closer()) - - config = get_config('worker') - logger = logging.get_logger('worker') - if await config.get_option('worker.listen_to_broadcast', True): - stop_task = asyncio.create_task(stop_event.wait()) - while True: - logger.info('listening for scheduler broadcasts...') - broadcast_task = asyncio.create_task(await_broadcast('lifeblood_scheduler')) - done, _ = await asyncio.wait((broadcast_task, stop_task), return_when=asyncio.FIRST_COMPLETED) - if stop_task in done: - broadcast_task.cancel() - logger.info('broadcast listening cancelled') - break - assert broadcast_task.done() - message = await broadcast_task - scheduler_info = json.loads(message) - logger.debug('received', scheduler_info) - addr = AddressChain(scheduler_info['message_address']) - try: - worker = await create_worker(addr, child_priority_adjustment=child_priority_adjustment, worker_type=worker_type, singleshot=singleshot, worker_id=worker_id, pool_address=pool_address) - except Exception: - logger.exception('could not start the worker') - else: - await worker.wait_till_stops() - logger.info('worker quited') - if noloop: - break - else: - logger.info('boradcast listening disabled') - while True: - addr = AddressChain(await config.get_option('worker.scheduler_address', get_default_addr())) - logger.debug(f'using {addr}') - try: - worker = await create_worker(addr, child_priority_adjustment=child_priority_adjustment, worker_type=worker_type, singleshot=singleshot, worker_id=worker_id, pool_address=pool_address) - except ConnectionRefusedError as e: - logger.exception('Connection error', str(e)) - await asyncio.sleep(10) - continue - await worker.wait_till_stops() - logger.info('worker quited') - if noloop: - break - - if win_signal_waiting_task is not None: # this happens only on windows - if not win_signal_waiting_task.done(): - win_signal_waiting_task.cancel() - else: - asyncio.get_event_loop().remove_signal_handler(signal.SIGINT) # this seem to fix the bad signal fd error - asyncio.get_event_loop().remove_signal_handler(signal.SIGTERM) # my guess what happens is that loop closes, but signal handlers remain if not unsed - - -default_config = ''' -[worker] -listen_to_broadcast = true - -[default_env_wrapper] -## here you can uncomment lines below to specify your own default environment wrapper and default arguments -## this will only be used by invocation jobs that have NO environment wrappers specified -# name = TrivialEnvironmentResolver -# arguments = [ "project_name", "or", "config_name", "idunno", "maybe rez packages requirements?", [1,4,11] ] - -[resources] -## here you can override resources that this machine has -## if you don't specify anything - resources will be detected automatically -## NOTE: automatic detection DOES NOT WORK FOR GPU yet, you have to specify it manually -# cpu_count = 32 # by default treated as the number of cores -# cpu_mem = "128G" # you can either specify int amount of bytes, or use string ending with one of "K" "M" "G" "T" "P" meaning Kilo, Mega, Giga, ... -# gpu_count = 1 # by default treated as the number devices -# gpu_mem = "8G" # you can either specify int amount of bytes, or use string ending with one of "K" "M" "G" "T" "P" meaning Kilo, Mega, Giga, ... -''' - - -def main(argv): - # import signal - # prev = None - # def signal_handler(sig, frame): - # print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! You pressed Ctrl+C !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') - # prev(sig, frame) - # - # prev = signal.signal(signal.SIGINT, signal_handler) - import argparse - parser = argparse.ArgumentParser('lifeblood worker', description='executes invocations from scheduler') - parser.add_argument('--scheduler-address', help='manually specify scheduler to connect to. if not specified - by default worker will start listening to broadcasts from schedulers') - parser.add_argument('--no-listen-broadcast', action='store_true', help='do not listen to scheduler\'s broadcast, use config') - parser.add_argument('--no-loop', action='store_true', help='by default worker will return into the loop of waiting for scheduler every time it quits because of connection loss, or other errors. ' - 'but this flag will force worker to just completely quit instead') - parser.add_argument('--singleshot', action='store_true', help='worker will pick one job and exit after that job is completed or cancelled. ' - 'this is on by default when type=SCHEDULER_HELPER') - parser.add_argument('--type', choices=('STANDARD', 'SCHEDULER_HELPER'), default='STANDARD') - parser.add_argument('--id', help='integer identifier which worker should use when talking to worker pool') - parser.add_argument('--pool-address', help='if this worker is a part of a pool - pool address. currently pool can only be on the same host') - parser.add_argument('--priority', choices=tuple(x.name for x in ProcessPriorityAdjustment), default=ProcessPriorityAdjustment.NO_CHANGE.name, help='adjust child process priority') - parser.add_argument('--generate-config-only', action='store_true', help='just generate initial config and exit. Note that existing config will NOT be overriden') - - args = parser.parse_args(argv) - - # check and create default config if none - create_default_user_config_file('worker', default_config) - - if args.generate_config_only: - return - - if args.type == 'STANDARD': - wtype = WorkerType.STANDARD - elif args.type == 'SCHEDULER_HELPER': - wtype = WorkerType.SCHEDULER_HELPER - else: - raise NotImplementedError(f'worker type {args.type} is not yet implemented') - - priority_adjustment = [x for x in ProcessPriorityAdjustment if x.name == args.priority][0] # there MUST be exactly 1 match - - global_logger = logging.get_logger('worker') - - # check and create default config if none - create_default_user_config_file('worker', default_config) - - # check legality of the address - paddr = AddressChain(args.pool_address) - - config = get_config('worker') - if args.no_listen_broadcast: - config.set_override('worker.listen_to_broadcast', False) - if args.scheduler_address is not None: - config.set_override('worker.listen_to_broadcast', False) - saddr = AddressChain(args.scheduler_address) - config.set_override('worker.scheduler_address', str(saddr)) - try: - asyncio.run(main_async(wtype, child_priority_adjustment=priority_adjustment, singleshot=args.singleshot, worker_id=int(args.id) if args.id is not None else None, pool_address=paddr, noloop=args.no_loop)) - except KeyboardInterrupt: - # if u see errors in pycharm around this area when running from scheduler - - # it's because pycharm and most shells send SIGINTs to this child process on top of SIGINT that pool sends - # this stuff above tries to suppress that double SIGINTing, but it's not 100% solution - global_logger.warning('SIGINT caught where it wasn\'t supposed to be caught') - global_logger.info('SIGINT caught. Worker is stopped now.') - - -if __name__ == '__main__': - import sys - main(sys.argv) diff --git a/src/lifeblood_viewer/connection_worker.py b/src/lifeblood_viewer/connection_worker.py index c8b5d334..e2e3e645 100644 --- a/src/lifeblood_viewer/connection_worker.py +++ b/src/lifeblood_viewer/connection_worker.py @@ -15,7 +15,7 @@ from lifeblood.broadcasting import await_broadcast from lifeblood.config import get_config from lifeblood.uidata import Parameter -from lifeblood.net_classes import NodeTypeMetadata +from lifeblood.node_type_metadata import NodeTypeMetadata from lifeblood.taskspawn import NewTask from lifeblood.snippets import NodeSnippetData, NodeSnippetDataPlaceholder from lifeblood.defaults import ui_port diff --git a/src/lifeblood_viewer/graphics_scene.py b/src/lifeblood_viewer/graphics_scene.py index 3902e29e..4c459366 100644 --- a/src/lifeblood_viewer/graphics_scene.py +++ b/src/lifeblood_viewer/graphics_scene.py @@ -17,7 +17,7 @@ from lifeblood.ui_protocol_data import UiData, TaskGroupBatchData, TaskBatchData, NodeGraphStructureData, TaskDelta, DataNotSet, IncompleteInvocationLogData, InvocationLogData from lifeblood.enums import TaskState, NodeParameterType, TaskGroupArchivedState from lifeblood import logging -from lifeblood.net_classes import NodeTypeMetadata +from lifeblood.node_type_metadata import NodeTypeMetadata from lifeblood.taskspawn import NewTask from lifeblood.invocationjob import InvocationJob from lifeblood.snippets import NodeSnippetData, NodeSnippetDataPlaceholder diff --git a/src/lifeblood_viewer/nodeeditor.py b/src/lifeblood_viewer/nodeeditor.py index 5e728040..91f70550 100644 --- a/src/lifeblood_viewer/nodeeditor.py +++ b/src/lifeblood_viewer/nodeeditor.py @@ -20,7 +20,7 @@ from lifeblood.config import get_config from lifeblood import logging from lifeblood import paths -from lifeblood.net_classes import NodeTypeMetadata +from lifeblood.node_type_metadata import NodeTypeMetadata from lifeblood.taskspawn import NewTask from lifeblood.invocationjob import InvocationJob from lifeblood.snippets import NodeSnippetData, NodeSnippetDataPlaceholder diff --git a/tests/nodes/common.py b/tests/nodes/common.py index 9bc816b9..29f983b4 100644 --- a/tests/nodes/common.py +++ b/tests/nodes/common.py @@ -12,10 +12,11 @@ from lifeblood.basenode import BaseNode, ProcessingResult from lifeblood.exceptions import NodeNotReadyToProcess from lifeblood.scheduler import Scheduler +from lifeblood.main_scheduler import create_default_scheduler from lifeblood.worker import Worker from lifeblood.invocationjob import InvocationJob, Environment from lifeblood.scheduler.pinger import Pinger -from lifeblood.pluginloader import create_node +from lifeblood.pluginloader import PluginNodeDataProvider from lifeblood.processingcontext import ProcessingContext from lifeblood.process_utils import oh_no_its_windows from lifeblood.environment_resolver import EnvironmentResolverArguments @@ -23,6 +24,15 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union +plugin_data_provider = PluginNodeDataProvider() + + +def create_node(node_type: str, node_name: str, scheduler, node_id): + node = plugin_data_provider.node_factory(node_type)(node_name) + node.set_parent(scheduler, node_id) + return node + + class FakeEnvArgs(EnvironmentResolverArguments): def __init__(self, rel_path_to_bin: str): super().__init__() @@ -203,7 +213,7 @@ async def _helper_test_worker_node(self, ppatch.return_value = mock.AsyncMock(Pinger) wppatch.return_value = mock.AsyncMock() - sched = Scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() workers = [] diff --git a/tests/nodes/test_ffmpeg.py b/tests/nodes/test_ffmpeg.py index 9af71f9d..ce2a22dc 100644 --- a/tests/nodes/test_ffmpeg.py +++ b/tests/nodes/test_ffmpeg.py @@ -8,12 +8,11 @@ from lifeblood.invocationjob import InvocationJob, Environment from lifeblood.enums import SpawnStatus -from lifeblood.pluginloader import create_node from lifeblood.processingcontext import ProcessingContext from lifeblood.environment_resolver import EnvironmentResolverArguments from lifeblood.process_utils import oh_no_its_windows -from .common import TestCaseBase +from .common import TestCaseBase, create_node class FakeEnvArgs(EnvironmentResolverArguments): diff --git a/tests/nodes/test_husk.py b/tests/nodes/test_husk.py index 4239d2fd..f1967104 100644 --- a/tests/nodes/test_husk.py +++ b/tests/nodes/test_husk.py @@ -7,12 +7,11 @@ from lifeblood.invocationjob import InvocationJob, Environment from lifeblood.enums import SpawnStatus -from lifeblood.pluginloader import create_node from lifeblood.processingcontext import ProcessingContext from lifeblood.environment_resolver import EnvironmentResolverArguments from lifeblood.process_utils import oh_no_its_windows -from .common import TestCaseBase +from .common import TestCaseBase, create_node class FakeEnvArgs(EnvironmentResolverArguments): diff --git a/tests/nodes/test_mantra.py b/tests/nodes/test_mantra.py index afd943b9..495c8648 100644 --- a/tests/nodes/test_mantra.py +++ b/tests/nodes/test_mantra.py @@ -7,12 +7,11 @@ from lifeblood.invocationjob import InvocationJob, Environment from lifeblood.enums import SpawnStatus -from lifeblood.pluginloader import create_node from lifeblood.processingcontext import ProcessingContext from lifeblood.environment_resolver import EnvironmentResolverArguments from lifeblood.process_utils import oh_no_its_windows -from .common import TestCaseBase +from .common import TestCaseBase, create_node # TODO: tests are currently very shallow !! diff --git a/tests/nodes/test_parent_children_waiter.py b/tests/nodes/test_parent_children_waiter.py index 3351fc23..424c4ee7 100644 --- a/tests/nodes/test_parent_children_waiter.py +++ b/tests/nodes/test_parent_children_waiter.py @@ -1,4 +1,5 @@ from asyncio import Event +import json import random from lifeblood.scheduler import Scheduler from lifeblood.worker import Worker @@ -45,6 +46,67 @@ async def test_recursive_when_unreached_child_becomes_dead(self): await self._helper_test_recursive(rng.randint(1, 10), 0) await self._helper_test_recursive(rng.randint(1, 10), 100) + async def test_serialization(self): + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + rng = random.Random(1666661) + task_parent = context.create_pseudo_task_with_attrs({}, 333, None) + task_child0 = context.create_pseudo_task_with_attrs({'foo': 2}, 456, 333) + task_child1 = context.create_pseudo_task_with_attrs({'foo': 5}, 123, 333) + task_child2 = context.create_pseudo_task_with_attrs({'foo': 9}, 666, 333) + + task_child0.set_input_name('children') + task_child1.set_input_name('children') + task_child2.set_input_name('children') + + node: BaseNode = context.create_node('parent_children_waiter', 'footest') + + def _prune_dicts(dic): + return { + k: v for k, v in dic.items() if k not in ( + '_ParentChildrenWaiterNode__main_lock', # lock is not part of the state, locks should be unique + '_parameters', # parameters to be compared separately, by a different, generic test set, they are not part of state anyway + '_BaseNode__parent_nid' # node ids will be different, state does not cover it + ) + } + + def _do_test(): + state = node.get_state() + node_test: BaseNode = context.create_node('parent_children_waiter', 'footest') + node_test.set_state(state) + self.assertDictEqual(_prune_dicts(node.__dict__), _prune_dicts(node_test.__dict__)) + + # we expect state to be json-serializeable + state = json.loads(json.dumps(state)) + node_test: BaseNode = context.create_node('parent_children_waiter', 'footest') + node_test.set_state(state) + self.assertDictEqual(_prune_dicts(node.__dict__), _prune_dicts(node_test.__dict__)) + + all_tasks = (task_child0, task_child1, task_child2, task_parent) + for _ in range(1): + tasks = [task_child0, task_child1, task_child2] + rng.shuffle(tasks) + for _ in range(rng.randint(0, 5)): + tasks.append(rng.choice(tasks)) + + # note that we need to simulate a valid situation if we want a correct inner state behaviour + _do_test() + for task in tasks: + context.process_task(node, task) + _do_test() + context.process_task(node, task_parent) + _do_test() + + for task in tasks[:3]: + context.process_task(node, task) + _do_test() + _do_test() + # internal state should be empty after + self.assertDictEqual({'cache_children': {}}, node.get_state()) + + await self._helper_test_node_with_arg_update( + _logic + ) + async def _helper_test_regular(self, dying_children_count, random_tasks_count): async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): rng = random.Random(16666661) diff --git a/tests/nodes/test_redshift.py b/tests/nodes/test_redshift.py index 65867fa4..20dc38ed 100644 --- a/tests/nodes/test_redshift.py +++ b/tests/nodes/test_redshift.py @@ -6,12 +6,11 @@ from lifeblood.invocationjob import InvocationJob, Environment from lifeblood.enums import SpawnStatus -from lifeblood.pluginloader import create_node from lifeblood.processingcontext import ProcessingContext from lifeblood.environment_resolver import EnvironmentResolverArguments from lifeblood.process_utils import oh_no_its_windows -from .common import TestCaseBase +from .common import TestCaseBase, create_node class FakeEnvArgs(EnvironmentResolverArguments): diff --git a/tests/nodes/test_split_waiter.py b/tests/nodes/test_split_waiter.py index 7e5ec194..1bb2b49d 100644 --- a/tests/nodes/test_split_waiter.py +++ b/tests/nodes/test_split_waiter.py @@ -1,4 +1,5 @@ import random +import json from asyncio import Event from lifeblood.scheduler import Scheduler from lifeblood.worker import Worker @@ -22,6 +23,58 @@ async def test_basic_functions_nowait(self): async def test_basic_functions_nowait_with_garbage(self): await self._helper_test_basic_functions(20, False) + async def test_serialization(self): + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + rng = random.Random(1666661) + main_task = context.create_pseudo_task_with_attrs({'boo': 32}) + tsplits = context.create_pseudo_split_of(main_task, [ + {'foo': 5}, + {'foo': 8}, + {'foo': 1}, + {'foo': 3}, + ]) + + node: BaseNode = context.create_node('split_waiter', 'footest') + + def _prune_dicts(dic): + return { + k: v for k, v in dic.items() if k not in ( + '_SplitAwaiterNode__main_lock', # lock is not part of the state, locks should be unique + '_parameters', # parameters to be compared separately, by a different, generic test set, they are not part of state anyway + '_BaseNode__parent_nid' # node ids will be different, state does not cover it + ) + } + + def _do_test(): + state = node.get_state() + node_test: BaseNode = context.create_node('split_waiter', 'footest') + node_test.set_state(state) + self.assertDictEqual(_prune_dicts(node.__dict__), _prune_dicts(node_test.__dict__)) + + # we expect state to be json-serializeable + state = json.loads(json.dumps(state)) + node_test: BaseNode = context.create_node('split_waiter', 'footest') + node_test.set_state(state) + self.assertDictEqual(_prune_dicts(node.__dict__), _prune_dicts(node_test.__dict__)) + + all_tasks = tsplits + [main_task] + for _ in range(100): + tasks = list(tsplits) + rng.shuffle(tasks) + for _ in range(rng.randint(0, 5)): + tasks.append(rng.choice(all_tasks)) + + _do_test() + for task in tasks: + context.process_task(node, task) + _do_test() + context.process_task(node, main_task) + _do_test() + + await self._helper_test_node_with_arg_update( + _logic + ) + async def _helper_test_basic_functions_wait(self, garbage_split_count): async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): rng = random.Random(716394) diff --git a/tests/nodes/test_wait_for_task.py b/tests/nodes/test_wait_for_task.py index a2d16b0a..f49d4308 100644 --- a/tests/nodes/test_wait_for_task.py +++ b/tests/nodes/test_wait_for_task.py @@ -1,4 +1,6 @@ from asyncio import Event +import json +import random from lifeblood.scheduler import Scheduler from lifeblood.worker import Worker from lifeblood.basenode import BaseNode @@ -242,6 +244,76 @@ async def test_reschedule_with_from_empty_cond_with_exp(self): """ await self._helper_test_reschedule_with_from_empty_cond(1) + async def test_serialization(self): + async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): + rng = random.Random(1666661) + + task0_attrs = { + 'cond': '1', + 'exp': '2 3 4' + } + task1_attrs = { + 'cond': '2', + 'exp': '1 3' + } + task2_attrs = { + 'cond': '3', + 'exp': '1 2' + } + task3_attrs = { + 'cond': '4', + 'exp': '1 2 3' + } + task0 = context.create_pseudo_task_with_attrs(task0_attrs, 234) + task1 = context.create_pseudo_task_with_attrs(task1_attrs, 235) + task2 = context.create_pseudo_task_with_attrs(task2_attrs, 236) + task3 = context.create_pseudo_task_with_attrs(task3_attrs, 237) + + def _prune_dicts(dic): + return { + k: v for k, v in dic.items() if k not in ( + '_WaitForTaskValue__main_lock', # lock is not part of the state, locks should be unique + '_parameters', # parameters to be compared separately, by a different, generic test set, they are not part of state anyway + '_BaseNode__parent_nid' # node ids will be different, state does not cover it + ) + } + + def _do_test(): + state = node.get_state() + node_test: BaseNode = context.create_node('wait_for_task_value', 'footest') + node_test.set_param_value('condition value', '`task["cond"]`') + node_test.set_param_value('expected values', '`task["exp"]`') + node_test.set_state(state) + self.assertDictEqual(_prune_dicts(node.__dict__), _prune_dicts(node_test.__dict__)) + + # we expect state to be json-serializeable + state = json.loads(json.dumps(state)) + node_test: BaseNode = context.create_node('wait_for_task_value', 'footest') + node_test.set_param_value('condition value', '`task["cond"]`') + node_test.set_param_value('expected values', '`task["exp"]`') + node_test.set_state(state) + self.assertDictEqual(_prune_dicts(node.__dict__), _prune_dicts(node_test.__dict__)) + + all_tasks = [task0, task1, task2, task3] + for _ in range(100): + node: BaseNode = context.create_node('wait_for_task_value', 'footest') + node.set_param_value('condition value', '`task["cond"]`') + node.set_param_value('expected values', '`task["exp"]`') + + tasks = list(all_tasks) + rng.shuffle(tasks) + for _ in range(rng.randint(0, 5)): + tasks.append(rng.choice(all_tasks)) + + _do_test() + for task in tasks: + context.process_task(node, task) + _do_test() + + await self._helper_test_node_with_arg_update( + _logic + ) + async def _helper_test_reschedule_with_from_empty_cond(self, var: int): async def _logic(sched: Scheduler, workers: List[Worker], done_waiter: Event, context: PseudoContext): node: BaseNode = context.create_node('wait_for_task_value', 'footest') diff --git a/tests/test_config.py b/tests/test_config.py index 6e17f35c..8fa52e10 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -96,7 +96,7 @@ def test_condigd(self): class DefaultComponentConfigTest(unittest.TestCase): def test_default_schediler(self): - from lifeblood.scheduler import default_config + from lifeblood.main_scheduler import default_config data = toml.loads(default_config) self.assertIn('scheduler', data) self.assertIn('globals', data['scheduler']) @@ -106,7 +106,7 @@ def test_default_worker(self): """ simply checking that config is a valid toml """ - from lifeblood.worker import default_config + from lifeblood.main_worker import default_config data = toml.loads(default_config) def test_default_viewer(self): diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index dea9c8d6..3e2dccc1 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -12,6 +12,7 @@ from lifeblood.net_messages.impl.tcp_simple_command_message_processor import TcpJsonMessageProcessor from lifeblood.net_messages.exceptions import MessageTransferError from lifeblood.scheduler_task_protocol import SchedulerTaskClient +from lifeblood.main_scheduler import create_default_scheduler def purge_db(recreate=True): @@ -37,7 +38,7 @@ def tearDownClass(cls) -> None: async def test_stopping_normal(self): purge_db() - sched = Scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() # crudely assert that there are corutines running from scheduler self.assertTrue(any([Path(x.get_coro().cr_code.co_filename).parts[-3:-1] == ('lifeblood', 'scheduler') for x in asyncio.all_tasks()])) @@ -56,7 +57,7 @@ async def test_stopping_nowait(self): tests that scheduler stops even without call to wait_till_stops """ purge_db() - sched = Scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() # crudely assert that there are corutines running from scheduler self.assertTrue(any([Path(x.get_coro().cr_code.co_filename).parts[-3:-1] == ('lifeblood', 'scheduler') for x in asyncio.all_tasks()])) @@ -99,7 +100,7 @@ async def _helper_test_connection_when_stopping(self, try_open_new=False): and it is expected to fail to be opened """ purge_db() - sched = Scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0, server_addr=('127.0.0.1', 11847, 11848)) + sched = create_default_scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0, server_addr=('127.0.0.1', 11847, 11848)) await sched.start() proc = TcpJsonMessageProcessor(('127.0.0.1', 11850)) await proc.start() @@ -145,7 +146,7 @@ async def _helper_test_nonmessage_connection_when_stopping(self, try_open_new=Fa and it is expected to fail to be opened """ purge_db() - sched = Scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0, server_addr=('127.0.0.1', 11847, 11848)) + sched = create_default_scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0, server_addr=('127.0.0.1', 11847, 11848)) await sched.start() async with SchedulerTaskClient('127.0.0.1', 11847) as client: @@ -187,7 +188,7 @@ async def test_get_invocation_workers(self): with mock.patch('lifeblood.scheduler.scheduler.Pinger') as ppatch: ppatch.return_value = mock.AsyncMock(Pinger) - sched = Scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc1.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() with sqlite3.connect('test_swc1.db') as con: diff --git a/tests/test_scheduler_worker_comm.py b/tests/test_scheduler_worker_comm.py index a956be0e..4ee10bf3 100644 --- a/tests/test_scheduler_worker_comm.py +++ b/tests/test_scheduler_worker_comm.py @@ -21,7 +21,7 @@ from lifeblood.config import get_config from lifeblood.nethelpers import get_default_addr from lifeblood.net_messages.address import AddressChain -from lifeblood import launch +from lifeblood.main_scheduler import create_default_scheduler from typing import Awaitable, Callable, List, Optional, Tuple @@ -50,7 +50,7 @@ def tearDownClass(cls) -> None: async def test_simple_start_stop(self): purge_db() - sched = Scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() worker = Worker(sched.server_message_address()) @@ -295,7 +295,7 @@ async def _logic(scheduler, workers: List[Worker], tmp_script_path, done_waiter) async def _helper_test_worker_invocation_api(self, runcode: str, logic: Callable, *, worker_count: int = 1, tasks_to_complete=None): purge_db() - sched = Scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() workers = [] @@ -342,7 +342,7 @@ def _side_effect(task: InvocationJob, stdout: str, stderr: str): async def test_task_get_order(self): purge_db() - sched = Scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + sched = create_default_scheduler('test_swc.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await sched.start() worker = Worker(sched.server_message_address(), scheduler_ping_interval=999) # huge ping interval to prevent pinger from interfering with the test diff --git a/tests/test_serializationv2.py b/tests/test_serializationv2.py new file mode 100644 index 00000000..7992a7d1 --- /dev/null +++ b/tests/test_serializationv2.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Any, Callable, Dict, Optional, Set, Type, Union +from pathlib import Path +from unittest import TestCase +from lifeblood.enums import NodeParameterType +from lifeblood.basenode_serializer_v2 import NodeSerializerV2 +from lifeblood.basenode import BaseNode, NodeUi +from lifeblood.nodegraph_holder_base import NodeGraphHolderBase +from lifeblood.node_dataprovider_base import NodeDataProvider +from lifeblood.snippets import NodeSnippetData + + +class TestGraphHolder(NodeGraphHolderBase): + async def get_node_input_connections(self, node_id: int, input_name: Optional[str] = None): + raise NotImplementedError() + + async def get_node_output_connections(self, node_id: int, output_name: Optional[str] = None): + raise NotImplementedError() + + async def node_reports_changes_needs_saving(self, node_id): + print(self, 'node_reports_changes_needs_saving called') + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + asyncio.get_event_loop() + + +class TestDataProvider(NodeDataProvider): + def node_settings_names(self, type_name: str) -> Set[str]: + raise NotImplementedError() + + def node_settings(self, type_name: str, settings_name: str) -> dict: + raise NotImplementedError() + + def node_type_names(self) -> Set[str]: + raise NotImplementedError() + + def node_class(self, type_name) -> Type[BaseNode]: + raise NotImplementedError() + + def node_factory(self, node_type: str) -> Callable[[str], BaseNode]: + if node_type == 'lelolelolelo': + return TestNode1 + raise NotImplementedError() + + def has_node_factory(self, node_type: str) -> bool: + raise NotImplementedError() + + def node_preset_packages(self) -> Set[str]: + raise NotImplementedError() + + # node presets - + def node_preset_names(self, package_name: str) -> Set[str]: + raise NotImplementedError() + + def node_preset(self, package_name: str, preset_name: str) -> NodeSnippetData: + raise NotImplementedError() + + def add_settings_to_existing_package(self, package_name_or_path: Union[str, Path], node_type_name: str, settings_name: str, settings: Dict[str, Any]): + raise NotImplementedError() + + def set_settings_as_default(self, node_type_name: str, settings_name: Optional[str]): + raise NotImplementedError() + + +class TestNode1(BaseNode): + def __init__(self, name): + super().__init__(name) + ui = self.get_ui() + with ui.initializing_interface_lock(): + ui.add_parameter('test param 1', 'oh, i\'m a label', NodeParameterType.FLOAT, 1.2) + ui.add_parameter('test param 2', 'label 2', NodeParameterType.STRING, 'fofafqf q !@#') + ui.add_parameter('test param 11', 'oh, i\'m a label', NodeParameterType.INT, 4) + with ui.multigroup_parameter_block('oh, multiparam !'): + ui.add_parameter('test multi param 1', 'rerlo', NodeParameterType.BOOL, False) + ui.add_parameter('test multi param 2', 'rerlo', NodeParameterType.STRING, 'f q w ') + with ui.multigroup_parameter_block('oh, mememe multiparam !'): + ui.add_parameter('test multi multi param 1', 'rerlo 123', NodeParameterType.INT, 2) + ui.add_parameter('test multi multi param 2', 'rerlo 234', NodeParameterType.STRING, 'f q w ') + + @classmethod + def type_name(cls) -> str: + return 'lelolelolelo' + + _test_state = { + 'foo': 'barr', + 'dic': {1: 'foo', 2: 'bar', 'viv': {}, -2.3: False, (1, 'a'): {111: 2.34, 'nope': None}}, + 'tup': (2, 3, 'qwe', {'f': 124, 'y': 44, 'z': -3.1, 'e': {1: 123, '1': 234}}), + 'seth': {3, 5, -12, (1, 3, 6, 'bobo')} + } + + def get_state(self) -> Optional[dict]: + return self._test_state + + def set_state(self, state: dict): + assert state == self._test_state + + def _ui_changed(self, definition_changed=False): + print(self, '_ui_changed called') + + +class TestSerialization(TestCase): + def nodes_are_same(self, node1: BaseNode, node2: BaseNode): + # currently there is no param hierarchy comparison + # so THIS IS A LIGHTER CHECK + # TODO: implement proper hierarchy comparison, + # and replace this with just node1 == node2 + + # first sanity check that we compare things we expect + node1_dict = node1.__dict__ + node2_dict = node2.__dict__ + assert '_parameters' in node1_dict + assert '_parameters' in node2_dict + assert isinstance(node1_dict['_parameters'], NodeUi) + assert isinstance(node2_dict['_parameters'], NodeUi) + ui1: NodeUi = node1_dict.pop('_parameters') + ui2: NodeUi = node2_dict.pop('_parameters') + + self.assertDictEqual( + {param.name(): (param.value(), param.expression()) for param in ui1.parameters()}, + {param.name(): (param.value(), param.expression()) for param in ui2.parameters()} + ) + self.assertDictEqual(node1_dict, node2_dict) + + def test_simple(self): + ser = NodeSerializerV2() + parent = TestGraphHolder() + dataprov = TestDataProvider() + + test1 = TestNode1('footest') + test1.set_parent(parent, 123) + + test1.param('test param 1').set_value(3.456) + test1.param('test param 2').set_value(-42) + test1.param('oh, multiparam !').set_value(3) + test1.param('test multi param 1_0').set_value(True) + test1.param('test multi param 2_0').set_value('sheeeeeesh') + test1.param('test multi param 1_1').set_value(True) + test1.param('test multi param 2_0').set_value('sheeeeeeeeeesh') + test1.param('oh, mememe multiparam !_0').set_value(2) + test1.param('oh, mememe multiparam !_2').set_value(2) + test1.param('test multi multi param 1_0.1').set_value(345) + test1.param('test multi multi param 1_0.0').set_value(246) + test1.param('test multi multi param 1_2.1').set_value(456) + test1.param('test multi multi param 1_2.0').set_value(3571) + + test1_data, test1_state = ser.serialize(test1) + + test1_act = ser.deserialize(parent, 123, dataprov, test1_data, test1_state) + + self.nodes_are_same(test1, test1_act) diff --git a/tests/test_ui_state_accessor_integration.py b/tests/test_ui_state_accessor_integration.py index 0f9a351b..8517bc81 100644 --- a/tests/test_ui_state_accessor_integration.py +++ b/tests/test_ui_state_accessor_integration.py @@ -12,6 +12,7 @@ from lifeblood.environment_resolver import EnvironmentResolverArguments from lifeblood.shared_lazy_sqlite_connection import SharedLazyAiosqliteConnection from lifeblood.logging import get_logger +from lifeblood.main_scheduler import create_default_scheduler def purge_db(testdbpath): @@ -37,7 +38,7 @@ async def asyncSetUp(self) -> None: SharedLazyAiosqliteConnection.connection_pools.pop(loop) purge_db('test_uilog.db') - self.sched = Scheduler('test_uilog.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) + self.sched = create_default_scheduler('test_uilog.db', do_broadcasting=False, helpers_minimal_idle_to_ensure=0) await self.sched.start() async def asyncTearDown(self) -> None: diff --git a/tests/test_worker.py b/tests/test_worker.py index 7d36ee94..a0110c1f 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -8,6 +8,7 @@ from lifeblood.logging import set_default_loglevel from lifeblood.invocationjob import InvocationJob, InvocationEnvironment from lifeblood.net_messages.address import AddressChain +from lifeblood.main_scheduler import create_default_scheduler class RunningSchedulerTests(IsolatedAsyncioTestCase): @@ -29,7 +30,7 @@ def tearDownClass(cls) -> None: print('tearingdown done') async def asyncSetUp(self) -> None: - self.scheduler = Scheduler(self.__db_path, do_broadcasting=False, helpers_minimal_idle_to_ensure=0, server_addr=('127.0.0.1', 12347, 12345), server_ui_addr=('127.0.0.1', 12346)) + self.scheduler = create_default_scheduler(self.__db_path, do_broadcasting=False, helpers_minimal_idle_to_ensure=0, server_addr=('127.0.0.1', 12347, 12345), server_ui_addr=('127.0.0.1', 12346)) if not self.scheduler.is_started(): await self.scheduler.start()