diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 5d12a0e522..364e383320 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -23,6 +23,7 @@ class ReturnCode(object): BAD_REQUEST_DATA = "BAD_REQUEST_DATA" BAD_TASK_DATA = "BAD_TASK_DATA" COMMUNICATION_ERROR = "COMMUNICATION_ERROR" + TIMEOUT = "TIMEOUT" ERROR = "ERROR" EXECUTION_EXCEPTION = "EXECUTION_EXCEPTION" EXECUTION_RESULT_ERROR = "EXECUTION_RESULT_ERROR" @@ -104,6 +105,7 @@ class ReservedKey(object): JOB_IS_UNSAFE = "__job_is_unsafe__" CUSTOM_PROPS = "__custom_props__" EXCEPTIONS = "__exceptions__" + PROCESS_TYPE = "__process_type__" # type of the current process (SP, CP, SJ, CJ) class FLContextKey(object): @@ -184,6 +186,14 @@ class FLContextKey(object): CLIENT_CONFIG = "__client_config__" SERVER_CONFIG = "__server_config__" SERVER_HOST_NAME = "__server_host_name__" + PROCESS_TYPE = ReservedKey.PROCESS_TYPE + + +class ProcessType: + SERVER_PARENT = "SP" + SERVER_JOB = "SJ" + CLIENT_PARENT = "CP" + CLIENT_JOB = "CJ" class ReservedTopic(object): diff --git a/nvflare/apis/fl_context.py b/nvflare/apis/fl_context.py index f17adc9806..2021eed65a 100644 --- a/nvflare/apis/fl_context.py +++ b/nvflare/apis/fl_context.py @@ -216,6 +216,9 @@ def _simple_get(self, key: str, default=None): def get_engine(self, default=None): return self._simple_get(ReservedKey.ENGINE, default) + def get_process_type(self, default=None): + return self._simple_get(ReservedKey.PROCESS_TYPE, default) + def get_job_id(self, default=None): return self._simple_get(ReservedKey.RUN_NUM, default) diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index e3dc589963..4c4013f0c1 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Tuple from nvflare.apis.shareable import Shareable -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamContext from nvflare.widgets.widget import Widget from .client import Client @@ -165,63 +164,6 @@ def fire_and_forget_aux_request( ) -> dict: return self.send_aux_request(targets, topic, request, 0.0, fl_ctx, optional, secure=secure) - @abstractmethod - def stream_objects( - self, - channel: str, - topic: str, - stream_ctx: StreamContext, - targets: List[str], - producer: ObjectProducer, - fl_ctx: FLContext, - optional=False, - secure=False, - ): - """Send a stream of Shareable objects to receivers. - - Args: - channel: the channel for this stream - topic: topic of the stream - stream_ctx: context of the stream - targets: receiving sites - producer: the ObjectProducer that can produces the stream of Shareable objects - fl_ctx: the FLContext object - optional: whether the stream is optional - secure: whether to use P2P security - - Returns: result from the generator's reply processing - - """ - pass - - @abstractmethod - def register_stream_processing( - self, - channel: str, - topic: str, - factory: ConsumerFactory, - stream_done_cb=None, - **cb_kwargs, - ): - """Register a ConsumerFactory for specified app channel and topic. - Once a new streaming request is received for the channel/topic, the registered factory will be used - to create an ObjectConsumer object to handle the new stream. - - Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because - multiple streaming sessions could be going on at the same time. Each streaming session should have its - own ObjectConsumer. - - Args: - channel: app channel - topic: app topic - factory: the factory to be registered - stream_done_cb: the callback to be called when streaming is done on receiving side - - Returns: None - - """ - pass - @abstractmethod def get_widget(self, widget_id: str) -> Widget: """Get the widget with the specified ID. diff --git a/nvflare/apis/shareable.py b/nvflare/apis/shareable.py index 5008853e93..92921a0944 100644 --- a/nvflare/apis/shareable.py +++ b/nvflare/apis/shareable.py @@ -41,9 +41,11 @@ class Shareable(dict): It is recommended that keys are strings. Values must be serializable. """ - def __init__(self): + def __init__(self, data: dict = None): """Init the Shareable.""" super().__init__() + if data: + self.update(data) self[ReservedHeaderKey.HEADERS] = {} def set_header(self, key: str, value): diff --git a/nvflare/apis/streaming.py b/nvflare/apis/streaming.py index 64e3331100..68f509ea7a 100644 --- a/nvflare/apis/streaming.py +++ b/nvflare/apis/streaming.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from builtins import dict as StreamContext -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable @@ -161,3 +161,73 @@ def stream_done_cb_signature(stream_ctx: StreamContext, fl_ctx: FLContext, **kwa """ pass + + +class StreamableEngine(ABC): + """This class defines requirements for streaming capable engines.""" + + @abstractmethod + def stream_objects( + self, + channel: str, + topic: str, + stream_ctx: StreamContext, + targets: List[str], + producer: ObjectProducer, + fl_ctx: FLContext, + optional=False, + secure=False, + ): + """Send a stream of Shareable objects to receivers. + + Args: + channel: the channel for this stream + topic: topic of the stream + stream_ctx: context of the stream + targets: receiving sites + producer: the ObjectProducer that can produces the stream of Shareable objects + fl_ctx: the FLContext object + optional: whether the stream is optional + secure: whether to use P2P security + + Returns: result from the generator's reply processing + + """ + pass + + @abstractmethod + def register_stream_processing( + self, + channel: str, + topic: str, + factory: ConsumerFactory, + stream_done_cb=None, + **cb_kwargs, + ): + """Register a ConsumerFactory for specified app channel and topic. + Once a new streaming request is received for the channel/topic, the registered factory will be used + to create an ObjectConsumer object to handle the new stream. + + Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because + multiple streaming sessions could be going on at the same time. Each streaming session should have its + own ObjectConsumer. + + Args: + channel: app channel + topic: app topic + factory: the factory to be registered + stream_done_cb: the callback to be called when streaming is done on receiving side + + Returns: None + + """ + pass + + @abstractmethod + def shutdown_streamer(self): + """Shutdown the engine's streamer. + + Returns: None + + """ + pass diff --git a/nvflare/app_common/streamers/file_streamer.py b/nvflare/app_common/streamers/file_streamer.py index 39427b41b7..fe88b93d12 100644 --- a/nvflare/app_common/streamers/file_streamer.py +++ b/nvflare/app_common/streamers/file_streamer.py @@ -18,7 +18,7 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReturnCode, Shareable, make_reply -from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamContext +from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext from nvflare.fuel.utils.obj_utils import get_logger from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number @@ -179,6 +179,9 @@ def register_stream_processing( raise ValueError(f"dest_dir '{dest_dir}' is not a valid dir") engine = fl_ctx.get_engine() + if not isinstance(engine, StreamableEngine): + raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") + engine.register_stream_processing( channel=channel, topic=topic, @@ -238,7 +241,12 @@ def stream_file( with open(file_name, "rb") as file: producer = _ChunkProducer(file, chunk_size, chunk_timeout) engine = fl_ctx.get_engine() + + if not isinstance(engine, StreamableEngine): + raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") + stream_ctx[_KEY_FILE_NAME] = os.path.basename(file_name) + return engine.stream_objects( channel=channel, topic=topic, diff --git a/nvflare/private/aux_runner.py b/nvflare/private/aux_runner.py index 83edac9ae3..0f3fb7fa79 100644 --- a/nvflare/private/aux_runner.py +++ b/nvflare/private/aux_runner.py @@ -18,7 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import ConfigVarName, ReturnCode, SystemConfigs +from nvflare.apis.fl_constant import ConfigVarName, ProcessType, ReturnCode, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.fuel.f3.cellnet.core_cell import Message, MessageHeaderKey @@ -59,7 +59,7 @@ def __init__(self, engine): def register_aux_message_handler(self, topic: str, message_handle_func): """Register aux message handling function with specified topics. - This method should be called by ServerEngine's register_aux_message_handler method. + This method should be called by Engine's register_aux_message_handler method. Args: topic: the topic to be handled by the func @@ -196,7 +196,7 @@ def _process_cell_replies( if cell_replies: for reply_cell_fqcn, v in cell_replies.items(): assert isinstance(v, Message) - rc = v.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) + rc = v.get_header(MessageHeaderKey.RETURN_CODE, CellReturnCode.OK) target_name = fqcn_to_name[reply_cell_fqcn] if rc == CellReturnCode.OK: result = v.payload @@ -258,7 +258,6 @@ def _send_multi_requests( if not cell: return {} - job_id = fl_ctx.get_job_id() public_props = fl_ctx.get_all_public_props() target_messages = {} fqcn_to_name = {} @@ -266,17 +265,16 @@ def _send_multi_requests( msg_target, req = t assert isinstance(msg_target, AuxMsgTarget) target_name = msg_target.name - target_fqcn = msg_target.fqcn if not isinstance(req, Shareable): raise ValueError(f"request of {target_name} should be Shareable but got {type(req)}") req.set_header(ReservedHeaderKey.TOPIC, topic) req.set_peer_props(public_props) - job_cell_fqcn = FQCN.join([target_fqcn, job_id]) - self.log_info(fl_ctx, f"sending multicast aux: {job_cell_fqcn=}") - fqcn_to_name[job_cell_fqcn] = target_name - target_messages[job_cell_fqcn] = TargetMessage( - topic=topic, channel=channel, target=job_cell_fqcn, message=Message(payload=req) + cell_fqcn = self._get_target_fqcn(msg_target, fl_ctx) + self.log_debug(fl_ctx, f"sending multicast aux: {cell_fqcn=}") + fqcn_to_name[cell_fqcn] = target_name + target_messages[cell_fqcn] = TargetMessage( + topic=topic, channel=channel, target=cell_fqcn, message=Message(payload=req) ) if timeout > 0: @@ -374,7 +372,6 @@ def _send_to_cell( request.set_header(ReservedHeaderKey.TOPIC, topic) request.set_peer_props(fl_ctx.get_all_public_props()) - job_id = fl_ctx.get_job_id() cell = self._wait_for_cell() if not cell: return {} @@ -383,9 +380,9 @@ def _send_to_cell( fqcn_to_name = {} for t in targets: # targeting job cells! - job_cell_fqcn = FQCN.join([t.fqcn, job_id]) - target_fqcns.append(job_cell_fqcn) - fqcn_to_name[job_cell_fqcn] = t.name + cell_fqcn = self._get_target_fqcn(t, fl_ctx) + target_fqcns.append(cell_fqcn) + fqcn_to_name[cell_fqcn] = t.name cell_msg = Message(payload=request) if timeout > 0: @@ -409,10 +406,25 @@ def _send_to_cell( ) return {} + @staticmethod + def _get_target_fqcn(target: AuxMsgTarget, fl_ctx: FLContext): + process_type = fl_ctx.get_process_type() + if process_type in [ProcessType.CLIENT_PARENT, ProcessType.SERVER_PARENT]: + # parent process + return target.fqcn + elif process_type in [ProcessType.CLIENT_JOB, ProcessType.SERVER_JOB]: + # job process + job_id = fl_ctx.get_job_id() + if not job_id: + raise RuntimeError("no job ID in fl_ctx in Job Process!") + return FQCN.join([target.fqcn, job_id]) + else: + raise RuntimeError(f"invalid process_type {process_type}") + @staticmethod def _convert_return_code(rc): rc_table = { - CellReturnCode.TIMEOUT: ReturnCode.COMMUNICATION_ERROR, + CellReturnCode.TIMEOUT: ReturnCode.TIMEOUT, CellReturnCode.COMM_ERROR: ReturnCode.COMMUNICATION_ERROR, CellReturnCode.PROCESS_EXCEPTION: ReturnCode.EXECUTION_EXCEPTION, CellReturnCode.ABORT_RUN: CellReturnCode.ABORT_RUN, diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index 7618622ad8..544567e0c5 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -108,6 +108,8 @@ def main(args): print("Waiting client cell to be created ....") time.sleep(1.0) + client_engine.initialize_comm(federated_client.cell) + with client_engine.new_context() as fl_ctx: client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 6f293abdb4..c91b9e68e4 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -18,17 +18,26 @@ import shutil import sys import threading +from typing import List from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey, MachineStatus, SystemComponents, WorkspaceConstants +from nvflare.apis.fl_constant import FLContextKey, MachineStatus, ProcessType, SystemComponents, WorkspaceConstants from nvflare.apis.fl_context import FLContext, FLContextManager +from nvflare.apis.shareable import Shareable +from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext +from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.apis.workspace import Workspace -from nvflare.private.defs import ERROR_MSG_PREFIX, ClientStatusKey, EngineConstant +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner +from nvflare.private.defs import ERROR_MSG_PREFIX, ClientStatusKey, EngineConstant, new_cell_message from nvflare.private.event import fire_event from nvflare.private.fed.server.job_meta_validator import JobMetaValidator from nvflare.private.fed.utils.app_deployer import AppDeployer from nvflare.private.fed.utils.fed_utils import security_close +from nvflare.private.stream_runner import ObjectStreamer from nvflare.security.logging import secure_format_exception, secure_log_traceback from .client_engine_internal_spec import ClientEngineInternalSpec @@ -45,8 +54,8 @@ def _remove_custom_path(): sys.path.remove(path) -class ClientEngine(ClientEngineInternalSpec): - """ClientEngine runs in the client parent process.""" +class ClientEngine(ClientEngineInternalSpec, StreamableEngine): + """ClientEngine runs in the client parent process (CP).""" def __init__(self, client: FederatedClient, args, rank, workers=5): """To init the ClientEngine. @@ -64,6 +73,9 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): self.rank = rank self.client_executor = JobExecutor(client, os.path.join(args.workspace, "startup")) self.admin_agent = None + self.aux_runner = AuxRunner(self) + self.object_streamer = ObjectStreamer(self.aux_runner) + self.cell = None self.fl_ctx_mgr = FLContextManager( engine=self, @@ -76,6 +88,7 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): SystemComponents.FED_CLIENT: client, FLContextKey.SECURE_MODE: self.client.secure_train, FLContextKey.WORKSPACE_ROOT: args.workspace, + FLContextKey.PROCESS_TYPE: ProcessType.CLIENT_PARENT, }, ) @@ -89,6 +102,189 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): def fire_event(self, event_type: str, fl_ctx: FLContext): fire_event(event=event_type, handlers=self.fl_components, ctx=fl_ctx) + def get_cell(self): + """Get the communication cell. + This method must be implemented since AuxRunner calls to get cell. + + Returns: + + """ + return self.cell + + def initialize_comm(self, cell: Cell): + """This is called when communication cell has been created. + We will set up aux message handler here. + + Args: + cell: + + Returns: + + """ + cell.register_request_cb( + channel=CellChannel.AUX_COMMUNICATION, + topic="*", + cb=self._handle_aux_message, + ) + self.cell = cell + + def _handle_aux_message(self, request: CellMessage) -> CellMessage: + assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) + data = request.payload + + topic = request.get_header(MessageHeaderKey.TOPIC) + with self.new_context() as fl_ctx: + reply = self.aux_runner.dispatch(topic=topic, request=data, fl_ctx=fl_ctx) + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + reply.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx) + + if reply is not None: + return_message = new_cell_message({}, reply) + return_message.set_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK) + else: + return_message = new_cell_message({}, None) + return return_message + + def register_aux_message_handler(self, topic: str, message_handle_func): + """Register aux message handling function with specified topics. + + Exception is raised when: + a handler is already registered for the topic; + bad topic - must be a non-empty string + bad message_handle_func - must be callable + + Implementation Note: + This method should simply call the ServerAuxRunner's register_aux_message_handler method. + + Args: + topic: the topic to be handled by the func + message_handle_func: the func to handle the message. Must follow aux_message_handle_func_signature. + + """ + self.aux_runner.register_aux_message_handler(topic, message_handle_func) + + def send_aux_request( + self, + topic: str, + request: Shareable, + timeout: float, + fl_ctx: FLContext, + optional=False, + secure=False, + ) -> Shareable: + """Send a request to the Server via the aux channel. + + Implementation: simply calls the AuxRunner's send_aux_request method. + + Args: + topic: topic of the request. + request: request to be sent + timeout: number of secs to wait for replies. 0 means fire-and-forget. + fl_ctx: FL context + optional: whether this message is optional + secure: send the aux request in a secure way + + Returns: a dict of replies (client name => reply Shareable) + + """ + reply = self.aux_runner.send_aux_request( + targets=[AuxMsgTarget.server_target()], + topic=topic, + request=request, + timeout=timeout, + fl_ctx=fl_ctx, + optional=optional, + secure=secure, + ) + + self.logger.info(f"got aux reply: {reply}") + + if len(reply) > 0: + return next(iter(reply.values())) + else: + return Shareable() + + def stream_objects( + self, + channel: str, + topic: str, + stream_ctx: StreamContext, + targets: List[str], + producer: ObjectProducer, + fl_ctx: FLContext, + optional=False, + secure=False, + ): + """Send a stream of Shareable objects to receivers. + + Args: + channel: the channel for this stream + topic: topic of the stream + stream_ctx: context of the stream + targets: receiving sites + producer: the ObjectProducer that can produces the stream of Shareable objects + fl_ctx: the FLContext object + optional: whether the stream is optional + secure: whether to use P2P security + + Returns: result from the generator's reply processing + + """ + if not self.object_streamer: + raise RuntimeError("object streamer has not been created") + + # We are CP: can only stream to SP + if targets: + for t in targets: + self.logger.debug(f"ignored target: {t}") + + return self.object_streamer.stream( + channel=channel, + topic=topic, + stream_ctx=stream_ctx, + targets=[AuxMsgTarget.server_target()], + producer=producer, + fl_ctx=fl_ctx, + secure=secure, + optional=optional, + ) + + def register_stream_processing( + self, + channel: str, + topic: str, + factory: ConsumerFactory, + stream_done_cb=None, + **cb_kwargs, + ): + """Register a ConsumerFactory for specified app channel and topic. + Once a new streaming request is received for the channel/topic, the registered factory will be used + to create an ObjectConsumer object to handle the new stream. + + Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because + multiple streaming sessions could be going on at the same time. Each streaming session should have its + own ObjectConsumer. + + Args: + channel: app channel + topic: app topic + factory: the factory to be registered + stream_done_cb: the callback to be called when streaming is done on receiving side + + Returns: None + + """ + if not self.object_streamer: + raise RuntimeError("object streamer has not been created") + + self.object_streamer.register_stream_processing( + topic=topic, channel=channel, factory=factory, stream_done_cb=stream_done_cb, **cb_kwargs + ) + + def shutdown_streamer(self): + if self.object_streamer: + self.object_streamer.shutdown() + def set_agent(self, admin_agent): self.admin_agent = admin_agent @@ -233,6 +429,7 @@ def shutdown(self) -> str: thread = threading.Thread(target=shutdown_client, args=(self.client, touch_file)) thread.start() + self.shutdown_streamer() return "Shutdown the client..." def restart(self) -> str: diff --git a/nvflare/private/fed/client/client_engine_executor_spec.py b/nvflare/private/fed/client/client_engine_executor_spec.py index 3d8897fc84..062955d505 100644 --- a/nvflare/private/fed/client/client_engine_executor_spec.py +++ b/nvflare/private/fed/client/client_engine_executor_spec.py @@ -20,7 +20,6 @@ from nvflare.apis.engine_spec import EngineSpec from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamContext from nvflare.apis.workspace import Workspace from nvflare.widgets.widget import Widget @@ -161,63 +160,6 @@ def fire_and_forget_aux_request( """ pass - @abstractmethod - def stream_objects( - self, - channel: str, - topic: str, - stream_ctx: StreamContext, - targets: List[str], - producer: ObjectProducer, - fl_ctx: FLContext, - optional=False, - secure=False, - ): - """Send a stream of Shareable objects to receivers. - - Args: - channel: the channel for this stream - topic: topic of the stream - stream_ctx: context of the steam - targets: receiving sites - producer: the ObjectProducer that can produce the stream of Shareable objects - fl_ctx: the FLContext object - optional: whether the stream is optional - secure: whether to use P2P security - - Returns: result from the producer's reply processing - - """ - pass - - @abstractmethod - def register_stream_processing( - self, - channel: str, - topic: str, - factory: ConsumerFactory, - stream_done_cb=None, - **cb_kwargs, - ): - """Register a ConsumerFactory for specified app channel and topic - Once a new streaming request is received for the channel/topic, the registered factory will be used - to create an ObjectConsumer object to consume objects of the stream. - - Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because - multiple streaming sessions could be going on at the same time. Each streaming session should have its - own ObjectConsumer. - - Args: - channel: app channel - topic: app topic - factory: the factory to be registered - stream_done_cb: callback to be called when streaming is done on receiving side - - Returns: None - - """ - pass - @abstractmethod def build_component(self, config_dict): """Build a component from the config_dict. diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index d7cc5e0dd6..21bc957f28 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -17,10 +17,10 @@ from typing import Dict, List, Optional, Union from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey, ServerCommandKey, ServerCommandNames, SiteType +from nvflare.apis.fl_constant import FLContextKey, ProcessType, ServerCommandKey, ServerCommandNames, SiteType from nvflare.apis.fl_context import FLContext, FLContextManager from nvflare.apis.shareable import Shareable -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamContext +from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext from nvflare.apis.workspace import Workspace from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey @@ -58,8 +58,8 @@ def __init__(self, job_id): GET_CLIENTS_RETRY = 300 -class ClientRunManager(ClientEngineExecutorSpec): - """ClientRunManager provides the ClientEngine APIs implementation running in the child process.""" +class ClientRunManager(ClientEngineExecutorSpec, StreamableEngine): + """ClientRunManager provides the ClientEngine APIs implementation running in the child process (CJ).""" def __init__( self, @@ -105,6 +105,7 @@ def __init__( # get job meta! job_ctx_props = self.create_job_processing_context_properties(workspace, job_id) + job_ctx_props.update({FLContextKey.PROCESS_TYPE: ProcessType.CLIENT_JOB}) self.fl_ctx_mgr = FLContextManager( engine=self, identity_name=client_name, job_id=job_id, public_stickers={}, private_stickers=job_ctx_props ) @@ -347,6 +348,9 @@ def stream_objects( optional=False, secure=False, ): + if not self.object_streamer: + raise RuntimeError("object streamer has not been created") + return self.object_streamer.stream( channel=channel, topic=topic, @@ -366,8 +370,15 @@ def register_stream_processing( stream_done_cb=None, **cb_kwargs, ): + if not self.object_streamer: + raise RuntimeError("object streamer has not been created") + self.object_streamer.register_stream_processing(channel, topic, factory, stream_done_cb, **cb_kwargs) + def shutdown_streamer(self): + if self.object_streamer: + self.object_streamer.shutdown() + def abort_app(self, job_id: str, fl_ctx: FLContext): runner = fl_ctx.get_prop(key=FLContextKey.RUNNER, default=None) if isinstance(runner, ClientRunner): diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index 5d11cbd395..d063642ae0 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -583,6 +583,7 @@ def run(self, app_root, args): finally: self.end_run_events_sequence() ReliableMessage.shutdown() + self.engine.shutdown_streamer() with self.task_lock: self.running_tasks = {} diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 3bb79495c4..0cdfa616c4 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -893,7 +893,7 @@ def deploy(self, args, grpc_args=None, secure_train=False): prv_key_path=grpc_args["ssl_private_key"], ) - self.engine.cell = self.cell + self.engine.initialize_comm(self.cell) self._register_cellnet_cbs() self.overseer_agent.start(self.overseer_callback) diff --git a/nvflare/private/fed/server/run_manager.py b/nvflare/private/fed/server/run_manager.py index 475eab0b5f..ba63130afb 100644 --- a/nvflare/private/fed/server/run_manager.py +++ b/nvflare/private/fed/server/run_manager.py @@ -17,6 +17,7 @@ from nvflare.apis.client import Client from nvflare.apis.engine_spec import EngineSpec from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_constant import FLContextKey, ProcessType from nvflare.apis.fl_context import FLContext, FLContextManager from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.workspace import Workspace @@ -63,8 +64,9 @@ def __init__( if job_id: job_ctx_props = self.create_job_processing_context_properties(workspace, job_id) + job_ctx_props.update({FLContextKey.PROCESS_TYPE: ProcessType.SERVER_JOB}) else: - job_ctx_props = {} + job_ctx_props = {FLContextKey.PROCESS_TYPE: ProcessType.SERVER_PARENT} self.fl_ctx_mgr = FLContextManager( engine=engine, identity_name=server_name, job_id=job_id, public_stickers={}, private_stickers=job_ctx_props diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index 295fb31c7f..f9b11833e3 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -30,25 +30,26 @@ AdminCommandNames, FLContextKey, MachineStatus, - ReturnCode, RunProcessKey, ServerCommandKey, ServerCommandNames, SnapshotKey, WorkspaceConstants, ) -from nvflare.apis.fl_context import FLContext, FLContextManager +from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_snapshot import RunSnapshot from nvflare.apis.impl.job_def_manager import JobDefManagerSpec from nvflare.apis.job_def import Job from nvflare.apis.job_launcher_spec import JobLauncherSpec -from nvflare.apis.shareable import Shareable, make_reply -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamContext -from nvflare.apis.utils.fl_context_utils import get_serializable_data +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext +from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx, get_serializable_data from nvflare.apis.workspace import Workspace -from nvflare.fuel.f3.cellnet.core_cell import FQCN +from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellMsgReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.zip_utils import zip_directory_to_bytes from nvflare.private.admin_defs import Message, MsgHeader @@ -76,7 +77,7 @@ from .server_status import ServerStatus -class ServerEngine(ServerEngineInternalSpec): +class ServerEngine(ServerEngineInternalSpec, StreamableEngine): def __init__(self, server, args, client_manager: ClientManager, snapshot_persistor, workers=3): """Server engine. @@ -93,7 +94,7 @@ def __init__(self, server, args, client_manager: ClientManager, snapshot_persist self.exception_run_processes = {} self.run_manager = None self.conf = None - # TODO:: does this class need client manager? + self.cell = None self.client_manager = client_manager self.widgets = { @@ -381,10 +382,61 @@ def send_app_command(self, job_id: str, topic: str, cmd_data, timeout: float) -> return result def set_run_manager(self, run_manager: RunManager): + self.logger.debug("set_run_manager is called") self.run_manager = run_manager + + # we set the run_manager's cell if we have the cell. + if self.cell: + self.run_manager.cell = self.cell + for _, widget in self.widgets.items(): self.run_manager.add_handler(widget) + def get_cell(self): + return self.cell + + def initialize_comm(self, cell: Cell): + """This is called when the communication cell has been created. + We will set up aux message handler here. + + Args: + cell: + + Returns: + + """ + self.logger.info("initialize_comm called!") + self.cell = cell + if self.run_manager: + # Note that the aux_runner is created with the self.run_manager as the "engine". + # We must set the cell in it; otherwise it won't be able to send messages. + # The timing of the creation of the run_manager and the cell is not deterministic, we set the cell here + # only if the run_manager has been created. + self.run_manager.cell = cell + + cell.register_request_cb( + channel=CellChannel.AUX_COMMUNICATION, + topic="*", + cb=self._handle_aux_message, + ) + + def _handle_aux_message(self, request: CellMessage) -> CellMessage: + assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) + data = request.payload + + topic = request.get_header(MessageHeaderKey.TOPIC) + with self.new_context() as fl_ctx: + reply = self.run_manager.aux_runner.dispatch(topic=topic, request=data, fl_ctx=fl_ctx) + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + reply.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx) + + if reply is not None: + return_message = new_cell_message({}, reply) + return_message.set_header(MessageHeaderKey.RETURN_CODE, CellMsgReturnCode.OK) + else: + return_message = new_cell_message({}, None) + return return_message + def set_job_runner(self, job_runner: JobRunner, job_manager: JobDefManagerSpec): self.job_runner = job_runner self.job_def_manager = job_manager @@ -401,10 +453,8 @@ def new_context(self) -> FLContext: if self.run_manager: return self.run_manager.new_context() else: - # return FLContext() - return FLContextManager( - engine=self, identity_name=self.server.project_name, job_id="", public_stickers={}, private_stickers={} - ).new_context() + # this call should never be made before the run_manager is created! + raise RuntimeError("no run_manager in Server Engine.") def add_component(self, component_id: str, component): self.server.runner_config.add_component(component_id, component) @@ -563,6 +613,12 @@ def stream_objects( optional=False, secure=False, ): + if not self.run_manager: + raise RuntimeError("run_manager has not been created") + + if not self.run_manager.object_streamer: + raise RuntimeError("object_streamer has not been created") + return self.run_manager.object_streamer.stream( channel=channel, topic=topic, @@ -582,10 +638,20 @@ def register_stream_processing( stream_done_cb=None, **cb_kwargs, ): + if not self.run_manager: + raise RuntimeError("run_manager has not been created") + + if not self.run_manager.object_streamer: + raise RuntimeError("object_streamer has not been created") + self.run_manager.object_streamer.register_stream_processing( channel=channel, topic=topic, factory=factory, stream_done_cb=stream_done_cb, **cb_kwargs ) + def shutdown_streamer(self): + if self.run_manager and self.run_manager.object_streamer: + self.run_manager.object_streamer.shutdown() + def sync_clients_from_main_process(self): # repeatedly ask the parent process to get participating clients until we receive the result # or timed out after 30 secs (already tried 30 times). @@ -888,6 +954,7 @@ def pause_server_jobs(self): def close(self): self.executor.shutdown() + self.shutdown_streamer() def server_shutdown(server, touch_file): diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 4831cacfe8..81c47f4848 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -218,6 +218,7 @@ def run(self): self.log_info(fl_ctx, "END_RUN fired") ReliableMessage.shutdown() + self.engine.shutdown_streamer() self.log_info(fl_ctx, "Server runner finished.") def handle_event(self, event_type: str, fl_ctx: FLContext): diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index e62a092046..4efbe1bf17 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -37,6 +37,7 @@ ) from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.job_def import JobMetaKey +from nvflare.apis.job_launcher_spec import JobLauncherSpec from nvflare.apis.utils.decomposers import flare_decomposers from nvflare.apis.workspace import Workspace from nvflare.app_common.decomposers import common_decomposers @@ -542,7 +543,7 @@ def get_scope_prop(scope_name: str, key: str) -> Any: return data_bus.get_data(_scope_prop_key(scope_name, key)) -def get_job_launcher(job_meta: dict, fl_ctx: FLContext) -> dict: +def get_job_launcher(job_meta: dict, fl_ctx: FLContext) -> JobLauncherSpec: engine = fl_ctx.get_engine() with engine.new_context() as job_launcher_ctx: @@ -555,4 +556,8 @@ def get_job_launcher(job_meta: dict, fl_ctx: FLContext) -> dict: if not (job_launcher and isinstance(job_launcher, list)): raise RuntimeError(f"There's no job launcher can handle this job: {job_meta}.") + launcher = job_launcher[0] + if not isinstance(launcher, JobLauncherSpec): + raise RuntimeError(f"The job launcher must be JobLauncherSpec but got {type(launcher)}") + return job_launcher[0] diff --git a/nvflare/private/stream_runner.py b/nvflare/private/stream_runner.py index c3a5566d17..3bae0dc0c1 100644 --- a/nvflare/private/stream_runner.py +++ b/nvflare/private/stream_runner.py @@ -13,6 +13,7 @@ # limitations under the License. import time import uuid +from concurrent.futures import Future, ThreadPoolExecutor from threading import Lock from typing import Any, List, Tuple @@ -22,6 +23,7 @@ from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamContext, StreamContextKey from nvflare.fuel.f3.cellnet.registry import Registry +from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.obj_utils import get_logger from nvflare.fuel.utils.validation_utils import check_callable, check_object_type, check_str from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner @@ -109,6 +111,10 @@ def __init__(self, aux_runner: AuxRunner): self.tx_table = {} # tx_id => _ProcessorInfo self.logger = get_logger(self) + # Note: the ConfigService has been initialized + max_concurrent_streaming_sessions = ConfigService.get_int_var("max_concurrent_streaming_sessions", default=20) + self.streaming_executor = ThreadPoolExecutor(max_workers=max_concurrent_streaming_sessions) + aux_runner.register_aux_message_handler( topic=TOPIC_STREAM_REQUEST, message_handle_func=self._handle_request, @@ -118,6 +124,13 @@ def __init__(self, aux_runner: AuxRunner): message_handle_func=self._handle_abort, ) + def shutdown(self): + e = self.streaming_executor + self.streaming_executor = None + if e: + e.shutdown(wait=False, cancel_futures=True) + self.logger.info("Stream Runer is Shut Down") + def register_stream_processing( self, channel: str, @@ -429,3 +442,29 @@ def stream( self._notify_abort_streaming(targets, tx_id, secure, fl_ctx) self.logger.debug(f"Done streaming: {rc}") return rc, result + + def stream_no_wait( + self, + channel: str, + topic: str, + stream_ctx: StreamContext, + targets: List[AuxMsgTarget], + producer: ObjectProducer, + fl_ctx: FLContext, + secure=False, + optional=False, + ) -> Future: + if not self.streaming_executor: + raise RuntimeError("streaming_executor is not available: the streamer has been shut down!") + + return self.streaming_executor.submit( + self.stream, + channel=channel, + topic=topic, + stream_ctx=stream_ctx, + targets=targets, + producer=producer, + fl_ctx=fl_ctx, + secure=secure, + optional=optional, + ) diff --git a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py index 9ff5a3f34a..8e0520053c 100644 --- a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py +++ b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py @@ -24,7 +24,6 @@ from nvflare.apis.job_scheduler_spec import DispatchInfo from nvflare.apis.resource_manager_spec import ResourceManagerSpec from nvflare.apis.server_engine_spec import ServerEngineSpec -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamContext from nvflare.app_common.job_schedulers.job_scheduler import DefaultJobScheduler from nvflare.app_common.resource_managers.list_resource_manager import ListResourceManager @@ -126,29 +125,6 @@ def multicast_aux_requests( ) -> dict: pass - def stream_objects( - self, - channel: str, - topic: str, - stream_ctx: StreamContext, - targets: List[str], - producer: ObjectProducer, - fl_ctx: FLContext, - optional=False, - secure=False, - ): - pass - - def register_stream_processing( - self, - channel: str, - topic: str, - factory: ConsumerFactory, - stream_done_cb=None, - **cb_kwargs, - ): - pass - def get_widget(self, widget_id: str): pass