diff --git a/examples/advanced/job-level-authorization/setup.sh b/examples/advanced/job-level-authorization/setup.sh index a9c9a71f8a..cdf698cc43 100755 --- a/examples/advanced/job-level-authorization/setup.sh +++ b/examples/advanced/job-level-authorization/setup.sh @@ -2,6 +2,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" cd $DIR rm -rf workspace +nvflare config -pw /tmp/nvflare/poc nvflare poc prepare -i project.yml -c site_a WORKSPACE="/tmp/nvflare/poc/job-level-authorization/prod_00" cp -r security/site_a/* $WORKSPACE/site_a/local diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index c046d459bc..71f7365847 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -111,7 +111,7 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable: self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT) # start processing - ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}") + ReliableMessage.debug(fl_ctx, f"started processing request of topic {self.topic}") self.executor.submit(self._do_request, request, fl_ctx) return _status_reply(STATUS_IN_PROCESS) # ack elif self.result: @@ -143,14 +143,14 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable: ReliableMessage.error(fl_ctx, f"aborting processing since exceeded max tx time {self.tx_timeout}") return _status_reply(STATUS_ABORTED) else: - ReliableMessage.info(fl_ctx, "got query: request is in-process") + ReliableMessage.debug(fl_ctx, "got query: request is in-process") return _status_reply(STATUS_IN_PROCESS) def _try_reply(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() self.replying = True start_time = time.time() - ReliableMessage.info(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}") + ReliableMessage.debug(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}") ack = engine.send_aux_request( targets=[self.source], topic=TOPIC_RELIABLE_REPLY, @@ -175,7 +175,7 @@ def _try_reply(self, fl_ctx: FLContext): def _do_request(self, request: Shareable, fl_ctx: FLContext): start_time = time.time() - ReliableMessage.info(fl_ctx, "invoking request handler") + ReliableMessage.debug(fl_ctx, "invoking request handler") try: result = self.request_handler_f(self.topic, request, fl_ctx) except Exception as e: @@ -187,7 +187,7 @@ def _do_request(self, request: Shareable, fl_ctx: FLContext): result.set_header(HEADER_OP, OP_REPLY) result.set_header(HEADER_TOPIC, self.topic) self.result = result - ReliableMessage.info(fl_ctx, f"finished request handler in {time.time()-start_time} secs") + ReliableMessage.debug(fl_ctx, f"finished request handler in {time.time()-start_time} secs") self._try_reply(fl_ctx) @@ -277,12 +277,14 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext): cls.error(fl_ctx, f"no handler registered for request {rm_topic=}") return make_reply(ReturnCode.TOPIC_UNKNOWN) receiver = cls._get_or_create_receiver(rm_topic, request, handler_f) - cls.info(fl_ctx, f"received request {rm_topic=}") + cls.debug(fl_ctx, f"received request {rm_topic=}") return receiver.process(request, fl_ctx) elif op == OP_QUERY: receiver = cls._req_receivers.get(tx_id) if not receiver: - cls.error(fl_ctx, f"received query but the request ({rm_topic=}) is not received or already done!") + cls.warning( + fl_ctx, f"received query but the request ({rm_topic=} {tx_id=}) is not received or already done!" + ) return _status_reply(STATUS_NOT_RECEIVED) # meaning the request wasn't received else: return receiver.process(request, fl_ctx) @@ -300,7 +302,7 @@ def _receive_reply(cls, topic: str, request: Shareable, fl_ctx: FLContext): cls.error(fl_ctx, "received reply but we are no longer waiting for it") else: assert isinstance(receiver, _ReplyReceiver) - cls.info(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter") + cls.debug(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter") receiver.process(request) return make_reply(ReturnCode.OK) @@ -415,6 +417,10 @@ def _log_msg(cls, fl_ctx: FLContext, msg: str): def info(cls, fl_ctx: FLContext, msg: str): cls._logger.info(cls._log_msg(fl_ctx, msg)) + @classmethod + def warning(cls, fl_ctx: FLContext, msg: str): + cls._logger.warning(cls._log_msg(fl_ctx, msg)) + @classmethod def error(cls, fl_ctx: FLContext, msg: str): cls._logger.error(cls._log_msg(fl_ctx, msg)) @@ -511,7 +517,7 @@ def _send_request( return make_reply(ReturnCode.COMMUNICATION_ERROR) if num_tries > 0: - cls.info(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}") + cls.debug(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}") ack = engine.send_aux_request( targets=[target], @@ -528,7 +534,7 @@ def _send_request( # the reply is already the result - we are done! # this could happen when we didn't get positive ack for our first request, and the result was # already produced when we did the 2nd request (this request). - cls.info(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}") + cls.debug(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}") return ack # the ack is a status report - check status @@ -536,7 +542,7 @@ def _send_request( if status and status != STATUS_NOT_RECEIVED: # status should never be STATUS_NOT_RECEIVED, unless there is a bug in the receiving logic # STATUS_NOT_RECEIVED is only possible during "query" phase. - cls.info(fl_ctx, f"received status ack: {rc=} {status=}") + cls.debug(fl_ctx, f"received status ack: {rc=} {status=}") break if time.time() + cls._query_interval - receiver.tx_start_time >= tx_timeout: @@ -544,7 +550,7 @@ def _send_request( return make_reply(ReturnCode.COMMUNICATION_ERROR) # we didn't get a positive ack - wait a short time and re-send the request. - cls.info(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs") + cls.debug(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs") num_tries += 1 start = time.time() while time.time() - start < cls._query_interval: @@ -553,7 +559,7 @@ def _send_request( return make_reply(ReturnCode.TASK_ABORTED) time.sleep(0.1) - cls.info(fl_ctx, "request was received by the peer - will query for result") + cls.debug(fl_ctx, "request was received by the peer - will query for result") return cls._query_result(target, abort_signal, fl_ctx, receiver) @classmethod @@ -585,7 +591,7 @@ def _query_result( # we already received result sent by the target. # Note that we don't wait forever here - we only wait for _query_interval, so we could # check other condition and/or send query to ask for result. - cls.info(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds") + cls.debug(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds") return receiver.result if abort_signal and abort_signal.triggered: @@ -599,7 +605,7 @@ def _query_result( # send a query. The ack of the query could be the result itself, or a status report. # Note: the ack could be the result because we failed to receive the result sent by the target earlier. num_tries += 1 - cls.info(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}") + cls.debug(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}") ack = engine.send_aux_request( targets=[target], topic=TOPIC_RELIABLE_REQUEST, @@ -607,13 +613,18 @@ def _query_result( timeout=per_msg_timeout, fl_ctx=fl_ctx, ) + + # Ignore query result if result is already received + if receiver.result_ready.is_set(): + return receiver.result + last_query_time = time.time() ack, rc = _extract_result(ack, target) if ack and rc not in [ReturnCode.COMMUNICATION_ERROR]: op = ack.get_header(HEADER_OP) if op == OP_REPLY: # the ack is result itself! - cls.info(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds") + cls.debug(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds") return ack status = ack.get_header(HEADER_STATUS) @@ -625,6 +636,6 @@ def _query_result( cls.error(fl_ctx, f"peer {target} aborted processing!") return _error_reply(ReturnCode.EXECUTION_EXCEPTION, "Aborted") - cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}") + cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}") else: - cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}") + cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}") diff --git a/nvflare/app_common/utils/fl_model_utils.py b/nvflare/app_common/utils/fl_model_utils.py index ac7201e414..b2a8795e2c 100644 --- a/nvflare/app_common/utils/fl_model_utils.py +++ b/nvflare/app_common/utils/fl_model_utils.py @@ -19,6 +19,7 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.app_common.abstract.fl_model import FLModel, FLModelConst, MetaKey, ParamsType +from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable from nvflare.app_common.app_constant import AppConstants from nvflare.fuel.utils.validation_utils import check_object_type @@ -188,6 +189,18 @@ def from_dxo(dxo: DXO) -> FLModel: meta=meta, ) + @staticmethod + def to_model_learnable(fl_model: FLModel) -> ModelLearnable: + return make_model_learnable(weights=fl_model.params, meta_props=fl_model.meta) + + @staticmethod + def from_model_learnable(model_learnable: ModelLearnable) -> FLModel: + return FLModel( + params_type=ParamsType.FULL, + params=model_learnable[ModelLearnableKey.WEIGHTS], + meta=model_learnable[ModelLearnableKey.META], + ) + @staticmethod def get_meta_prop(model: FLModel, key: str, default=None): check_object_type("model", model, FLModel) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index f0b865ac8b..6ffa7436f1 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -202,7 +202,7 @@ def _prepare_task( name=task_name, data=data_shareable, operator=operator, - props={AppConstants.TASK_PROP_CALLBACK: callback}, + props={AppConstants.TASK_PROP_CALLBACK: callback, AppConstants.META_DATA: data.meta}, timeout=timeout, before_task_sent_cb=self._prepare_task_data, result_received_cb=self._process_result, @@ -221,6 +221,7 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: # Turn result into FLModel result_model = FLModelUtils.from_shareable(result) + result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA] result_model.meta["client_name"] = client_name if result_model.current_round is not None: @@ -235,7 +236,7 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: try: callback(result_model) except Exception as e: - self.error(f"Unsuccessful callback {callback} for task {client_task.task.name}") + self.error(f"Unsuccessful callback {callback} for task {client_task.task.name}: {e}") else: self._results.append(result_model) @@ -338,6 +339,14 @@ def load_model(self): return model + def get_run_dir(self): + """Get current run directory.""" + return self.engine.get_workspace().get_run_dir(self.fl_ctx.get_job_id()) + + def get_app_dir(self): + """Get current app directory.""" + return self.engine.get_workspace().get_app_dir(self.fl_ctx.get_job_id()) + def save_model(self, model): if self.persistor: self.info("Start persist model on server.") diff --git a/nvflare/app_common/workflows/cross_site_eval.py b/nvflare/app_common/workflows/cross_site_eval.py new file mode 100644 index 0000000000..4d569e27a9 --- /dev/null +++ b/nvflare/app_common/workflows/cross_site_eval.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import time + +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.app_constant import AppConstants, DefaultCheckpointFileName, ModelName +from nvflare.app_common.utils.fl_model_utils import FLModelUtils +from nvflare.fuel.utils import fobs + +from .model_controller import ModelController + + +class CrossSiteEval(ModelController): + def __init__( + self, + *args, + cross_val_dir=AppConstants.CROSS_VAL_DIR, + submit_model_timeout=600, + validation_timeout: int = 6000, + server_models=[DefaultCheckpointFileName.GLOBAL_MODEL], + participating_clients=None, + num_clients: int = 2, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self._cross_val_dir = cross_val_dir + self._submit_model_timeout = submit_model_timeout + self._validation_timeout = validation_timeout + self._server_models = server_models + self._participating_clients = participating_clients + self.num_clients = num_clients + + self._val_results = {} + self._client_models = {} + + self._cross_val_models_dir = None + self._cross_val_results_dir = None + + self._results_dir = AppConstants.CROSS_VAL_DIR + self._json_val_results = {} + self._json_file_name = "cross_val_results.json" + + def initialize(self, fl_ctx): + super().initialize(fl_ctx) + + # Create shareable dirs for models and results + cross_val_path = os.path.join(self.get_run_dir(), self._cross_val_dir) + self._cross_val_models_dir = os.path.join(cross_val_path, AppConstants.CROSS_VAL_MODEL_DIR_NAME) + self._cross_val_results_dir = os.path.join(cross_val_path, AppConstants.CROSS_VAL_RESULTS_DIR_NAME) + + # Cleanup/create the cross val models and results directories + if os.path.exists(self._cross_val_models_dir): + shutil.rmtree(self._cross_val_models_dir) + if os.path.exists(self._cross_val_results_dir): + shutil.rmtree(self._cross_val_results_dir) + + # Recreate new directories. + os.makedirs(self._cross_val_models_dir) + os.makedirs(self._cross_val_results_dir) + + self._participating_clients = self.sample_clients(self.num_clients) + + for c_name in self._participating_clients: + self._client_models[c_name.name] = None + self._val_results[c_name.name] = {} + + def run(self) -> None: + self.info("Start Cross-Site Evaluation.") + + data = FLModel(params={}) + data.meta[AppConstants.SUBMIT_MODEL_NAME] = ModelName.BEST_MODEL + # Create submit_model task and broadcast to all participating clients + self.send_model( + task_name=AppConstants.TASK_SUBMIT_MODEL, + data=data, + targets=self._participating_clients, + timeout=self._submit_model_timeout, + callback=self._receive_local_model_cb, + ) + + # Obtain server models and send to clients for validation + if self.persistor: + for server_model_name in self._server_models: + server_model_path = os.path.join(self.get_app_dir(), server_model_name) + server_model_learnable = self.persistor.get_model_from_location(server_model_path, self.fl_ctx) + server_model = FLModelUtils.from_model_learnable(server_model_learnable) + self._send_validation_task(server_model_name, server_model) + else: + for server_model_name in self._server_models: + try: + server_model = fobs.loadf(server_model_name) + self._send_validation_task(server_model_name, server_model) + except Exception as e: + self.exception(f"Unable to load server model {server_model_name}: {e}") + + # Wait for all standing tasks to complete, since we used non-blocking `send_model()` + while self.get_num_standing_tasks(): + if self.abort_signal.triggered: + self.info("Abort signal triggered. Finishing cross site validation.") + return + self.debug("Checking standing tasks to see if cross site validation finished.") + time.sleep(self._task_check_period) + + self.save_results() + + def _receive_local_model_cb(self, model: FLModel): + client_name = model.meta["client_name"] + + save_path = os.path.join(self._cross_val_models_dir, client_name) + fobs.dumpf(model, save_path) + + self.info(f"Saved client model {client_name} to {save_path}") + self._client_models[client_name] = save_path + + # Send this model to all clients to validate + self._send_validation_task(client_name, model) + + def _send_validation_task(self, model_name: str, model: FLModel): + self.info(f"Sending {model_name} model to all participating clients for validation.") + # Create validation task and broadcast to all participating clients. + model.meta[AppConstants.MODEL_OWNER] = model_name + + self.send_model( + task_name=AppConstants.TASK_VALIDATION, + data=model, + targets=self._participating_clients, + timeout=self._validation_timeout, + callback=self._receive_val_result_cb, + ) + + def _receive_val_result_cb(self, model: FLModel): + client_name = model.meta["client_name"] + model_owner = model.meta["props"].get(AppConstants.MODEL_OWNER, None) + + self.track_results(model_owner, client_name, model) + + file_path = os.path.join(self._cross_val_models_dir, client_name + "_" + model_owner) + fobs.dumpf(model, file_path) + + client_results = self._val_results.get(client_name, None) + if not client_results: + client_results = {} + self._val_results[client_name] = client_results + client_results[model_owner] = file_path + self.info(f"Saved validation result from client '{client_name}' on model '{model_owner}' in {file_path}") + + def track_results(self, model_owner, data_client, val_results: FLModel): + if not model_owner: + self.error("model_owner unknown. Validation result will not be saved to json") + if not data_client: + self.error("data_client unknown. Validation result will not be saved to json") + + if val_results: + try: + if data_client not in self._json_val_results: + self._json_val_results[data_client] = {} + self._json_val_results[data_client][model_owner] = val_results.metrics + + except Exception: + self.exception("Exception in handling validation result.") + else: + self.error("Validation result not found.", fire_event=False) + + def save_results(self): + cross_val_res_dir = os.path.join(self.get_run_dir(), self._results_dir) + if not os.path.exists(cross_val_res_dir): + os.makedirs(cross_val_res_dir) + + res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + with open(res_file_path, "w") as f: + json.dump(self._json_val_results, f) diff --git a/nvflare/cli.py b/nvflare/cli.py index 46491cb130..bf2c7800c6 100644 --- a/nvflare/cli.py +++ b/nvflare/cli.py @@ -31,6 +31,7 @@ create_poc_workspace_config, create_startup_kit_config, get_hidden_config, + print_hidden_config, save_config, ) @@ -116,14 +117,18 @@ def def_config_parser(sub_cmd): def handle_config_cmd(args): config_file_path, nvflare_config = get_hidden_config() - if not args.job_templates_dir or not os.path.isdir(args.job_templates_dir): - raise ValueError(f"job_templates_dir='{args.job_templates_dir}', it is not a directory") + if args.startup_kit_dir is None and args.poc_workspace_dir is None and args.job_templates_dir is None: + print(f"not specifying any directory. print existing config at {config_file_path}") + print_hidden_config(config_file_path, nvflare_config) + return nvflare_config = create_startup_kit_config(nvflare_config, args.startup_kit_dir) nvflare_config = create_poc_workspace_config(nvflare_config, args.poc_workspace_dir) nvflare_config = create_job_template_config(nvflare_config, args.job_templates_dir) save_config(nvflare_config, config_file_path) + print(f"new config at {config_file_path}") + print_hidden_config(config_file_path, nvflare_config) def parse_args(prog_name: str): diff --git a/nvflare/tool/poc/poc_commands.py b/nvflare/tool/poc/poc_commands.py index d1fcc5b274..f38394e7ab 100644 --- a/nvflare/tool/poc/poc_commands.py +++ b/nvflare/tool/poc/poc_commands.py @@ -13,7 +13,6 @@ # limitations under the License. import json import os -import pathlib import random import shutil import socket @@ -36,7 +35,7 @@ from nvflare.lighter.utils import load_yaml, update_project_server_name_config, update_storage_locations from nvflare.tool.api_utils import shutdown_system from nvflare.tool.poc.service_constants import FlareServiceConstants as SC -from nvflare.utils.cli_utils import hocon_to_string +from nvflare.utils.cli_utils import get_hidden_nvflare_config_path, get_or_create_hidden_nvflare_dir, hocon_to_string DEFAULT_WORKSPACE = "/tmp/nvflare/poc" DEFAULT_PROJECT_NAME = "example_project" @@ -396,7 +395,7 @@ def prepare_clients(clients, number_of_clients): def save_startup_kit_dir_config(workspace, project_name): - dst = get_hidden_nvflare_config_path() + dst = get_or_create_hidden_nvflare_config_path() config = None if os.path.isfile(dst): try: @@ -485,27 +484,17 @@ def _prepare_poc( return True -def get_home_dir(): - return Path.home() - - -def get_hidden_nvflare_config_path() -> str: +def get_or_create_hidden_nvflare_config_path() -> str: """ Get the path for the hidden nvflare configuration file. Returns: str: The path to the hidden nvflare configuration file. """ - home_dir = get_home_dir() - hidden_nvflare_dir = pathlib.Path(home_dir) / ".nvflare" - - try: - hidden_nvflare_dir.mkdir(exist_ok=True) - except OSError as e: - raise RuntimeError(f"Error creating the hidden nvflare directory: {e}") + hidden_nvflare_dir = get_or_create_hidden_nvflare_dir() - hidden_nvflare_config_file = hidden_nvflare_dir / "config.conf" - return str(hidden_nvflare_config_file) + hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(hidden_nvflare_dir)) + return hidden_nvflare_config_file def prepare_poc_provision( @@ -1077,7 +1066,7 @@ def get_poc_workspace(): poc_workspace = os.getenv("NVFLARE_POC_WORKSPACE") if not poc_workspace: - src_path = get_hidden_nvflare_config_path() + src_path = get_or_create_hidden_nvflare_config_path() if os.path.isfile(src_path): from pyhocon import ConfigFactory as CF diff --git a/nvflare/utils/cli_utils.py b/nvflare/utils/cli_utils.py index 3c3c2c412f..8fb72d9fab 100644 --- a/nvflare/utils/cli_utils.py +++ b/nvflare/utils/cli_utils.py @@ -40,7 +40,7 @@ def get_hidden_nvflare_config_path(hidden_nvflare_dir: str) -> str: return str(hidden_nvflare_config_file) -def create_hidden_nvflare_dir(): +def get_or_create_hidden_nvflare_dir(): hidden_nvflare_dir = get_hidden_nvflare_dir() if not hidden_nvflare_dir.exists(): try: @@ -70,7 +70,7 @@ def find_startup_kit_location() -> str: def load_hidden_config() -> ConfigTree: - hidden_dir = create_hidden_nvflare_dir() + hidden_dir = get_or_create_hidden_nvflare_dir() hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(hidden_dir)) nvflare_config = load_config(hidden_nvflare_config_file) return nvflare_config @@ -139,6 +139,7 @@ def create_job_template_config(nvflare_config: ConfigTree, job_templates_dir: Op return nvflare_config job_templates_dir = os.path.abspath(job_templates_dir) + check_dir(job_templates_dir) conf_str = f""" job_template {{ path = {job_templates_dir} @@ -243,7 +244,7 @@ def save_configs(app_configs: Dict[str, Tuple], keep_origin_format: bool = True) save_config(dst_config, dst_path, keep_origin_format) -def save_config(dst_config, dst_path, keep_origin_format: bool = True): +def save_config(dst_config: ConfigTree, dst_path, keep_origin_format: bool = True): if dst_path is None or dst_path.rindex(".") == -1: raise ValueError(f"configuration file path '{dst_path}' can't be None or has no extension") @@ -274,13 +275,20 @@ def save_config(dst_config, dst_path, keep_origin_format: bool = True): os.remove(dst_path) -def get_hidden_config(): - hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(create_hidden_nvflare_dir())) +def get_hidden_config() -> (str, ConfigTree): + hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(get_or_create_hidden_nvflare_dir())) conf = load_hidden_config() nvflare_config = CF.parse_string("{}") if not conf else conf return hidden_nvflare_config_file, nvflare_config +def print_hidden_config(dst_path: str, dst_config: ConfigTree): + original_ext = os.path.basename(dst_path).split(".")[1] + fmt = ConfigFormat.config_ext_formats().get(f".{original_ext}", None) + config_str = hocon_to_string(fmt, dst_config) + print(config_str) + + def find_in_list(arr: List, item) -> bool: if arr is None: return False