Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into tf_algos
Browse files Browse the repository at this point in the history
Merge latest changes from main into tf_algos
  • Loading branch information
falibabaei committed Jul 22, 2024
2 parents 4a0c7a5 + 7ed01c1 commit f70c9b8
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 45 deletions.
1 change: 1 addition & 0 deletions examples/advanced/job-level-authorization/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
cd $DIR
rm -rf workspace
nvflare config -pw /tmp/nvflare/poc
nvflare poc prepare -i project.yml -c site_a
WORKSPACE="/tmp/nvflare/poc/job-level-authorization/prod_00"
cp -r security/site_a/* $WORKSPACE/site_a/local
Expand Down
47 changes: 29 additions & 18 deletions nvflare/apis/utils/reliable_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT)

# start processing
ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}")
ReliableMessage.debug(fl_ctx, f"started processing request of topic {self.topic}")
self.executor.submit(self._do_request, request, fl_ctx)
return _status_reply(STATUS_IN_PROCESS) # ack
elif self.result:
Expand Down Expand Up @@ -143,14 +143,14 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
ReliableMessage.error(fl_ctx, f"aborting processing since exceeded max tx time {self.tx_timeout}")
return _status_reply(STATUS_ABORTED)
else:
ReliableMessage.info(fl_ctx, "got query: request is in-process")
ReliableMessage.debug(fl_ctx, "got query: request is in-process")
return _status_reply(STATUS_IN_PROCESS)

def _try_reply(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
self.replying = True
start_time = time.time()
ReliableMessage.info(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}")
ReliableMessage.debug(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[self.source],
topic=TOPIC_RELIABLE_REPLY,
Expand All @@ -175,7 +175,7 @@ def _try_reply(self, fl_ctx: FLContext):

def _do_request(self, request: Shareable, fl_ctx: FLContext):
start_time = time.time()
ReliableMessage.info(fl_ctx, "invoking request handler")
ReliableMessage.debug(fl_ctx, "invoking request handler")
try:
result = self.request_handler_f(self.topic, request, fl_ctx)
except Exception as e:
Expand All @@ -187,7 +187,7 @@ def _do_request(self, request: Shareable, fl_ctx: FLContext):
result.set_header(HEADER_OP, OP_REPLY)
result.set_header(HEADER_TOPIC, self.topic)
self.result = result
ReliableMessage.info(fl_ctx, f"finished request handler in {time.time()-start_time} secs")
ReliableMessage.debug(fl_ctx, f"finished request handler in {time.time()-start_time} secs")
self._try_reply(fl_ctx)


Expand Down Expand Up @@ -277,12 +277,14 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext):
cls.error(fl_ctx, f"no handler registered for request {rm_topic=}")
return make_reply(ReturnCode.TOPIC_UNKNOWN)
receiver = cls._get_or_create_receiver(rm_topic, request, handler_f)
cls.info(fl_ctx, f"received request {rm_topic=}")
cls.debug(fl_ctx, f"received request {rm_topic=}")
return receiver.process(request, fl_ctx)
elif op == OP_QUERY:
receiver = cls._req_receivers.get(tx_id)
if not receiver:
cls.error(fl_ctx, f"received query but the request ({rm_topic=}) is not received or already done!")
cls.warning(
fl_ctx, f"received query but the request ({rm_topic=} {tx_id=}) is not received or already done!"
)
return _status_reply(STATUS_NOT_RECEIVED) # meaning the request wasn't received
else:
return receiver.process(request, fl_ctx)
Expand All @@ -300,7 +302,7 @@ def _receive_reply(cls, topic: str, request: Shareable, fl_ctx: FLContext):
cls.error(fl_ctx, "received reply but we are no longer waiting for it")
else:
assert isinstance(receiver, _ReplyReceiver)
cls.info(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter")
cls.debug(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter")
receiver.process(request)
return make_reply(ReturnCode.OK)

Expand Down Expand Up @@ -415,6 +417,10 @@ def _log_msg(cls, fl_ctx: FLContext, msg: str):
def info(cls, fl_ctx: FLContext, msg: str):
cls._logger.info(cls._log_msg(fl_ctx, msg))

@classmethod
def warning(cls, fl_ctx: FLContext, msg: str):
cls._logger.warning(cls._log_msg(fl_ctx, msg))

@classmethod
def error(cls, fl_ctx: FLContext, msg: str):
cls._logger.error(cls._log_msg(fl_ctx, msg))
Expand Down Expand Up @@ -511,7 +517,7 @@ def _send_request(
return make_reply(ReturnCode.COMMUNICATION_ERROR)

if num_tries > 0:
cls.info(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}")
cls.debug(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}")

ack = engine.send_aux_request(
targets=[target],
Expand All @@ -528,23 +534,23 @@ def _send_request(
# the reply is already the result - we are done!
# this could happen when we didn't get positive ack for our first request, and the result was
# already produced when we did the 2nd request (this request).
cls.info(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}")
cls.debug(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}")
return ack

# the ack is a status report - check status
status = ack.get_header(HEADER_STATUS)
if status and status != STATUS_NOT_RECEIVED:
# status should never be STATUS_NOT_RECEIVED, unless there is a bug in the receiving logic
# STATUS_NOT_RECEIVED is only possible during "query" phase.
cls.info(fl_ctx, f"received status ack: {rc=} {status=}")
cls.debug(fl_ctx, f"received status ack: {rc=} {status=}")
break

if time.time() + cls._query_interval - receiver.tx_start_time >= tx_timeout:
cls.error(fl_ctx, f"aborting send_request since it will exceed {tx_timeout=}")
return make_reply(ReturnCode.COMMUNICATION_ERROR)

# we didn't get a positive ack - wait a short time and re-send the request.
cls.info(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs")
cls.debug(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs")
num_tries += 1
start = time.time()
while time.time() - start < cls._query_interval:
Expand All @@ -553,7 +559,7 @@ def _send_request(
return make_reply(ReturnCode.TASK_ABORTED)
time.sleep(0.1)

cls.info(fl_ctx, "request was received by the peer - will query for result")
cls.debug(fl_ctx, "request was received by the peer - will query for result")
return cls._query_result(target, abort_signal, fl_ctx, receiver)

@classmethod
Expand Down Expand Up @@ -585,7 +591,7 @@ def _query_result(
# we already received result sent by the target.
# Note that we don't wait forever here - we only wait for _query_interval, so we could
# check other condition and/or send query to ask for result.
cls.info(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds")
cls.debug(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds")
return receiver.result

if abort_signal and abort_signal.triggered:
Expand All @@ -599,21 +605,26 @@ def _query_result(
# send a query. The ack of the query could be the result itself, or a status report.
# Note: the ack could be the result because we failed to receive the result sent by the target earlier.
num_tries += 1
cls.info(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}")
cls.debug(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[target],
topic=TOPIC_RELIABLE_REQUEST,
request=query,
timeout=per_msg_timeout,
fl_ctx=fl_ctx,
)

# Ignore query result if result is already received
if receiver.result_ready.is_set():
return receiver.result

last_query_time = time.time()
ack, rc = _extract_result(ack, target)
if ack and rc not in [ReturnCode.COMMUNICATION_ERROR]:
op = ack.get_header(HEADER_OP)
if op == OP_REPLY:
# the ack is result itself!
cls.info(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds")
cls.debug(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds")
return ack

status = ack.get_header(HEADER_STATUS)
Expand All @@ -625,6 +636,6 @@ def _query_result(
cls.error(fl_ctx, f"peer {target} aborted processing!")
return _error_reply(ReturnCode.EXECUTION_EXCEPTION, "Aborted")

cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
else:
cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")
13 changes: 13 additions & 0 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.fl_model import FLModel, FLModelConst, MetaKey, ParamsType
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
from nvflare.app_common.app_constant import AppConstants
from nvflare.fuel.utils.validation_utils import check_object_type

Expand Down Expand Up @@ -188,6 +189,18 @@ def from_dxo(dxo: DXO) -> FLModel:
meta=meta,
)

@staticmethod
def to_model_learnable(fl_model: FLModel) -> ModelLearnable:
return make_model_learnable(weights=fl_model.params, meta_props=fl_model.meta)

@staticmethod
def from_model_learnable(model_learnable: ModelLearnable) -> FLModel:
return FLModel(
params_type=ParamsType.FULL,
params=model_learnable[ModelLearnableKey.WEIGHTS],
meta=model_learnable[ModelLearnableKey.META],
)

@staticmethod
def get_meta_prop(model: FLModel, key: str, default=None):
check_object_type("model", model, FLModel)
Expand Down
13 changes: 11 additions & 2 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _prepare_task(
name=task_name,
data=data_shareable,
operator=operator,
props={AppConstants.TASK_PROP_CALLBACK: callback},
props={AppConstants.TASK_PROP_CALLBACK: callback, AppConstants.META_DATA: data.meta},
timeout=timeout,
before_task_sent_cb=self._prepare_task_data,
result_received_cb=self._process_result,
Expand All @@ -221,6 +221,7 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:

# Turn result into FLModel
result_model = FLModelUtils.from_shareable(result)
result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA]
result_model.meta["client_name"] = client_name

if result_model.current_round is not None:
Expand All @@ -235,7 +236,7 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
try:
callback(result_model)
except Exception as e:
self.error(f"Unsuccessful callback {callback} for task {client_task.task.name}")
self.error(f"Unsuccessful callback {callback} for task {client_task.task.name}: {e}")
else:
self._results.append(result_model)

Expand Down Expand Up @@ -338,6 +339,14 @@ def load_model(self):

return model

def get_run_dir(self):
"""Get current run directory."""
return self.engine.get_workspace().get_run_dir(self.fl_ctx.get_job_id())

def get_app_dir(self):
"""Get current app directory."""
return self.engine.get_workspace().get_app_dir(self.fl_ctx.get_job_id())

def save_model(self, model):
if self.persistor:
self.info("Start persist model on server.")
Expand Down
Loading

0 comments on commit f70c9b8

Please sign in to comment.