diff --git a/docs/getting_started.rst b/docs/getting_started.rst index cee26e8f9e..171a5e6c8c 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -304,7 +304,7 @@ we can install these in the Python virtual environment by running: .. code-block:: shell source nvflare-env/bin/activate - python3 -m pip install -r simulator-example/requirements.txt + python3 -m pip install -r simulator-example/hello-pt/requirements.txt If using the Dockerfile above to run in a container, these dependencies have already been installed. diff --git a/docs/programming_guide/controllers/model_controller.rst b/docs/programming_guide/controllers/model_controller.rst new file mode 100644 index 0000000000..26215dad90 --- /dev/null +++ b/docs/programming_guide/controllers/model_controller.rst @@ -0,0 +1,225 @@ +.. _model_controller: + +################### +ModelController API +################### + +The FLARE :mod:`ModelController` API provides an easy way for users to write and customize FLModel-based controller workflows. + +* Highly flexible with a simple API (run routine and basic communication and utility functions) +* :ref:`fl_model`for the communication data structure, everything else is pure Python +* Option to support pre-existing components and FLARE-specific functionalities + +.. note:: + + The ModelController API is a high-level API meant to simplify writing workflows. + If users prefer or need the full flexibility of the Controller with all the capabilites of FLARE functions, refer to the :ref:`controllers`. + + +Core Concepts +============= + +As an example, we can take a look at the popular federated learning workflow, "FedAvg" which has the following steps: + +#. FL server initializes an initial model +#. For each round (global iteration): + + #. FL server sends the global model to clients + #. Each FL client starts with this global model and trains on their own data + #. Each FL client sends back their trained model + #. FL server aggregates all the models and produces a new global model + + +To implement this workflow using the ModelController there are a few essential parts: + +* Import and subclass the :class:`nvflare.app_common.workflows.model_controller.ModelController`. +* Implement the ``run()`` routine for the workflow logic. +* Utilize ``send_model()`` / ``send_model_and_wait()`` for communication to send tasks with FLModel to target clients, and receive FLModel results. +* Customize workflow using predefined utility functions and components, or implement your own logics. + + +Here is an example of the FedAvg workflow using the :class:`BaseFedAvg` base class: + +.. code-block:: python + + # BaseFedAvg subclasses ModelController and defines common functions and variables such as aggregate(), update_model(), self.start_round, self.num_rounds + class FedAvg(BaseFedAvg): + + # run routine that user must implement + def run(self) -> None: + self.info("Start FedAvg.") + + # load model (by default uses persistor, can provide custom method) + model = self.load_model() + model.start_round = self.start_round + model.total_rounds = self.num_rounds + + # for each round (global iteration) + for self.current_round in range(self.start_round, self.start_round + self.num_rounds): + self.info(f"Round {self.current_round} started.") + model.current_round = self.current_round + + # obtain self.num_clients clients + clients = self.sample_clients(self.num_clients) + + # send model to target clients with default train task, wait to receive results + results = self.send_model_and_wait(targets=clients, data=model) + + # use BaseFedAvg aggregate function + aggregate_results = self.aggregate( + results, aggregate_fn=self.aggregate_fn + ) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel] + + # update global model with agggregation results + model = self.update_model(model, aggregate_results) + + # save model (by default uses persistor, can provide custom method) + self.save_model(model) + + self.info("Finished FedAvg.") + + +Below is a comprehensive table overview of the :class:`ModelController` API: + + +.. list-table:: ModelController API + :widths: 25 35 50 + :header-rows: 1 + + * - API + - Description + - API Doc Link + * - run + - Run routine for workflow. + - :func:`run` + * - send_model_and_wait + - Send a task with data to targets (blocking) and wait for results.. + - :func:`send_model_and_wait` + * - send_model + - Send a task with data to targets (non-blocking) with callback. + - :func:`send_model` + * - sample_clients + - Returns a list of num_clients clients. + - :func:`sample_clients` + * - save_model + - Save model with persistor. + - :func:`save_model` + * - load_model + - Load model from persistor. + - :func:`load_model` + + +Communication +============= + +The ModelController uses a task based communication where tasks are sent to targets, and targets execute the tasks and return results. +The :ref:`fl_model` is standardized data structure object that is sent along with each task, and :ref:`fl_model` responses are received for the results. + +.. note:: + + The :ref:`fl_model` object can be any type of data depending on the specific task. + For example, in the "train" and "validate" tasks we send the model parameters along with the task so the target clients can train and validate the model. + However in many other tasks that do not involve sending the model (e.g. "submit_model"), the :ref:`fl_model` can contain any type of data (e.g. metadata, metrics etc.) or may be not be needed at all. + + +send_model_and_wait +------------------- +:func:`send_model_and_wait` is the core communication function which enables users to send tasks to targets, and wait for responses. + +The ``data`` is an :ref:`fl_model` object, and the ``task_name`` is the task for the target executors to execute (Client API executors by default support "train", "validate", and "submit_model", however executors can be written for any arbitrary task name). + +``targets`` can be chosen from client names obtained with ``sample_clients()``. + +Returns the :ref:`fl_model` responses from the target clients once the task is completed (``min_responses`` have been received, or ``timeout`` time has passed). + +send_model +---------- +:func:`send_model` is the non-blocking version of +:func:`send_model_and_wait` with a user-defined callback when receiving responses. + +A callback with the signature ``Callable[[FLModel], None]`` can be passed in, which will be called when a response is received from each target. + +The task is standing until either ``min_responses`` have been received, or ``timeout`` time has passed. +Since this call is asynchronous, the Controller :func:`get_num_standing_tasks` method can be used to get the number of standing tasks for synchronization purposes. + + +Saving & Loading +================ + +persistor +--------- +The :func:`save_model` and :func:`load_model` +functions utilize the configured :class:`ModelPersistor` set in the ModelController ``persistor_id: str = "persistor"`` init argument. + +custom save & load +------------------ +Users can also choose to instead create their own custom save and load functions rather than use a persistor. + +For example we can use PyTorch's save and load functions for the model parameters, and save the FLModel metadata with :mod:`FOBS` separately to different filepaths. + +.. code-block:: python + + import torch + from nvflare.fuel.utils import fobs + + class MyController(ModelController): + ... + def save_model(self, model, filepath=""): + params = model.params + # PyTorch save + torch.save(params, filepath) + + # save FLModel metadata + model.params = {} + fobs.dumpf(model, filepath + ".metadata") + model.params = params + + def load_model(self, filepath=""): + # PyTorch load + params = torch.load(filepath) + + # load FLModel metadata + model = fobs.loadf(filepath + ".metadata") + model.params = params + return model + + +Note: for non-primitive data types such as ``torch.nn.Module`` (used for the initial PyTorch model), we must configure a corresponding FOBS decomposer for serialization and deserialization. +Read more at :github_nvflare_link:`Flare Object Serializer (FOBS) `. + +.. code-block:: python + + from nvflare.app_opt.pt.decomposers import TensorDecomposer + + fobs.register(TensorDecomposer) + + +Additional Functionalities +========================== + +In some cases, more advanced FLARE-specific functionalities may be of use. + +The :mod:`BaseModelController` class provides access to the engine ``self.engine`` and FLContext ``self.fl_ctx`` if needed. +Functions such as ``get_component()`` and ``build_component()`` can be used to load or dynamically build components. + +Furthermore, the underlying :mod:`Controller` class offers additional communication functions and task related utilities. +Many of our pre-existing workflows are based on this lower-level Controller API. +For more details refer to the :ref:`controllers` section. + +Examples +======== + +Examples of basic workflows using the ModelController API: + +* :github_nvflare_link:`Cyclic ` +* :github_nvflare_link:`BaseFedAvg ` +* :github_nvflare_link:`FedAvg ` + +Advanced examples: + +* :github_nvflare_link:`Scaffold ` +* :github_nvflare_link:`FedOpt ` +* :github_nvflare_link:`PTFedAvgEarlyStopping ` +* :github_nvflare_link:`Kaplan-Meier ` +* :github_nvflare_link:`Logistic Regression Newton Raphson ` +* :github_nvflare_link:`FedBPT ` diff --git a/docs/programming_guide/fl_model.rst b/docs/programming_guide/fl_model.rst index 6b2a9bad07..702af3a4de 100644 --- a/docs/programming_guide/fl_model.rst +++ b/docs/programming_guide/fl_model.rst @@ -3,7 +3,7 @@ FLModel ======= -We define a standard data structure :mod:`FLModel` +We define a standard data structure :mod:`FLModel` that captures the common attributes needed for exchanging learning results. This is particularly useful when NVFlare system needs to exchange learning @@ -14,4 +14,4 @@ information from received FLModel, run local training, and put the results in a new FLModel to be sent back. For a detailed explanation of each attributes, please refer to the API doc: -:mod:`FLModel` +:mod:`FLModel` diff --git a/docs/programming_guide/workflows_and_controllers.rst b/docs/programming_guide/workflows_and_controllers.rst index 9a75c9901d..8b8bd6ce24 100644 --- a/docs/programming_guide/workflows_and_controllers.rst +++ b/docs/programming_guide/workflows_and_controllers.rst @@ -7,16 +7,49 @@ A workflow has one or more controllers, each implementing a specific coordinatio CrossSiteValidation controller implements a strategy to let every client site evaluate every other site's model. You can put together a workflow that uses any number of controllers. -We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side :ref:`controllers `. +We provide the FLModel-based :ref:`model_controller` which provides a straightforward way for users to write controllers. +We also have the original :ref:`Controller API ` with more FLARE-specific functionalities, which many of our existing workflows are based upon. + +We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side controllers. In these workflows, FL clients get tasks assigned by the controller, execute the tasks, and submit results back to the server. In certain cases, if the server cannot be trusted, it should not be involved in communication with sensitive information. To address this concern, NVFlare introduces Client Controlled Workflows (CCWF) to facilitate peer-to-peer communication among clients. -Please refer to the following sections for more details. + +Controllers can be configured in ``config_fed_server.json`` in the workflows section: + +.. code-block:: json + + workflows = [ + { + id = "fedavg_ctl", + name = "FedAvg", + args { + min_clients = 2, + num_rounds = 3, + persistor_id = "persistor" + } + } + ] + +To configure controllers using the JobAPI, define the controller and send it to the server. +This code will automatically generate the server configuration for the controller: + +.. code-block:: python + + controller = FedAvg( + num_clients=2, + num_rounds=3, + persistor_id = "persistor" + ) + job.to(controller, "server") + +Please refer to the following sections for more details about the different types of controllers. .. toctree:: :maxdepth: 3 + controllers/model_controller controllers/controllers controllers/client_controlled_workflows diff --git a/examples/getting_started/pt/fedavg_script_executor_cifar10_all.py b/examples/getting_started/pt/fedavg_script_executor_cifar10_all.py new file mode 100644 index 0000000000..af15043785 --- /dev/null +++ b/examples/getting_started/pt/fedavg_script_executor_cifar10_all.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.net import Net + +from nvflare import FedAvg, FedJob, ScriptExecutor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/cifar10_fl.py" + + job = FedJob(name="cifar10_fedavg") + + # Define the controller workflow and send to server + controller = FedAvg( + num_clients=n_clients, + num_rounds=num_rounds, + ) + job.to_server(controller) + + # Define the initial global model and send to server + job.to_server(Net()) + + # Send executor to all clients + executor = ScriptExecutor( + task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to_clients(executor) + + # job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients) diff --git a/examples/getting_started/pt/swarm_script_executor_cifar10.py b/examples/getting_started/pt/swarm_script_executor_cifar10.py index 4a04f10f40..adf59091c5 100644 --- a/examples/getting_started/pt/swarm_script_executor_cifar10.py +++ b/examples/getting_started/pt/swarm_script_executor_cifar10.py @@ -47,19 +47,22 @@ executor = ScriptExecutor(task_script_path=train_script) job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"]) - client_controller = SwarmClientController() - job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) - - client_controller = CrossSiteEvalClientController() - job.to(client_controller, f"site-{i}", tasks=["cse_*"]) - # In swarm learning, each client acts also as an aggregator aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) - job.to(aggregator, f"site-{i}") # In swarm learning, each client uses a model persistor and shareable_generator - job.to(PTFileModelPersistor(model=Net()), f"site-{i}") - job.to(SimpleModelShareableGenerator(), f"site-{i}") + persistor = PTFileModelPersistor(model=Net()) + shareable_generator = SimpleModelShareableGenerator() + + client_controller = SwarmClientController( + aggregator_id=job.as_id(aggregator), + persistor_id=job.as_id(persistor), + shareable_generator_id=job.as_id(shareable_generator), + ) + job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) + + client_controller = CrossSiteEvalClientController() + job.to(client_controller, f"site-{i}", tasks=["cse_*"]) # job.export_job("/tmp/nvflare/jobs/job_config") job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/getting_started/tf/README.md b/examples/getting_started/tf/README.md index 87014c00cf..beafde808e 100644 --- a/examples/getting_started/tf/README.md +++ b/examples/getting_started/tf/README.md @@ -1,22 +1,33 @@ # Simulated Federated Learning with CIFAR10 Using Tensorflow -This example shows how to develop and run classic Federated Learning -algorithms, namely FedAvg, FedProx, FedOpt and Scaffold on CIFAR10, -using `Tensorflow` backend. This example is analogous to [the example -using `Pytorch` +This example shows `Tensorflow`-based classic Federated Learning +algorithms, namely FedAvg and FedOpt on CIFAR10 +dataset. This example is analogous to [the example using `Pytorch` backend](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-sim) on the same dataset, where same experiments -will be conducted and analyzed. You should expect the same +were conducted and analyzed. You should expect the same experimental results when comparing this example with the `Pytorch` one. -In this example, client-side training logics are implemented using the -new `Script Executor` APIs, which alleviate the need of job -config files, simplifying client-side development. - +In this example, the latest Client APIs were used to implement +client-side training logics (details in file +[`cifar10_tf_fl_alpha_split.py`](src/cifar10_tf_fl_alpha_split.py)), +and the new +[`FedJob`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/job_config/fed_job.py#L106) +APIs were used to programmatically set up an +`nvflare` job to be exported or ran by simulator (details in file +[`tf_fl_script_executor_cifar10.py`](tf_fl_script_executor_cifar10.py)), +alleviating the need of writing job config files, simplifying +development process. + +Before continuing with the following sections, you can first refer to +the [getting started notebook](nvflare_tf_getting_started.ipynb) +included under this folder, to learn more about the implementation +details, with an example walkthrough of FedAvg using a small +Tensorflow model. ## 1. Install requirements -Install required packages for training +Install required packages ``` pip install --upgrade pip pip install -r ./requirements.txt @@ -28,43 +39,62 @@ pip install -r ./requirements.txt ## 2. Run experiments -This example uses simulator to run all experiments. A script -`run_jobs.sh` is provided to run experiments described below all at -once: +This example uses simulator to run all experiments. The script +[`tf_fl_script_executor_cifar10.py`](tf_fl_script_executor_cifar10.py) +is the main script to be used to launch different experiments with +different arguments (see sections below for details). A script +[`run_jobs.sh`](run_jobs.sh) is also provided to run all experiments +described below at once: ``` bash ./run_jobs.sh ``` +The CIFAR10 dataset will be downloaded when running any experiment for +the first time. `Tensorboard` summary logs will be generated during +any experiment, and you can use `Tensorboard` to visualize the +training and validation process as the experiment runs. Data split +files, summary logs and results will be saved in a workspace +directory, which defaults to `/tmp` and can be configured by setting +`--workspace` argument of the `tf_fl_script_executor_cifar10.py` +script. + +> [!WARNING] +> If you are using GPU, make sure to set the following +> environment variables before running a training job, to prevent +> `Tensoflow` from allocating full GPU memory all at once: +> `export TF_FORCE_GPU_ALLOW_GROWTH=true && export +> TF_GPU_ALLOCATOR=cuda_malloc_asyncp` The set-up of all experiments in this example are kept the same as [the example using `Pytorch` backend](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-sim). Refer to the `Pytorch` example for more details. Similar to the Pytorch example, we here also use Dirichelet sampling on CIFAR10 data labels -to simulate data heterogeneity from different client sites, which is controlled by an alpha -value, ranging from 0 (not including 0) to 1. A high alpha value indicates less data -heterogeneity, while alpha = 0 here is used to disable Dirichelet -sampling, a.k.a. resulting in homogeneous data distribution among -different sites. +to simulate data heterogeneity among data splits for different client +sites, controlled by an alpha value, ranging from 0 (not including 0) +to 1. A high alpha value indicates less data heterogeneity, i.e., an +alpha value equal to 1.0 would result in homogeneous data distribution +among different splits. ### 2.1 Centralized training To simulate a centralized training baseline, we run FedAvg algorithm -with 1 client for 25 local epochs, for one single round. +with 1 client for 25 rounds, where each round consists of one single epoch. ``` python ./tf_fl_script_executor_cifar10.py \ --algo centralized \ --n_clients 1 \ - --num_rounds 1 \ + --num_rounds 25 \ --batch_size 64 \ - --epochs 25 \ + --epochs 1 \ --alpha 0.0 ``` -Note, here `--alpha 0.0` means that no heterogeneous data splits are being generated. +Note, here `--alpha 0.0` is a placeholder value used to disable data +splits for centralized training. ### 2.2 FedAvg with different data heterogeneity (alpha values) -Here we run FedAvg for 50 rounds, each with 4 local epochs. This +Here we run FedAvg for 50 rounds, each round with 4 local epochs. This corresponds roughly to the same number of iterations across clients as in the centralized baseline above (50*4 divided by 8 clients is 25): ``` @@ -81,64 +111,72 @@ for alpha in 1.0 0.5 0.3 0.1; do done ``` -### 2.3 Advanced FL algorithms (FedProx, FedOpt and SCAFFOLD) +### 2.3 Advanced FL algorithms (FedOpt) Next, let's try some different FL algorithms on a more heterogeneous split: -[FedProx](https://arxiv.org/abs/1812.06127) adds a regularizer to the loss: +[FedOpt](https://arxiv.org/abs/2003.00295) uses optimizers on server +side to update the global model from client-side gradients. Here we +use SGD with momentum and cosine learning rate decay: ``` python ./tf_fl_script_executor_cifar10.py \ - --algo fedprox \ + --algo fedopt \ --n_clients 8 \ --num_rounds 50 \ --batch_size 64 \ --epochs 4 \ - --fedprox_mu 1e-5 \ --alpha 0.1 ``` - -For [FedOpt](https://arxiv.org/abs/2003.00295), here we use SGD with momentum and cosine learning rate -decay to update server-side model: +[FedProx](https://arxiv.org/abs/1812.06127) adds a regularizer to the loss: ``` python ./tf_fl_script_executor_cifar10.py \ - --algo fedopt \ + --algo fedprox \ --n_clients 8 \ --num_rounds 50 \ --batch_size 64 \ --epochs 4 \ + --fedprox_mu 1e-5 \ --alpha 0.1 ``` - [SCAFFOLD](https://arxiv.org/abs/1910.06378) adds a correction term during local training following the [implementation](https://github.com/Xtra-Computing/NIID-Bench) as described in [Li et al.](https://arxiv.org/abs/2102.02079) - +``` +python ./tf_fl_script_executor_cifar10.py \ + --algo scaffold \ + --n_clients 8 \ + --num_rounds 50 \ + --batch_size 64 \ + --epochs 4 \ + --fedprox_mu 1e-5 \ + --alpha 0.1 +``` ## 3. Results -Now let's compare experimental results. For all experiments, you can -use `Tensorboard` to visualize the training and validation process as -the experiment is running. +Now let's compare experimental results. +### 3.1 Centralized training vs. FedAvg for homogeneous split +Let's first compare FedAvg with homogeneous data split +(i.e. `alpha=1.0`) and centralized training. As can be seen from the +figure and table below, FedAvg can achieve similar performance to +centralized training under homogeneous data split, i.e., when there is +no difference in data distributions among different clients. -### 3.1 Central vs. FedAvg -With a data split using `alpha=1.0`, i.e. a non-heterogeneous split, -we achieve the following final validation scores. -One can see that FedAvg can achieve similar performance to central training. - -| Config | Alpha | Val score | -| ----------- | ----------- | ----------- | -| cifar10_central | 1.0 | 0.8756 | -| cifar10_fedavg | 1.0 | 0.8525 | +| Config | Alpha | Val score | +|-----------------|-------|-----------| +| cifar10_central | n.a. | 0.8758 | +| cifar10_fedavg | 1.0 | 0.8839 | ![Central vs. FedAvg](./figs/fedavg-vs-centralized.png) ### 3.2 Impact of client data heterogeneity Here we compare the impact of data heterogeneity by varying the -`alpha` value, where lower values cause higher heterogeneity. This can -be observed in the resulting performance of the FedAvg algorithms. +`alpha` value, where lower values cause higher heterogeneity. As can +be observed in the table below, performance of the FedAvg decreases +as data heterogeneity becomes higher. | Config | Alpha | Val score | | ----------- | ----------- | ----------- | @@ -152,26 +190,20 @@ heterogeneity](./figs/fedavg-diff-alphas.png) ### 3.3 Impact of different FL algorithms -Lastly we compare the performance of different FL algorithms, with -`alpha` value fixed to 0.1, indicating a high client data -heterogeneity. We can observe from the figure below that, FedProx and +Lastly, we compare the performance of different FL algorithms, with +`alpha` value fixed to 0.1, i.e., a high client data heterogeneity. We can observe from the figure below that, FedProx and SCAFFOLD achieve better performance, with better convergence rates compared to FedAvg and FedProx with the same alpha setting. SCAFFOLD achieves that by adding a correction term when updating the client models, while FedOpt utilizes SGD with momentum to update the global model on the server. Both achieve better performance with the same number of training steps as FedAvg/FedProx. -If the distribution of the data is too uneven, problems can also arise with the SCAFFOLD implementation using TensorFlow. -However, these only become noticeable after too many training rounds and epochs. At a certain point, it no longer converges, and the model weights explode, causing weights and losses to become NaN. This problem only occurred with a distribution of alpha = 0.1 and is a well known [issue](https://discuss.ai.google.dev/t/failing-to-implement-scaffold-in-tensorfow/31665). -To counteract this problem, a so called [clip norm](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam) was added. -This ensures that the gradient of each weight is clipped individually so that its norm does not exceed the specified value. Since this is a hyperparameter, clipnorm=1.5 was tested as the best value in our case. The specific use of the clip norm is based on this [code](https://github.com/google-research/public-data-in-dpfl/blob/7655a6b7165dc2cbdfe3d2e1080721223aa2c79b/scaffold_v2.py#L192C5-L192C14). -This allows the accuracy to converge until the end of the training and achieves an as good accuracy as FedOpt. | Config | Alpha | Val score | | ----------- | ----------- | ----------- | -| cifar10_fedavg | 0.1 | 0.7959 | +| cifar10_fedavg | 0.1 | 0.7903 | +| cifar10_fedopt | 0.1 | 0.8145 | | cifar10_fedprox | 0.1 | 0.7843 | -| cifar10_fedopt | 0.1 | 0.8059 | | cifar10_scaffold | 0.1 | 0.8164 | ![Impact of different FL algorithms](./figs/fedavg-diff-algos-new.png) diff --git a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb index af87d0af9b..17374b9356 100644 --- a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb +++ b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb @@ -55,7 +55,7 @@ "outputs": [], "source": [ "! pip install --ignore-installed blinker\n", - "! pip install nvflare~=2.5.0rc tensorflow" + "! pip install -r ./requirements.txt" ] }, { @@ -410,6 +410,16 @@ "source": [ "! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_tf_fedavg" ] + }, + { + "cell_type": "markdown", + "id": "387662f4-7d05-4840-bcc7-a2523e03c2c2", + "metadata": {}, + "source": [ + "#### 8. Next steps\n", + "\n", + "Continue with the steps described in the [README.md](README.md) to run more experiments with a more complex model and more advanced FL algorithms. " + ] } ], "metadata": { @@ -428,7 +438,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/getting_started/tf/run_jobs.sh b/examples/getting_started/tf/run_jobs.sh index 0a9b8997f4..7195e71474 100755 --- a/examples/getting_started/tf/run_jobs.sh +++ b/examples/getting_started/tf/run_jobs.sh @@ -1,4 +1,18 @@ #!/usr/bin/env bash +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + export TF_FORCE_GPU_ALLOW_GROWTH=true export TF_GPU_ALLOCATOR=cuda_malloc_asyncp @@ -7,16 +21,20 @@ export TF_GPU_ALLOCATOR=cuda_malloc_asyncp # You can change GPU index if multiple GPUs are available GPU_INDX=0 +# You can change workspace - where results and artefact will be saved. +WORKSPACE=/tmp # Run centralized training job python ./tf_fl_script_executor_cifar10.py \ --algo centralized \ --n_clients 1 \ - --num_rounds 1 \ + --num_rounds 25 \ --batch_size 64 \ - --epochs 25 \ + --epochs 1 \ --alpha 0.0 \ - --gpu $GPU_INDX + --gpu $GPU_INDX \ + --workspace $WORKSPACE + # Run FedAvg with different alpha values for alpha in 1.0 0.5 0.3 0.1; do @@ -28,23 +46,12 @@ for alpha in 1.0 0.5 0.3 0.1; do --batch_size 64 \ --epochs 4 \ --alpha $alpha \ - --gpu $GPU_INDX + --gpu $GPU_INDX \ + --workspace $WORKSPACE done -# Run FedProx job. -python ./tf_fl_script_executor_cifar10.py \ - --algo fedprox \ - --n_clients 8 \ - --num_rounds 50 \ - --batch_size 64 \ - --epochs 4 \ - --fedprox_mu 1e-5 \ - --alpha 0.1 \ - --gpu $GPU_INDX - - # Run FedOpt job python ./tf_fl_script_executor_cifar10.py \ --algo fedopt \ @@ -53,14 +60,5 @@ python ./tf_fl_script_executor_cifar10.py \ --batch_size 64 \ --epochs 4 \ --alpha 0.1 \ - --gpu $GPU_INDX - -# Run scaffold job -python ./tf_fl_script_executor_cifar10.py \ - --algo scaffold \ - --n_clients 8 \ - --num_rounds 50 \ - --batch_size 64 \ - --epochs 4 \ - --alpha 0.1 \ - --gpu $GPU_INDX + --gpu $GPU_INDX \ + --workspace $WORKSPACE diff --git a/examples/getting_started/tf/src/cifar10_data_split.py b/examples/getting_started/tf/src/cifar10_data_split.py index 1eec679b9b..1dd05f6385 100644 --- a/examples/getting_started/tf/src/cifar10_data_split.py +++ b/examples/getting_started/tf/src/cifar10_data_split.py @@ -37,6 +37,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + import json import os diff --git a/examples/getting_started/tf/src/cifar10_tf_fl.py b/examples/getting_started/tf/src/cifar10_tf_fl.py new file mode 100644 index 0000000000..5058a4025b --- /dev/null +++ b/examples/getting_started/tf/src/cifar10_tf_fl.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf +from tensorflow.keras import datasets +from tf_net import TFNet + +# (1) import nvflare client API +import nvflare.client as flare + +PATH = "./tf_model.weights.h5" + + +def main(): + # (2) initializes NVFlare client API + flare.init() + + (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() + + # Normalize pixel values to be between 0 and 1 + train_images, test_images = train_images / 255.0, test_images / 255.0 + + model = TFNet() + model.build(input_shape=(None, 32, 32, 3)) + model.compile( + optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + ) + model.summary() + + # (3) gets FLModel from NVFlare + while flare.is_running(): + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (optional) print system info + system_info = flare.system_info() + print(f"NVFlare system info: {system_info}") + + # (4) loads model from NVFlare + for k, v in input_model.params.items(): + model.get_layer(k).set_weights(v) + + # (5) evaluate aggregated/received model + _, test_global_acc = model.evaluate(test_images, test_labels, verbose=2) + print( + f"Accuracy of the received model on round {input_model.current_round} on the {len(test_images)} test images: {test_global_acc * 100} %" + ) + + model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels)) + + print("Finished Training") + + model.save_weights(PATH) + + _, test_acc = model.evaluate(test_images, test_labels, verbose=2) + print(f"Accuracy of the model on the {len(test_images)} test images: {test_acc * 100} %") + + # (6) construct trained FL model (A dict of {layer name: layer weights} from the keras model) + output_model = flare.FLModel( + params={layer.name: layer.get_weights() for layer in model.layers}, metrics={"accuracy": test_global_acc} + ) + # (7) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py index be81af4620..fbad375336 100644 --- a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py +++ b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ import argparse -import copy import numpy as np import tensorflow as tf @@ -23,10 +22,6 @@ # (1) import nvflare client API import nvflare.client as flare -from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss - -# (optional) metrics -from nvflare.client.tracking import SummaryWriter PATH = "./tf_model.weights.h5" @@ -65,7 +60,7 @@ def preprocess_dataset(dataset, is_training, batch_size=1): Tensorflow Dataset with pre-processings applied. """ - # Values from: https://github.com/NVIDIA/NVFlare/blob/fc2bc47889b980c8de37de5528e3d07e6b1a942e/examples/advanced/cifar10/pt/learners/cifar10_model_learner.py#L147 + # Values from: https://github.com/NVIDIA/NVFlare/blob/main/examples/advanced/cifar10/pt/learners/cifar10_model_learner.py#L147 mean_cifar10 = tf.constant([125.3, 123.0, 113.9], dtype=tf.float32) std_cifar10 = tf.constant([63.0, 62.1, 66.7], dtype=tf.float32) @@ -108,21 +103,13 @@ def main(): parser.add_argument("--batch_size", type=int, required=True) parser.add_argument("--epochs", type=int, required=True) parser.add_argument("--train_idx_path", type=str, required=True) - parser.add_argument( - "--fedprox_mu", - type=float, - default=0.0, - ) args = parser.parse_args() - # (2) initializes NVFlare client API - flare.init() - (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() # Use alpha-split per-site data to simulate data heteogeniety, # only if if train_idx_path is not None. - + # if args.train_idx_path != "None": print(f"Loading train indices from {args.train_idx_path}") @@ -149,10 +136,12 @@ def main(): model = ModerateTFNet() model.build(input_shape=(None, 32, 32, 3)) - callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs_keras", write_graph=False)] - - # Control whether FedProx is used. + # Tensorboard logs for each local training epoch + callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs/epochs", write_graph=False)] + # Tensorboard logs for each aggregation run + tf_summary_writer = tf.summary.create_file_writer(logdir="./logs/rounds") + # Define loss function. loss = losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9), loss=loss, metrics=["accuracy"]) @@ -161,8 +150,6 @@ def main(): # (2) initializes NVFlare client API flare.init() - summary_writer = SummaryWriter() - tf_summary_writer = tf.summary.create_file_writer(logdir="./logs/validation") while flare.is_running(): # (3) receives FLModel from NVFlare input_model = flare.receive() @@ -176,20 +163,8 @@ def main(): for k, v in input_model.params.items(): model.get_layer(k).set_weights(v) - if args.fedprox_mu > 0: - - local_model_weights = model.trainable_variables - global_model_weights = copy.deepcopy(model.trainable_variables) - model.loss = TFFedProxLoss(local_model_weights, global_model_weights, args.fedprox_mu, loss) - elif args.fedprox_mu < 0.0: - - raise ValueError("mu should be no less than 0.0") - # (5) evaluate aggregated/received model _, test_global_acc = model.evaluate(x=test_ds, verbose=2) - summary_writer.add_scalar( - tag="global_model_accuracy", scalar=test_global_acc, global_step=input_model.current_round - ) with tf_summary_writer.as_default(): tf.summary.scalar("global_model_accuracy", test_global_acc, input_model.current_round) @@ -207,7 +182,7 @@ def main(): validation_data=test_ds, callbacks=callbacks, initial_epoch=start_epoch, - validation_freq=1, # args.epochs + validation_freq=1, ) print("Finished Training") @@ -216,8 +191,6 @@ def main(): _, test_acc = model.evaluate(x=test_ds, verbose=2) - summary_writer.add_scalar(tag="local_model_accuracy", scalar=test_acc, global_step=input_model.current_round) - with tf_summary_writer.as_default(): tf.summary.scalar("local_model_accuracy", test_acc, input_model.current_round) print(f"Accuracy of the model on the {len(test_images)} test images: {test_acc * 100} %") diff --git a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py index 860f472973..24fba56990 100644 --- a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py +++ b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py @@ -14,6 +14,7 @@ import argparse +import copy import numpy as np import tensorflow as tf @@ -25,6 +26,7 @@ from nvflare.app_common.app_constant import AlgorithmConstants from nvflare.app_opt.tf.scaffold import ScaffoldCallback, TFScaffoldHelper, get_lr_values from nvflare.client.tracking import SummaryWriter +from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss PATH = "./tf_model.weights.h5" @@ -129,6 +131,7 @@ def main(): parser.add_argument("--epochs", type=int, required=True) parser.add_argument("--train_idx_path", type=str, required=True) parser.add_argument("--clip_norm", type=float, default=1.55, required=False) + parser.add_argument("--fedprox_mu", type=float, default=0.0) args = parser.parse_args() @@ -195,6 +198,16 @@ def main(): for k, v in input_model.params.items(): model.get_layer(k).set_weights(v) + if args.fedprox_mu > 0: + + local_model_weights = model.trainable_variables + global_model_weights = copy.deepcopy(model.trainable_variables) + model.loss = TFFedProxLoss(local_model_weights, global_model_weights, args.fedprox_mu, loss) + elif args.fedprox_mu < 0.0: + + raise ValueError("mu should be no less than 0.0") + + # (step 4) load regularization parameters from scaffold global_ctrl_weights = input_model.meta.get(AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL) diff --git a/examples/getting_started/tf/src/tf_net.py b/examples/getting_started/tf/src/tf_net.py index 86b7a7f9a4..bdd8d29147 100644 --- a/examples/getting_started/tf/src/tf_net.py +++ b/examples/getting_started/tf/src/tf_net.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/getting_started/tf/tf_fl_script_executor_cifar10.py b/examples/getting_started/tf/tf_fl_script_executor_cifar10.py index 540fef39bd..6037bf0d69 100644 --- a/examples/getting_started/tf/tf_fl_script_executor_cifar10.py +++ b/examples/getting_started/tf/tf_fl_script_executor_cifar10.py @@ -30,8 +30,10 @@ CENTRALIZED_ALGO = "centralized" FEDAVG_ALGO = "fedavg" FEDOPT_ALGO = "fedopt" -FEDPROX_ALGO = "fedprox" SCAFFOLD_ALGO = "scaffold" +FEDPROX_ALGO = "fedprox" + + if __name__ == "__main__": @@ -71,6 +73,11 @@ type=float, default=1.0, ) + parser.add_argument( + "--workspace", + type=str, + default="/tmp", + ) parser.add_argument( "--gpu", type=int, @@ -80,30 +87,22 @@ args = parser.parse_args() multiprocessing.set_start_method("spawn") - supported_algos = ( - CENTRALIZED_ALGO, - FEDAVG_ALGO, - FEDOPT_ALGO, - FEDPROX_ALGO, - SCAFFOLD_ALGO, - ) + supported_algos = (CENTRALIZED_ALGO, FEDAVG_ALGO, FEDOPT_ALGO, SCAFFOLD_ALGO, FEDPROX_ALGO) if args.algo not in supported_algos: raise ValueError(f"--algo should be one of: {supported_algos}, got: {args.algo}") train_script = "src/cifar10_tf_fl_alpha_split.py" - train_split_root = f"/tmp/cifar10_splits/clients{args.n_clients}_alpha{args.alpha}" # avoid overwriting results + train_split_root = ( + f"{args.workspace}/cifar10_splits/clients{args.n_clients}_alpha{args.alpha}" # avoid overwriting results + ) # Prepare data splits if args.alpha > 0.0: # Do alpha splitting if alpha value > 0.0 print(f"preparing CIFAR10 and doing alpha split with alpha = {args.alpha}") - train_idx_paths = cifar10_split( - num_sites=args.n_clients, - alpha=args.alpha, - split_dir=train_split_root, - ) + train_idx_paths = cifar10_split(num_sites=args.n_clients, alpha=args.alpha, split_dir=train_split_root) print(train_idx_paths) else: @@ -131,7 +130,6 @@ num_clients=args.n_clients, num_rounds=args.num_rounds, ) - elif args.algo == FEDPROX_ALGO: from nvflare import FedAvg @@ -158,11 +156,11 @@ # Add clients for i, train_idx_path in enumerate(train_idx_paths): curr_task_script_args = task_script_args + f" --train_idx_path {train_idx_path}" - executor = ScriptExecutor( - task_script_path=train_script, - task_script_args=curr_task_script_args, - ) + executor = ScriptExecutor(task_script_path=train_script, task_script_args=curr_task_script_args) job.to(executor, f"site-{i+1}", gpu=args.gpu) - # job.export_job("/tmp/nvflare/jobs/job_config") - job.simulator_run(f"/tmp/nvflare/jobs/{job.name}") + # Can export current job to folder. + # job.export_job(f"{args.workspace}/nvflare/jobs/job_config") + + # Here we launch the job using simulator. + job.simulator_run(f"{args.workspace}/nvflare/jobs/{job.name}") diff --git a/examples/hello-world/hello-flower/README.md b/examples/hello-world/hello-flower/README.md new file mode 100644 index 0000000000..a0c303a832 --- /dev/null +++ b/examples/hello-world/hello-flower/README.md @@ -0,0 +1,28 @@ +# Flower App (PyTorch) in NVIDIA FLARE + +In this example, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator. + +## Preconditions + +To run Flower code in NVFlare, we created a job, including an app with the following custom folder content +```bash +$ tree jobs/hello-flwr-pt +. +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── task.py # <-- task-specific code (model, data) +``` +Note, this code is directly copied from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example. + +## Install dependencies +To run this job with NVFlare, we first need to install the dependencies. +```bash +pip install -r requirements.txt +``` + +## Run a simulation + +Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator. +```bash +nvflare simulator jobs/hello-flwr-pt -n 2 -t 2 -w /tmp/nvflare/flwr +``` diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/config/config_fed_client.json b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/config/config_fed_client.json new file mode 100644 index 0000000000..e1e74ade3f --- /dev/null +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/config/config_fed_client.json @@ -0,0 +1,17 @@ +{ + "format_version": 2, + "executors": [ + { + "tasks": ["*"], + "executor": { + "path": "nvflare.app_opt.flower.executor.FlowerExecutor", + "args": { + "client_app": "client:app" + } + } + } + ], + "task_result_filters": [], + "task_data_filters": [], + "components": [] +} \ No newline at end of file diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/config/config_fed_server.json b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/config/config_fed_server.json new file mode 100644 index 0000000000..dfb4cab82d --- /dev/null +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/config/config_fed_server.json @@ -0,0 +1,16 @@ +{ + "format_version": 2, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + ], + "workflows": [ + { + "id": "ctl", + "path": "nvflare.app_opt.flower.controller.FlowerController", + "args": { + "server_app": "server:app" + } + } + ] +} \ No newline at end of file diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py new file mode 100644 index 0000000000..9e674b27c3 --- /dev/null +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/client.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flwr.client import ClientApp, NumPyClient +from task import DEVICE, Net, get_weights, load_data, set_weights, test, train + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def fit(self, parameters, config): + set_weights(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_weights(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, +) + + +# Legacy mode +if __name__ == "__main__": + from flwr.client import start_client + + start_client( + server_address="127.0.0.1:8080", + client=FlowerClient().to_client(), + ) diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/server.py b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/server.py new file mode 100644 index 0000000000..8083a6b802 --- /dev/null +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/server.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +from flwr.common import Metrics, ndarrays_to_parameters +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from task import Net, get_weights + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, +) + + +# Define config +config = ServerConfig(num_rounds=3) + + +# Flower ServerApp +app = ServerApp( + config=config, + strategy=strategy, +) + + +# Legacy mode +if __name__ == "__main__": + from flwr.server import start_server + + start_server( + server_address="0.0.0.0:8080", + config=config, + strategy=strategy, + ) diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/task.py b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/task.py new file mode 100644 index 0000000000..7a5c1a0514 --- /dev/null +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/task.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from logging import INFO + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr.common.logger import log +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + log(INFO, "Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in testloader: + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/examples/hello-world/hello-flower/jobs/hello-flwr-pt/meta.json b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/meta.json new file mode 100644 index 0000000000..90bc82e8ed --- /dev/null +++ b/examples/hello-world/hello-flower/jobs/hello-flwr-pt/meta.json @@ -0,0 +1,10 @@ +{ + "name": "hello-flwr-pt", + "resource_spec": {}, + "min_clients" : 2, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/examples/hello-world/hello-flower/requirements.txt b/examples/hello-world/hello-flower/requirements.txt new file mode 100644 index 0000000000..1d8990f84a --- /dev/null +++ b/examples/hello-world/hello-flower/requirements.txt @@ -0,0 +1,4 @@ +nvflare~=2.5.0rc +flwr[simulation]>=1.8.0 +torch==2.2.1 +torchvision==0.17.1 diff --git a/examples/tutorials/setup_poc.ipynb b/examples/tutorials/setup_poc.ipynb index d1a1aa0007..d07c579882 100644 --- a/examples/tutorials/setup_poc.ipynb +++ b/examples/tutorials/setup_poc.ipynb @@ -81,12 +81,12 @@ "If you prefer not to use environment variable, you can do the followings: \n", "\n", "```\n", - "! nvflare config -pw /tmp/nvflare/poc\n", + "! nvflare config -pw /tmp/nvflare/poc --job_templates_dir ../../job_templates\n", "\n", "```\n", "or \n", "```\n", - "! nvflare config -poc_workspace_dir /tmp/nvflare/poc\n", + "! nvflare config --poc_workspace_dir /tmp/nvflare/poc --job_templates_dir ../../job_templates\n", "```" ] }, @@ -99,7 +99,7 @@ }, "outputs": [], "source": [ - "! nvflare config -pw /tmp/nvflare/poc" + "! nvflare config -pw /tmp/nvflare/poc --job_templates_dir ../../job_templates" ] }, { diff --git a/integration/monai/examples/spleen_ct_segmentation_local/README.md b/integration/monai/examples/spleen_ct_segmentation_local/README.md index 42772aa9e9..44b28e2670 100644 --- a/integration/monai/examples/spleen_ct_segmentation_local/README.md +++ b/integration/monai/examples/spleen_ct_segmentation_local/README.md @@ -160,17 +160,17 @@ Experiment tracking for the FLARE-MONAI integration now uses `NVFlareStatsHandle In this example, the `spleen_ct_segmentation_local` job is configured to automatically log metrics to MLflow through the FL server. -- The `config_fed_client.json` contains the `NVFlareStatsHandler`, `MetricsSender`, and `MetricRelay` (with their respective pipes) to send the metrics to the server side as federated events. -- Then in `config_fed_server.json`, the `MLflowReceiver` is configured for the server to write the results to the default MLflow tracking server URI. +- The `config_fed_client.conf` contains the `NVFlareStatsHandler`, `MetricsSender`, and `MetricRelay` (with their respective pipes) to send the metrics to the server side as federated events. +- Then in `config_fed_server.conf`, the `MLflowReceiver` is configured for the server to write the results to the MLflow tracking server URI `http://127.0.0.1:5000`. -With this configuration the MLflow tracking server must be started before running the job: +We need to start MLflow tracking server before running this job: ``` mlflow server ``` > **_NOTE:_** The receiver on the server side can be easily configured to support other experiment tracking formats. - In addition to the `MLflowReceiver`, the `WandBReceiver` and `TBAnalyticsReceiver` can also be used in `config_fed_server.json` for Tensorboard and Weights & Biases experiment tracking streaming to the server. + In addition to the `MLflowReceiver`, the `WandBReceiver` and `TBAnalyticsReceiver` can also be used in `config_fed_server.conf` for Tensorboard and Weights & Biases experiment tracking streaming to the server. Next, we can submit the job. @@ -219,10 +219,16 @@ nvflare job submit -j jobs/spleen_ct_segementation_he ### 5.4 MLflow experiment tracking results -To view the results, you can access the MLflow dashboard in your browser using the default tracking uri `http://127.0.0.1:5000`. - -> **_NOTE:_** To write the results to the server workspace instead of using the MLflow server, users can remove the `tracking_uri` argument from the `MLflowReceiver` configuration and instead view the results by running `mlflow ui --port 5000` in the directory that contains the `mlruns/` directory in the server workspace. +To view the results, you can access the MLflow dashboard in your browser using the tracking uri `http://127.0.0.1:5000`. Once the training is started, you can see the experiment curves for the local clients in the current run on the MLflow dashboard. -![MLflow dashboard](./mlflow.png) \ No newline at end of file +![MLflow dashboard](./mlflow.png) + + +> **_NOTE:_** If you prefer not to start the MLflow server before federated training, +> you can alternatively choose to write the metrics streaming results to the server's +> job workspace directory. Remove the tracking_uri argument from the MLflowReceiver +> configuration. After the job finishes, download the server job workspace and unzip it. +> You can view the results by running mlflow ui --port 5000 in the directory containing +> the mlruns/ directory within the server job workspace. diff --git a/integration/monai/setup.py b/integration/monai/setup.py index a31011b3eb..8b41d0a099 100644 --- a/integration/monai/setup.py +++ b/integration/monai/setup.py @@ -24,14 +24,14 @@ release = os.environ.get("MONAI_NVFL_RELEASE") if release == "1": package_name = "monai-nvflare" - version = "0.2.4" + version = "0.2.9" else: package_name = "monai-nvflare-nightly" today = datetime.date.today().timetuple() year = today[0] % 1000 month = today[1] day = today[2] - version = f"0.2.3.{year:02d}{month:02d}{day:02d}" + version = f"0.2.9.{year:02d}{month:02d}{day:02d}" setup( name=package_name, @@ -57,5 +57,5 @@ long_description=long_description, long_description_content_type="text/markdown", python_requires=">=3.8,<3.11", - install_requires=["monai>=1.3.0", "nvflare==2.4.0rc6"], + install_requires=["monai>=1.3.1", "nvflare~=2.5.0rc1"], ) diff --git a/integration/xgboost/encryption_plugins/.editorconfig b/integration/xgboost/encryption_plugins/.editorconfig new file mode 100644 index 0000000000..97a7bc133a --- /dev/null +++ b/integration/xgboost/encryption_plugins/.editorconfig @@ -0,0 +1,11 @@ +root = true + +[*] +charset=utf-8 +indent_style = space +indent_size = 2 +insert_final_newline = true + +[*.py] +indent_style = space +indent_size = 4 diff --git a/integration/xgboost/encryption_plugins/CMakeLists.txt b/integration/xgboost/encryption_plugins/CMakeLists.txt new file mode 100644 index 0000000000..f5d71dd61c --- /dev/null +++ b/integration/xgboost/encryption_plugins/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 3.19) +project(xgb_nvflare LANGUAGES CXX C VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) + +option(GOOGLE_TEST "Build google tests" OFF) + +file(GLOB_RECURSE LIB_SRC "src/*.cc") + +add_library(nvflare SHARED ${LIB_SRC}) +set_target_properties(nvflare PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ENABLE_EXPORTS ON +) +target_include_directories(nvflare PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include) + +if (APPLE) + add_link_options("LINKER:-object_path_lto,$_lto.o") + add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") +endif () + +#-- Unit Tests +if(GOOGLE_TEST) + find_package(GTest REQUIRED) + enable_testing() + add_executable(nvflare_test) + target_link_libraries(nvflare_test PRIVATE nvflare) + + + target_include_directories(nvflare_test PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include) + + add_subdirectory(${xgb_nvflare_SOURCE_DIR}/tests) + + add_test( + NAME TestNvflarePlugins + COMMAND nvflare_test + WORKING_DIRECTORY ${xgb_nvflare_BINARY_DIR}) + +endif() diff --git a/integration/xgboost/encryption_plugins/README.md b/integration/xgboost/encryption_plugins/README.md new file mode 100644 index 0000000000..57f2c4621e --- /dev/null +++ b/integration/xgboost/encryption_plugins/README.md @@ -0,0 +1,9 @@ +# Build Instruction + +cd NVFlare/integration/xgboost/encryption_plugins +mkdir build +cd build +cmake .. +make + +The library is libxgb_nvflare.so diff --git a/integration/xgboost/processor/src/README.md b/integration/xgboost/encryption_plugins/src/README.md similarity index 100% rename from integration/xgboost/processor/src/README.md rename to integration/xgboost/encryption_plugins/src/README.md diff --git a/integration/xgboost/processor/src/dam/README.md b/integration/xgboost/encryption_plugins/src/dam/README.md similarity index 100% rename from integration/xgboost/processor/src/dam/README.md rename to integration/xgboost/encryption_plugins/src/dam/README.md diff --git a/integration/xgboost/encryption_plugins/src/dam/dam.cc b/integration/xgboost/encryption_plugins/src/dam/dam.cc new file mode 100644 index 0000000000..9fdb7d8582 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/dam/dam.cc @@ -0,0 +1,274 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "dam.h" + + +void print_hex(const uint8_t *buffer, std::size_t size) { + std::cout << std::hex; + for (int i = 0; i < size; i++) { + int c = buffer[i]; + std::cout << c << " "; + } + std::cout << std::endl << std::dec; +} + +void print_buffer(const uint8_t *buffer, std::size_t size) { + if (size <= 64) { + std::cout << "Whole buffer: " << size << " bytes" << std::endl; + print_hex(buffer, size); + return; + } + + std::cout << "First chunk, Total: " << size << " bytes" << std::endl; + print_hex(buffer, 32); + std::cout << "Last chunk, Offset: " << size-16 << " bytes" << std::endl; + print_hex(buffer+size-32, 32); +} + +size_t align(const size_t length) { + return ((length + 7)/8)*8; +} + +// DamEncoder ====== +void DamEncoder::AddBuffer(const Buffer &buffer) { + if (debug_) { + std::cout << "AddBuffer called, size: " << buffer.buf_size << std::endl; + } + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(buffer, buf_size); + entries_.emplace_back(kDataTypeBuffer, static_cast(buffer.buffer), buffer.buf_size); +} + +void DamEncoder::AddFloatArray(const std::vector &value) { + if (debug_) { + std::cout << "AddFloatArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(reinterpret_cast(value.data()), value.size() * 8); + entries_.emplace_back(kDataTypeFloatArray, reinterpret_cast(value.data()), value.size()); +} + +void DamEncoder::AddIntArray(const std::vector &value) { + if (debug_) { + std::cout << "AddIntArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + // print_buffer(buffer, buf_size); + entries_.emplace_back(kDataTypeIntArray, reinterpret_cast(value.data()), value.size()); +} + +void DamEncoder::AddBufferArray(const std::vector &value) { + if (debug_) { + std::cout << "AddBufferArray called, size: " << value.size() << std::endl; + } + + if (encoded_) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + size_t size = 0; + for (auto &buf: value) { + size += buf.buf_size; + } + size += 8*value.size(); + entries_.emplace_back(kDataTypeBufferArray, reinterpret_cast(&value), size); +} + + +std::uint8_t * DamEncoder::Finish(size_t &size) { + encoded_ = true; + + size = CalculateSize(); + auto buf = static_cast(calloc(size, 1)); + auto pointer = buf; + auto sig = local_version_ ? kSignatureLocal : kSignature; + memcpy(pointer, sig, strlen(sig)); + memcpy(pointer+8, &size, 8); + memcpy(pointer+16, &data_set_id_, 8); + + pointer += kPrefixLen; + for (auto& entry : entries_) { + std::size_t len; + if (entry.data_type == kDataTypeBufferArray) { + auto buffers = reinterpret_cast *>(entry.pointer); + memcpy(pointer, &entry.data_type, 8); + pointer += 8; + auto array_size = static_cast(buffers->size()); + memcpy(pointer, &array_size, 8); + pointer += 8; + auto sizes = reinterpret_cast(pointer); + for (auto &item : *buffers) { + *sizes = static_cast(item.buf_size); + sizes++; + } + len = 8*buffers->size(); + auto buf_ptr = pointer + len; + for (auto &item : *buffers) { + if (item.buf_size > 0) { + memcpy(buf_ptr, item.buffer, item.buf_size); + } + buf_ptr += item.buf_size; + len += item.buf_size; + } + } else { + memcpy(pointer, &entry.data_type, 8); + pointer += 8; + memcpy(pointer, &entry.size, 8); + pointer += 8; + len = entry.size * entry.ItemSize(); + if (len) { + memcpy(pointer, entry.pointer, len); + } + } + pointer += align(len); + } + + if ((pointer - buf) != size) { + std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl; + return nullptr; + } + + return buf; +} + +std::size_t DamEncoder::CalculateSize() { + std::size_t size = kPrefixLen; + + for (auto& entry : entries_) { + size += 16; // The Type and Len + auto len = entry.size * entry.ItemSize(); + size += align(len); + } + + return size; +} + + +// DamDecoder ====== + +DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version, bool debug) { + local_version_ = local_version; + buffer_ = buffer; + buf_size_ = size; + pos_ = buffer + kPrefixLen; + debug_ = debug; + + if (size >= kPrefixLen) { + memcpy(&len_, buffer + 8, 8); + memcpy(&data_set_id_, buffer + 16, 8); + } else { + len_ = 0; + data_set_id_ = 0; + } +} + +bool DamDecoder::IsValid() const { + auto sig = local_version_ ? kSignatureLocal : kSignature; + return buf_size_ >= kPrefixLen && memcmp(buffer_, sig, strlen(sig)) == 0; +} + +Buffer DamDecoder::DecodeBuffer() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeBuffer) { + std::cout << "Data type " << type << " doesn't match bytes" << std::endl; + return {}; + } + pos_ += 8; + + auto size = *reinterpret_cast(pos_); + pos_ += 8; + + if (size == 0) { + return {}; + } + + auto ptr = reinterpret_cast(pos_); + pos_ += align(size); + return{ ptr, static_cast(size)}; +} + +std::vector DamDecoder::DecodeIntArray() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeIntArray) { + std::cout << "Data type " << type << " doesn't match Int Array" << std::endl; + return {}; + } + pos_ += 8; + + auto array_size = *reinterpret_cast(pos_); + pos_ += 8; + auto ptr = reinterpret_cast(pos_); + pos_ += align(8 * array_size); + return {ptr, ptr + array_size}; +} + +std::vector DamDecoder::DecodeFloatArray() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeFloatArray) { + std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; + return {}; + } + pos_ += 8; + + auto array_size = *reinterpret_cast(pos_); + pos_ += 8; + + auto ptr = reinterpret_cast(pos_); + pos_ += align(8 * array_size); + return {ptr, ptr + array_size}; +} + +std::vector DamDecoder::DecodeBufferArray() { + auto type = *reinterpret_cast(pos_); + if (type != kDataTypeBufferArray) { + std::cout << "Data type " << type << " doesn't match Bytes Array" << std::endl; + return {}; + } + pos_ += 8; + + auto num = *reinterpret_cast(pos_); + pos_ += 8; + + auto size_ptr = reinterpret_cast(pos_); + auto buf_ptr = pos_ + 8 * num; + size_t total_size = 8 * num; + auto result = std::vector(num); + for (int i = 0; i < num; i++) { + auto size = size_ptr[i]; + if (buf_size_ > 0) { + result[i].buf_size = size; + result[i].buffer = buf_ptr; + buf_ptr += size; + } + total_size += size; + } + + pos_ += align(total_size); + return result; +} diff --git a/integration/xgboost/encryption_plugins/src/include/base_plugin.h b/integration/xgboost/encryption_plugins/src/include/base_plugin.h new file mode 100644 index 0000000000..dddd5a7911 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/base_plugin.h @@ -0,0 +1,155 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // for uint8_t, uint32_t, int32_t, int64_t +#include // for string_view +#include // for pair +#include // for vector +#include +#include +#include + +#include "util.h" + +namespace nvflare { + +/** + * @brief Abstract interface for the encryption plugin + * + * All plugin implementations must inherit this class. + */ +class BasePlugin { +protected: + bool debug_ = false; + bool print_timing_ = false; + bool dam_debug_ = false; + +public: +/** + * @brief Constructor + * + * All inherited classes should call this constructor. + * + * @param args Entries from federated_plugin in communicator environments. + */ + explicit BasePlugin( + std::vector> const &args) { + debug_ = get_bool(args, "debug"); + print_timing_ = get_bool(args, "print_timing"); + dam_debug_ = get_bool(args, "dam_debug"); + } + + /** + * @brief Destructor + */ + virtual ~BasePlugin() = default; + + /** + * @brief Identity for the plugin used for debug + * + * This is a string with instance address and process id. + */ + std::string Ident() { + std::stringstream ss; + ss << std::hex << std::uppercase << std::setw(sizeof(void*) * 2) << std::setfill('0') << + reinterpret_cast(this); + return ss.str() + "-" + std::to_string(getpid()); + } + + /** + * @brief Encrypt the gradient pairs + * + * @param in_gpair Input g and h pairs for each record + * @param n_in The array size (2xnum_of_records) + * @param out_gpair Pointer to encrypted buffer + * @param n_out Encrypted buffer size + */ + virtual void EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, std::size_t *n_out) = 0; + + /** + * @brief Process encrypted gradient pairs + * + * @param in_gpair Encrypted gradient pairs + * @param n_bytes Buffer size of Encrypted gradient + * @param out_gpair Pointer to decrypted gradient pairs + * @param out_n_bytes Decrypted buffer size + */ + virtual void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) = 0; + + /** + * @brief Reset the histogram context + * + * @param cutptrs Cut-pointers for the flattened histograms + * @param cutptr_len cutptrs array size (number of features plus one) + * @param bin_idx An array (flattened matrix) of slot index for each record/feature + * @param n_idx The size of above array + */ + virtual void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, + std::int32_t const *bin_idx, std::size_t n_idx) = 0; + + /** + * @brief Encrypt histograms for horizontal training + * + * @param in_histogram The array for the histogram + * @param len The array size + * @param out_hist Pointer to encrypted buffer + * @param out_len Encrypted buffer size + */ + virtual void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Process encrypted histograms for horizontal training + * + * @param buffer Buffer for encrypted histograms + * @param len Buffer size of encrypted histograms + * @param out_hist Pointer to decrypted histograms + * @param out_len Size of above array + */ + virtual void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, + double **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Build histograms in encrypted space for vertical training + * + * @param ridx Pointer to a matrix of row IDs for each node + * @param sizes An array of sizes of each node + * @param nidx An array for each node ID + * @param len Number of nodes + * @param out_hist Pointer to encrypted histogram buffer + * @param out_len Buffer size + */ + virtual void BuildEncryptedHistVert(std::uint64_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) = 0; + + /** + * @brief Decrypt histogram for vertical training + * + * @param hist_buffer Encrypted histogram buffer + * @param len Buffer size of encrypted histogram + * @param out Pointer to decrypted histograms + * @param out_len Size of above array + */ + virtual void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len) = 0; +}; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/dam.h b/integration/xgboost/encryption_plugins/src/include/dam.h new file mode 100644 index 0000000000..8677a413b1 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/dam.h @@ -0,0 +1,143 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +constexpr char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 +constexpr char kSignatureLocal[] = "NVDADAML"; // DAM Local version +constexpr int kPrefixLen = 24; + +constexpr int kDataTypeInt = 1; +constexpr int kDataTypeFloat = 2; +constexpr int kDataTypeString = 3; +constexpr int kDataTypeBuffer = 4; +constexpr int kDataTypeIntArray = 257; +constexpr int kDataTypeFloatArray = 258; +constexpr int kDataTypeBufferArray = 259; +constexpr int kDataTypeMap = 1025; + +/*! \brief A replacement for std::span */ +class Buffer { +public: + void *buffer; + size_t buf_size; + bool allocated; + + Buffer() : buffer(nullptr), buf_size(0), allocated(false) { + } + + Buffer(void *buffer, size_t buf_size, bool allocated=false) : + buffer(buffer), buf_size(buf_size), allocated(allocated) { + } + + Buffer(const Buffer &that): + buffer(that.buffer), buf_size(that.buf_size), allocated(false) { + } +}; + +class Entry { + public: + int64_t data_type; + const uint8_t * pointer; + int64_t size; + + Entry(int64_t data_type, const uint8_t *pointer, int64_t size) { + this->data_type = data_type; + this->pointer = pointer; + this->size = size; + } + + [[nodiscard]] std::size_t ItemSize() const + { + size_t item_size; + switch (data_type) { + case kDataTypeBuffer: + case kDataTypeString: + case kDataTypeBufferArray: + item_size = 1; + break; + default: + item_size = 8; + } + return item_size; + } +}; + +class DamEncoder { + private: + bool encoded_ = false; + bool local_version_ = false; + bool debug_ = false; + int64_t data_set_id_; + std::vector entries_; + + public: + explicit DamEncoder(int64_t data_set_id, bool local_version=false, bool debug=false) { + data_set_id_ = data_set_id; + local_version_ = local_version; + debug_ = debug; + + } + + void AddBuffer(const Buffer &buffer); + + void AddIntArray(const std::vector &value); + + void AddFloatArray(const std::vector &value); + + void AddBufferArray(const std::vector &value); + + std::uint8_t * Finish(size_t &size); + + private: + std::size_t CalculateSize(); +}; + +class DamDecoder { + private: + bool local_version_ = false; + std::uint8_t *buffer_ = nullptr; + std::size_t buf_size_ = 0; + std::uint8_t *pos_ = nullptr; + std::size_t remaining_ = 0; + int64_t data_set_id_ = 0; + int64_t len_ = 0; + bool debug_ = false; + + public: + explicit DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version=false, bool debug=false); + + [[nodiscard]] std::size_t Size() const { + return len_; + } + + [[nodiscard]] int64_t GetDataSetId() const { + return data_set_id_; + } + + [[nodiscard]] bool IsValid() const; + + Buffer DecodeBuffer(); + + std::vector DecodeIntArray(); + + std::vector DecodeFloatArray(); + + std::vector DecodeBufferArray(); +}; + +void print_buffer(const uint8_t *buffer, std::size_t size); diff --git a/integration/xgboost/encryption_plugins/src/include/data_set_ids.h b/integration/xgboost/encryption_plugins/src/include/data_set_ids.h new file mode 100644 index 0000000000..98eb20e838 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/data_set_ids.h @@ -0,0 +1,23 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +constexpr int kDataSetGHPairs = 1; +constexpr int kDataSetAggregation = 2; +constexpr int kDataSetAggregationWithFeatures = 3; +constexpr int kDataSetAggregationResult = 4; +constexpr int kDataSetHistograms = 5; +constexpr int kDataSetHistogramResult = 6; diff --git a/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h b/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h new file mode 100644 index 0000000000..7b4f353b21 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/delegated_plugin.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "base_plugin.h" + +namespace nvflare { + +// Plugin that delegates to other real plugins +class DelegatedPlugin : public BasePlugin { + + BasePlugin *plugin_{nullptr}; + +public: + explicit DelegatedPlugin(std::vector> const &args); + + ~DelegatedPlugin() override { + delete plugin_; + } + + void EncryptGPairs(const float* in_gpair, std::size_t n_in, std::uint8_t** out_gpair, std::size_t* n_out) override { + plugin_->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + } + + void SyncEncryptedGPairs(const std::uint8_t* in_gpair, std::size_t n_bytes, const std::uint8_t** out_gpair, + std::size_t* out_n_bytes) override { + plugin_->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, out_n_bytes); + } + + void ResetHistContext(const std::uint32_t* cutptrs, std::size_t cutptr_len, const std::int32_t* bin_idx, + std::size_t n_idx) override { + plugin_->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + } + + void BuildEncryptedHistHori(const double* in_histogram, std::size_t len, std::uint8_t** out_hist, + std::size_t* out_len) override { + plugin_->BuildEncryptedHistHori(in_histogram, len, out_hist, out_len); + } + + void SyncEncryptedHistHori(const std::uint8_t* buffer, std::size_t len, double** out_hist, + std::size_t* out_len) override { + plugin_->SyncEncryptedHistHori(buffer, len, out_hist, out_len); + } + + void BuildEncryptedHistVert(const std::uint64_t** ridx, const std::size_t* sizes, const std::int32_t* nidx, + std::size_t len, std::uint8_t** out_hist, std::size_t* out_len) override { + plugin_->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + } + + void SyncEncryptedHistVert(std::uint8_t* hist_buffer, std::size_t len, double** out, std::size_t* out_len) override { + plugin_->SyncEncryptedHistVert(hist_buffer, len, out, out_len); + } +}; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/local_plugin.h b/integration/xgboost/encryption_plugins/src/include/local_plugin.h new file mode 100644 index 0000000000..2022322266 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/local_plugin.h @@ -0,0 +1,107 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "base_plugin.h" +#include "dam.h" + +namespace nvflare { + +// A base plugin for all plugins that handle encryption locally in C++ +class LocalPlugin : public BasePlugin { +protected: + std::vector gh_pairs_; + std::vector encrypted_gh_; + std::vector histo_; + std::vector cuts_; + std::vector slots_; + std::vector buffer_; + +public: + explicit LocalPlugin(std::vector> const &args) : + BasePlugin(args) {} + + ~LocalPlugin() override = default; + + void EncryptGPairs(const float *in_gpair, std::size_t n_in, std::uint8_t **out_gpair, + std::size_t *n_out) override; + + void SyncEncryptedGPairs(const std::uint8_t *in_gpair, std::size_t n_bytes, const std::uint8_t **out_gpair, + std::size_t *out_n_bytes) override; + + void ResetHistContext(const std::uint32_t *cutptrs, std::size_t cutptr_len, const std::int32_t *bin_idx, + std::size_t n_idx) override; + + void BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) override; + + void SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) override; + + void BuildEncryptedHistVert(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) override; + + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, double **out, + std::size_t *out_len) override; + + // Method needs to be implemented by local plugins + + /*! + * \brief Encrypt a vector of float-pointing numbers + * \param cleartext A vector of numbers in cleartext + * \return A buffer with serialized ciphertext + */ + virtual Buffer EncryptVector(const std::vector &cleartext) = 0; + + /*! + * \brief Decrypt a serialized ciphertext into an array of numbers + * \param ciphertext A serialzied buffer of ciphertext + * \return An array of numbers + */ + virtual std::vector DecryptVector(const std::vector &ciphertext) = 0; + + /*! + * \brief Add the G&H pairs for a series of samples + * \param sample_ids A map of slot number and an array of sample IDs + * \return A map of the serialized encrypted sum of G and H for each slot + * The input and output maps must have the same size + */ + virtual std::map AddGHPairs(const std::map> &sample_ids) = 0; + + /*! + * \brief Free encrypted data buffer + * \param ciphertext The buffer for encrypted data + */ + virtual void FreeEncryptedData(Buffer &ciphertext) { + if (ciphertext.allocated && ciphertext.buffer != nullptr) { + free(ciphertext.buffer); + ciphertext.allocated = false; + } + ciphertext.buffer = nullptr; + ciphertext.buf_size = 0; + }; + +private: + + void BuildEncryptedHistVertActive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len); + + void BuildEncryptedHistVertPassive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len); + +}; + +} // namespace nvflare diff --git a/integration/xgboost/processor/src/include/nvflare_processor.h b/integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h similarity index 77% rename from integration/xgboost/processor/src/include/nvflare_processor.h rename to integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h index cb7076eaf4..87f47d622c 100644 --- a/integration/xgboost/processor/src/include/nvflare_processor.h +++ b/integration/xgboost/encryption_plugins/src/include/nvflare_plugin.h @@ -14,61 +14,65 @@ * limitations under the License. */ #pragma once + #include // for uint8_t, uint32_t, int32_t, int64_t #include // for string_view #include // for pair #include // for vector -const int kDataSetHGPairs = 1; -const int kDataSetAggregation = 2; -const int kDataSetAggregationWithFeatures = 3; -const int kDataSetAggregationResult = 4; -const int kDataSetHistograms = 5; -const int kDataSetHistogramResult = 6; - -// Opaque pointer type for the C API. -typedef void *FederatedPluginHandle; // NOLINT +#include "base_plugin.h" namespace nvflare { -// Plugin that uses Python tenseal and GRPC. -class TensealPlugin { + +// Plugin that uses Python TenSeal and GRPC. +class NvflarePlugin : public BasePlugin { // Buffer for storing encrypted gradient pairs. std::vector encrypted_gpairs_; // Buffer for histogram cut pointers (indptr of a CSC). std::vector cut_ptrs_; // Buffer for histogram index. std::vector bin_idx_; + std::vector gh_pairs_; bool feature_sent_{false}; // The feature index. std::vector features_; // Buffer for output histogram. std::vector encrypted_hist_; - std::vector hist_; + // A temporary buffer to hold return value + std::vector buffer_; + // Buffer for clear histogram + std::vector histo_; public: - TensealPlugin( - std::vector> const &args); + explicit NvflarePlugin(std::vector> const &args) : BasePlugin(args) {} + + ~NvflarePlugin() override = default; + // Gradient pairs void EncryptGPairs(float const *in_gpair, std::size_t n_in, - std::uint8_t **out_gpair, std::size_t *n_out); + std::uint8_t **out_gpair, std::size_t *n_out) override; + void SyncEncryptedGPairs(std::uint8_t const *in_gpair, std::size_t n_bytes, std::uint8_t const **out_gpair, - std::size_t *out_n_bytes); + std::size_t *out_n_bytes) override; // Histogram void ResetHistContext(std::uint32_t const *cutptrs, std::size_t cutptr_len, - std::int32_t const *bin_idx, std::size_t n_idx); + std::int32_t const *bin_idx, std::size_t n_idx) override; + void BuildEncryptedHistHori(double const *in_histogram, std::size_t len, - std::uint8_t **out_hist, std::size_t *out_len); + std::uint8_t **out_hist, std::size_t *out_len) override; + void SyncEncryptedHistHori(std::uint8_t const *buffer, std::size_t len, - double **out_hist, std::size_t *out_len); + double **out_hist, std::size_t *out_len) override; - void BuildEncryptedHistVert(std::size_t const **ridx, + void BuildEncryptedHistVert(std::uint64_t const **ridx, std::size_t const *sizes, std::int32_t const *nidx, std::size_t len, - std::uint8_t **out_hist, std::size_t *out_len); + std::uint8_t **out_hist, std::size_t *out_len) override; + void SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, - double **out, std::size_t *out_len); + double **out, std::size_t *out_len) override; }; } // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h b/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h new file mode 100644 index 0000000000..3abeee4b56 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/pass_thru_plugin.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "local_plugin.h" + +namespace nvflare { + // A pass-through plugin that doesn't encrypt any data + class PassThruPlugin : public LocalPlugin { + public: + explicit PassThruPlugin(std::vector> const &args) : + LocalPlugin(args) {} + + ~PassThruPlugin() override = default; + + // Horizontal in local plugin still goes through NVFlare, so it needs to be overwritten + void BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) override; + + void SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) override; + + Buffer EncryptVector(const std::vector &cleartext) override; + + std::vector DecryptVector(const std::vector &ciphertext) override; + + std::map AddGHPairs(const std::map> &sample_ids) override; + }; +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/include/util.h b/integration/xgboost/encryption_plugins/src/include/util.h new file mode 100644 index 0000000000..bb8ba16d1a --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/include/util.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include + +std::vector> distribute_work(size_t num_jobs, size_t num_workers); + +uint32_t to_int(double d); + +double to_double(uint32_t i); + +std::string get_string(std::vector> const &args, + std::string_view const &key,std::string_view default_value = ""); + +bool get_bool(std::vector> const &args, + const std::string &key, bool default_value = false); + +int get_int(std::vector> const &args, + const std::string &key, int default_value = 0); diff --git a/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc new file mode 100644 index 0000000000..a026822799 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/delegated_plugin.cc @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "delegated_plugin.h" +#include "pass_thru_plugin.h" +#include "nvflare_plugin.h" + +namespace nvflare { + +DelegatedPlugin::DelegatedPlugin(std::vector> const &args): + BasePlugin(args) { + + auto name = get_string(args, "name"); + // std::cout << "==== Name is " << name << std::endl; + if (name == "pass-thru") { + plugin_ = new PassThruPlugin(args); + } else if (name == "nvflare") { + plugin_ = new NvflarePlugin(args); + } else { + throw std::invalid_argument{"Unknown plugin name: " + name}; + } +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc new file mode 100644 index 0000000000..99e304ea77 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/local_plugin.cc @@ -0,0 +1,366 @@ +/** + * Copyright 2014-2024 by XGBoost Contributors + */ +#include +#include +#include +#include "local_plugin.h" +#include "data_set_ids.h" + +namespace nvflare { + +void LocalPlugin::EncryptGPairs(const float *in_gpair, std::size_t n_in, std::uint8_t **out_gpair, std::size_t *n_out) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::EncryptGPairs called with pairs size: " << n_in << std::endl; + } + + if (print_timing_) { + std::cout << "Encrypting " << n_in / 2 << " GH Pairs" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto pairs = std::vector(in_gpair, in_gpair + n_in); + auto double_pairs = std::vector(pairs.cbegin(), pairs.cend()); + auto encrypted_data = EncryptVector(double_pairs); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast(std::chrono::duration_cast(end - start).count()) / 1000.0; + std::cout << "Encryption time: " << secs << " seconds" << std::endl; + } + + // Serialize with DAM so the buffers can be separated after all-gather + DamEncoder encoder(kDataSetGHPairs, true, dam_debug_); + encoder.AddBuffer(encrypted_data); + + std::size_t size; + auto buffer = encoder.Finish(size); + FreeEncryptedData(encrypted_data); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_gpair = buffer_.data(); + *n_out = buffer_.size(); + if (debug_) { + std::cout << "Encrypted GPairs:" << std::endl; + print_buffer(*out_gpair, *n_out); + } + + // Save pairs for future operations. This is only called on active site + gh_pairs_ = std::vector(double_pairs); +} + +void LocalPlugin::SyncEncryptedGPairs(const std::uint8_t *in_gpair, std::size_t n_bytes, + const std::uint8_t **out_gpair, std::size_t *out_n_bytes) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedGPairs called with buffer:" << std::endl; + print_buffer(in_gpair, n_bytes); + } + + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; + auto decoder = DamDecoder(const_cast(in_gpair), n_bytes, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "LocalPlugin::SyncEncryptedGPairs called with wrong data" << std::endl; + return; + } + + auto encrypted_buffer = decoder.DecodeBuffer(); + if (debug_) { + std::cout << "Encrypted buffer size: " << encrypted_buffer.buf_size << std::endl; + } + + // The caller may free buffer so a copy is needed + auto pointer = static_cast(encrypted_buffer.buffer); + encrypted_gh_ = std::vector(pointer, pointer + encrypted_buffer.buf_size); + FreeEncryptedData(encrypted_buffer); +} + +void LocalPlugin::ResetHistContext(const std::uint32_t *cutptrs, std::size_t cutptr_len, const std::int32_t *bin_idx, + std::size_t n_idx) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::ResetHistContext called with cutptrs size: " << cutptr_len << " bin_idx size: " + << n_idx << std::endl; + } + + cuts_ = std::vector(cutptrs, cutptrs + cutptr_len); + slots_ = std::vector(bin_idx, bin_idx + n_idx); +} + +void LocalPlugin::BuildEncryptedHistHori(const double *in_histogram, std::size_t len, std::uint8_t **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistHori called with " << len << " entries" << std::endl; + print_buffer(reinterpret_cast(in_histogram), len); + } + + // don't have a local implementation yet, just encoded it and let NVFlare handle it. + DamEncoder encoder(kDataSetHistograms, false, dam_debug_); + auto histograms = std::vector(in_histogram, in_histogram + len); + encoder.AddFloatArray(histograms); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_hist = buffer_.data(); + *out_len = buffer_.size(); + if (debug_) { + std::cout << "Output buffer" << std::endl; + print_buffer(*out_hist, *out_len); + } +} + +void LocalPlugin::SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, double **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + print_buffer(buffer, len); + } + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast(pointer - buffer) << std::endl; + break; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + std::to_string(decoder.GetDataSetId())}; + } + + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); + + if (debug_) { + std::cout << "Output buffer" << std::endl; + print_buffer(reinterpret_cast(*out_hist), histo_.size() * sizeof(double)); + } +} + +void LocalPlugin::BuildEncryptedHistVert(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *nidx, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVert called with number of nodes: " << len << std::endl; + } + + if (gh_pairs_.empty()) { + BuildEncryptedHistVertPassive(ridx, sizes, nidx, len, out_hist, out_len); + } else { + BuildEncryptedHistVertActive(ridx, sizes, nidx, len, out_hist, out_len); + } + + if (debug_) { + std::cout << "Encrypted histogram output:" << std::endl; + print_buffer(*out_hist, *out_len); + } +} + +void LocalPlugin::BuildEncryptedHistVertActive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVertActive called with " << len << " nodes" << std::endl; + } + + auto total_bin_size = cuts_.back(); + auto histo_size = total_bin_size * 2; + auto total_size = histo_size * len; + + histo_.clear(); + histo_.resize(total_size); + size_t start = 0; + for (std::size_t i = 0; i < len; i++) { + for (std::size_t j = 0; j < sizes[i]; j++) { + auto row_id = ridx[i][j]; + auto num = cuts_.size() - 1; + for (std::size_t f = 0; f < num; f++) { + int slot = slots_[f + num * row_id]; + if ((slot < 0) || (slot >= total_bin_size)) { + continue; + } + auto g = gh_pairs_[row_id * 2]; + auto h = gh_pairs_[row_id * 2 + 1]; + (histo_)[start + slot * 2] += g; + (histo_)[start + slot * 2 + 1] += h; + } + } + start += histo_size; + } + + // Histogram is in clear, can't send to all_gather. Just return empty DAM buffer + auto encoder = DamEncoder(kDataSetAggregationResult, true, dam_debug_); + encoder.AddBuffer(Buffer()); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = size; +} + +void LocalPlugin::BuildEncryptedHistVertPassive(const std::uint64_t **ridx, const std::size_t *sizes, const std::int32_t *, + std::size_t len, std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::BuildEncryptedHistVertPassive called with " << len << " nodes" << std::endl; + } + + auto num_slot = cuts_.back(); + auto total_size = num_slot * len; + + auto encrypted_histo = std::vector(total_size); + size_t offset = 0; + for (std::size_t i = 0; i < len; i++) { + auto num = cuts_.size() - 1; + auto row_id_map = std::map>(); + + // Empty slot leaks data so fill everything with empty vectors + for (int slot = 0; slot < num_slot; slot++) { + row_id_map.insert({slot, std::vector()}); + } + + for (std::size_t f = 0; f < num; f++) { + for (std::size_t j = 0; j < sizes[i]; j++) { + auto row_id = ridx[i][j]; + int slot = slots_[f + num * row_id]; + if ((slot < 0) || (slot >= num_slot)) { + continue; + } + auto &row_ids = row_id_map[slot]; + row_ids.push_back(static_cast(row_id)); + } + } + + if (print_timing_) { + std::size_t add_ops = 0; + for (auto &item: row_id_map) { + add_ops += item.second.size(); + } + std::cout << "Aggregating with " << add_ops << " additions" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto encrypted_sum = AddGHPairs(row_id_map); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast(std::chrono::duration_cast(end - start).count()) / 1000.0; + std::cout << "Aggregation time: " << secs << " seconds" << std::endl; + } + + // Convert map back to array + for (int slot = 0; slot < num_slot; slot++) { + auto it = encrypted_sum.find(slot); + if (it != encrypted_sum.end()) { + encrypted_histo[offset + slot] = it->second; + } + } + + offset += num_slot; + } + + auto encoder = DamEncoder(kDataSetAggregationResult, true, dam_debug_); + encoder.AddBufferArray(encrypted_histo); + std::size_t size; + auto buffer = encoder.Finish(size); + for (auto &item: encrypted_histo) { + FreeEncryptedData(item); + } + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = size; +} + +void LocalPlugin::SyncEncryptedHistVert(std::uint8_t *hist_buffer, std::size_t len, + double **out, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistVert called with buffer size: " << len << " nodes" << std::endl; + print_buffer(hist_buffer, len); + } + + auto remaining = len; + auto pointer = hist_buffer; + + *out = nullptr; + *out_len = 0; + if (gh_pairs_.empty()) { + if (debug_) { + std::cout << Ident() << " LocalPlugin::SyncEncryptedHistVert Do nothing for passive worker" << std::endl; + } + // Do nothing for passive worker + return; + } + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + auto first = true; + auto orig_size = histo_.size(); + while (remaining > kPrefixLen) { + DamDecoder decoder(pointer, remaining, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast((pointer - hist_buffer)) << std::endl; + break; + } + auto size = decoder.Size(); + if (first) { + if (histo_.empty()) { + std::cout << "No clear histogram." << std::endl; + return; + } + first = false; + } else { + auto encrypted_buf = decoder.DecodeBufferArray(); + + if (print_timing_) { + std::cout << "Decrypting " << encrypted_buf.size() << " pairs" << std::endl; + } + auto start = std::chrono::system_clock::now(); + + auto decrypted_histo = DecryptVector(encrypted_buf); + + if (print_timing_) { + auto end = std::chrono::system_clock::now(); + auto secs = static_cast(std::chrono::duration_cast(end - start).count()) / 1000.0; + std::cout << "Decryption time: " << secs << " seconds" << std::endl; + } + + if (decrypted_histo.size() != orig_size) { + std::cout << "Histo sizes are different: " << decrypted_histo.size() + << " != " << orig_size << std::endl; + } + histo_.insert(histo_.end(), decrypted_histo.cbegin(), decrypted_histo.cend()); + } + remaining -= size; + pointer += size; + } + + if (debug_) { + std::cout << Ident() << " Decrypted result size: " << histo_.size() << std::endl; + } + + // print_buffer(reinterpret_cast(result.data()), result.size()*8); + + *out = histo_.data(); + *out_len = histo_.size(); +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc new file mode 100644 index 0000000000..b062aecfa6 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/nvflare_plugin.cc @@ -0,0 +1,297 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include // for copy_n, transform +#include // for memcpy +#include // for invalid_argument +#include // for vector + +#include "nvflare_plugin.h" +#include "data_set_ids.h" +#include "dam.h" // for DamEncoder + +namespace nvflare { + +void NvflarePlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, + std::uint8_t **out_gpair, + std::size_t *n_out) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::EncryptGPairs called with pairs size: " << n_in<< std::endl; + } + + auto pairs = std::vector(in_gpair, in_gpair + n_in); + gh_pairs_ = std::vector(pairs.cbegin(), pairs.cend()); + + DamEncoder encoder(kDataSetGHPairs, false, dam_debug_); + encoder.AddFloatArray(gh_pairs_); + std::size_t size; + auto buffer = encoder.Finish(size); + if (!out_gpair) { + throw std::invalid_argument{"Invalid pointer to output gpair."}; + } + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_gpair = buffer_.data(); + *n_out = size; +} + +void NvflarePlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, + std::size_t n_bytes, + std::uint8_t const **out_gpair, + std::size_t *out_n_bytes) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedGPairs called with buffer size: " << n_bytes << std::endl; + } + + // For NVFlare plugin, nothing needs to be done here + *out_n_bytes = n_bytes; + *out_gpair = in_gpair; +} + +void NvflarePlugin::ResetHistContext(std::uint32_t const *cutptrs, + std::size_t cutptr_len, + std::int32_t const *bin_idx, + std::size_t n_idx) { + if (debug_) { + std::cout << Ident() << " NvFlarePlugin::ResetHistContext called with cutptrs size: " << cutptr_len << " bin_idx size: " + << n_idx<< std::endl; + } + + cut_ptrs_.resize(cutptr_len); + std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); + bin_idx_.resize(n_idx); + std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); +} + +void NvflarePlugin::BuildEncryptedHistVert(std::uint64_t const **ridx, + std::size_t const *sizes, + std::int32_t const *nidx, + std::size_t len, + std::uint8_t** out_hist, + std::size_t* out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::BuildEncryptedHistVert called with len: " << len << std::endl; + } + + std::int64_t data_set_id; + if (!feature_sent_) { + data_set_id = kDataSetAggregationWithFeatures; + feature_sent_ = true; + } else { + data_set_id = kDataSetAggregation; + } + + DamEncoder encoder(data_set_id, false, dam_debug_); + + // Add cuts pointers + std::vector cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); + encoder.AddIntArray(cuts_vec); + + auto num_features = cut_ptrs_.size() - 1; + auto num_samples = bin_idx_.size() / num_features; + if (debug_) { + std::cout << "Samples: " << num_samples << " Features: " << num_features << std::endl; + } + + std::vector bins; + if (data_set_id == kDataSetAggregationWithFeatures) { + if (features_.empty()) { // when is it not empty? + for (int64_t f = 0; f < num_features; f++) { + auto slot = bin_idx_[f]; + if (slot >= 0) { + // what happens if it's missing? + features_.push_back(f); + } + } + } + encoder.AddIntArray(features_); + + for (int i = 0; i < num_samples; i++) { + for (auto f : features_) { + auto index = f + i * num_features; + if (index > bin_idx_.size()) { + throw std::out_of_range{"Index is out of range: " + + std::to_string(index)}; + } + auto slot = bin_idx_[index]; + bins.push_back(slot); + } + } + encoder.AddIntArray(bins); + } + + // Add nodes to build + std::vector node_vec(len); + for (std::size_t i = 0; i < len; i++) { + node_vec[i] = nidx[i]; + } + encoder.AddIntArray(node_vec); + + // For each node, get the row_id/slot pair + auto row_ids = std::vector>(len); + for (std::size_t i = 0; i < len; ++i) { + auto& rows = row_ids[i]; + rows.resize(sizes[i]); + for (std::size_t j = 0; j < sizes[i]; j++) { + rows[j] = static_cast(ridx[i][j]); + } + encoder.AddIntArray(rows); + } + + std::size_t n{0}; + auto buffer = encoder.Finish(n); + if (debug_) { + std::cout << "Finished size: " << n << std::endl; + } + + // XGBoost doesn't allow the change of allgatherV sizes. Make sure it's big + // enough to carry histograms + auto max_slot = cut_ptrs_.back(); + auto histo_size = 2 * max_slot * sizeof(double) * len + 1024*1024; // 1M is DAM overhead + auto buf_size = histo_size > n ? histo_size : n; + + // Copy to an array so the buffer can be freed, should change encoder to return vector + buffer_.resize(buf_size); + std::copy_n(buffer, n, buffer_.begin()); + free(buffer); + + *out_hist = buffer_.data(); + *out_len = buffer_.size(); +} + +void NvflarePlugin::SyncEncryptedHistVert(std::uint8_t *buffer, + std::size_t buf_size, + double **out, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedHistVert called with buffer size: " << buf_size << std::endl; + } + + auto remaining = buf_size; + char *pointer = reinterpret_cast(buffer); + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector &result = histo_; + result.clear(); + auto max_slot = cut_ptrs_.back(); + auto array_size = 2 * max_slot * sizeof(double); + + // A new histogram array? + auto slots = static_cast(malloc(array_size)); + while (remaining > kPrefixLen) { + DamDecoder decoder(reinterpret_cast(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded buffer ignored at offset: " + << static_cast((pointer - reinterpret_cast(buffer))) << std::endl; + break; + } + auto size = decoder.Size(); + auto node_list = decoder.DecodeIntArray(); + if (debug_) { + std::cout << "Number of nodes: " << node_list.size() << " Histo size: " << 2*max_slot << std::endl; + } + for ([[maybe_unused]] auto node : node_list) { + std::memset(slots, 0, array_size); + auto feature_list = decoder.DecodeIntArray(); + // Convert per-feature histo to a flat one + for (auto f : feature_list) { + auto base = cut_ptrs_[f]; // cut pointer for the current feature + auto bins = decoder.DecodeFloatArray(); + auto n = bins.size() / 2; + for (int i = 0; i < n; i++) { + auto index = base + i; + // [Q] Build local histogram? Why does it need to be built here? + slots[2 * index] += bins[2 * i]; + slots[2 * index + 1] += bins[2 * i + 1]; + } + } + result.insert(result.end(), slots, slots + 2 * max_slot); + } + remaining -= size; + pointer += size; + } + free(slots); + + // result is a reference to a histo_ + *out_len = result.size(); + *out = result.data(); + if (debug_) { + std::cout << "Total histogram size: " << *out_len << std::endl; + } +} + +void NvflarePlugin::BuildEncryptedHistHori(double const *in_histogram, + std::size_t len, + std::uint8_t **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::BuildEncryptedHistHori called with histo size: " << len << std::endl; + } + + DamEncoder encoder(kDataSetHistograms, false, dam_debug_); + std::vector copy(in_histogram, in_histogram + len); + encoder.AddFloatArray(copy); + + std::size_t size{0}; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + + *out_hist = this->buffer_.data(); + *out_len = this->buffer_.size(); +} + +void NvflarePlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, + std::size_t len, + double **out_hist, + std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " NvflarePlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + } + + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast(pointer), remaining, false, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast(pointer - buffer) << std::endl; + break; + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + throw std::runtime_error{"Invalid dataset: " + std::to_string(decoder.GetDataSetId())}; + } + + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc b/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc new file mode 100644 index 0000000000..4a29d0ed2b --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/pass_thru_plugin.cc @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "pass_thru_plugin.h" +#include "data_set_ids.h" + +namespace nvflare { + +void PassThruPlugin::BuildEncryptedHistHori(const double *in_histogram, std::size_t len, + std::uint8_t **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " PassThruPlugin::BuildEncryptedHistHori called with " << len << " entries" << std::endl; + } + + DamEncoder encoder(kDataSetHistogramResult, true, dam_debug_); + auto array = std::vector(in_histogram, in_histogram + len); + encoder.AddFloatArray(array); + std::size_t size; + auto buffer = encoder.Finish(size); + buffer_.resize(size); + std::copy_n(buffer, size, buffer_.begin()); + free(buffer); + *out_hist = buffer_.data(); + *out_len = buffer_.size(); +} + +void PassThruPlugin::SyncEncryptedHistHori(const std::uint8_t *buffer, std::size_t len, + double **out_hist, std::size_t *out_len) { + if (debug_) { + std::cout << Ident() << " PassThruPlugin::SyncEncryptedHistHori called with buffer size: " << len << std::endl; + } + + auto remaining = len; + auto pointer = buffer; + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector& result = histo_; + result.clear(); + while (remaining > kPrefixLen) { + DamDecoder decoder(const_cast(pointer), remaining, true, dam_debug_); + if (!decoder.IsValid()) { + std::cout << "Not DAM encoded histogram ignored at offset: " + << static_cast(pointer - buffer) << std::endl; + break; + } + auto size = decoder.Size(); + auto histo = decoder.DecodeFloatArray(); + result.insert(result.end(), histo.cbegin(), histo.cend()); + + remaining -= size; + pointer += size; + } + + *out_hist = result.data(); + *out_len = result.size(); +} + +Buffer PassThruPlugin::EncryptVector(const std::vector& cleartext) { + if (debug_ && cleartext.size() > 2) { + std::cout << "PassThruPlugin::EncryptVector called with cleartext size: " << cleartext.size() << std::endl; + } + + size_t size = cleartext.size() * sizeof(double); + auto buf = static_cast(malloc(size)); + std::copy_n(reinterpret_cast(cleartext.data()), size, buf); + + return {buf, size, true}; +} + +std::vector PassThruPlugin::DecryptVector(const std::vector& ciphertext) { + if (debug_) { + std::cout << "PassThruPlugin::DecryptVector with ciphertext size: " << ciphertext.size() << std::endl; + } + + std::vector result; + + for (auto const &v : ciphertext) { + size_t n = v.buf_size/sizeof(double); + auto p = static_cast(v.buffer); + for (int i = 0; i < n; i++) { + result.push_back(p[i]); + } + } + + return result; +} + +std::map PassThruPlugin::AddGHPairs(const std::map>& sample_ids) { + if (debug_) { + std::cout << "PassThruPlugin::AddGHPairs called with " << sample_ids.size() << " slots" << std::endl; + } + + // Can't do this in real plugin. It needs to be broken into encrypted parts + auto gh_pairs = DecryptVector(std::vector{Buffer(encrypted_gh_.data(), encrypted_gh_.size())}); + + auto result = std::map(); + for (auto const &entry : sample_ids) { + auto rows = entry.second; + double g = 0.0; + double h = 0.0; + + for (auto row : rows) { + g += gh_pairs[2 * row]; + h += gh_pairs[2 * row + 1]; + } + // In real plugin, the sum should be still in encrypted state. No need to do this step + auto encrypted_sum = EncryptVector(std::vector{g, h}); + // print_buffer(reinterpret_cast(encrypted_sum.buffer), encrypted_sum.buf_size); + result.insert({entry.first, encrypted_sum}); + } + + return result; +} + +} // namespace nvflare diff --git a/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc b/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc new file mode 100644 index 0000000000..4c1d43a6f8 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/plugin_main.cc @@ -0,0 +1,184 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include // for shared_ptr +#include // for invalid_argument +#include // for string_view +#include // for vector +#include // for transform + +#include "delegated_plugin.h" + +// Opaque pointer type for the C API. +typedef void *FederatedPluginHandle; // NOLINT + +namespace nvflare { +namespace { +// The opaque type for the C handle. +using CHandleT = std::shared_ptr *; +// Actual representation used in C++ code base. +using HandleT = std::remove_pointer_t; + +std::string &GlobalErrorMsg() { + static thread_local std::string msg; + return msg; +} + +// Perform handle handling for C API functions. +template auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { + auto pptr = static_cast(handle); + if (!pptr) { + return 1; + } + + try { + if constexpr (std::is_void_v>) { + fn(*pptr); + return 0; + } else { + return fn(*pptr); + } + } catch (std::exception const &e) { + GlobalErrorMsg() = e.what(); + return 1; + } +} +} // namespace +} // namespace nvflare + +#if defined(_MSC_VER) || defined(_WIN32) +#define NVF_C __declspec(dllexport) +#else +#define NVF_C __attribute__((visibility("default"))) +#endif // defined(_MSC_VER) || defined(_WIN32) + +extern "C" { +NVF_C char const *FederatedPluginErrorMsg() { + return nvflare::GlobalErrorMsg().c_str(); +} + +FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { + // std::cout << "==== FedreatedPluginCreate called with argc=" << argc << std::endl; + using namespace nvflare; + try { + auto pptr = new std::shared_ptr; + std::vector> args; + std::transform( + argv, argv + argc, std::back_inserter(args), [](char const *carg) { + // Split a key value pair in contructor argument: `key=value` + std::string_view arg{carg}; + auto idx = arg.find('='); + if (idx == std::string_view::npos) { + // `=` not found + throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; + } + auto key = arg.substr(0, idx); + auto value = arg.substr(idx + 1); + return std::make_pair(key, value); + }); + *pptr = std::make_shared(args); + // std::cout << "==== Plugin created: " << pptr << std::endl; + return pptr; + } catch (std::exception const &e) { + // std::cout << "==== Create exception " << e.what() << std::endl; + GlobalErrorMsg() = e.what(); + return nullptr; + } +} + +int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { + using namespace nvflare; + auto pptr = static_cast(handle); + if (!pptr) { + return 1; + } + + delete pptr; + + return 0; +} + +int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, + float const *in_gpair, size_t n_in, + uint8_t **out_gpair, size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); + return 0; + }); +} + +int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, + uint8_t const *in_gpair, + size_t n_bytes, + uint8_t const **out_gpair, + size_t *n_out) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); + }); +} + +int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, + uint32_t const *cutptrs, + size_t cutptr_len, + int32_t const *bin_idx, + size_t n_idx) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistVert( + FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, + int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEncryptedHistVert(FederatedPluginHandle handle, + uint8_t *in_hist, size_t len, + double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, + double const *in_hist, + size_t len, uint8_t **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); + }); +} + +int NVF_C FederatedPluginSyncEncryptedHistHori(FederatedPluginHandle handle, + uint8_t const *in_hist, + size_t len, double **out_hist, + size_t *out_len) { + using namespace nvflare; + return CApiGuard(handle, [&](HandleT const &plugin) { + plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); + return 0; + }); +} +} // extern "C" diff --git a/integration/xgboost/encryption_plugins/src/plugins/util.cc b/integration/xgboost/encryption_plugins/src/plugins/util.cc new file mode 100644 index 0000000000..a0cbd922d4 --- /dev/null +++ b/integration/xgboost/encryption_plugins/src/plugins/util.cc @@ -0,0 +1,99 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "util.h" + + +constexpr double kScaleFactor = 1000000.0; + +std::vector> distribute_work(size_t num_jobs, size_t const num_workers) { + std::vector> result; + auto num = num_jobs / num_workers; + auto remainder = num_jobs % num_workers; + int start = 0; + for (int i = 0; i < num_workers; i++) { + auto stop = static_cast((start + num - 1)); + if (i < remainder) { + // If jobs cannot be evenly distributed, first few workers take an extra one + stop += 1; + } + + if (start <= stop) { + result.emplace_back(start, stop); + } + start = stop + 1; + } + + // Verify all jobs are distributed + int sum = 0; + for (auto &item: result) { + sum += item.second - item.first + 1; + } + + if (sum != num_jobs) { + std::cout << "Distribution error" << std::endl; + } + + return result; +} + +uint32_t to_int(double d) { + auto int_val = static_cast(d * kScaleFactor); + return static_cast(int_val); +} + +double to_double(uint32_t i) { + auto int_val = static_cast(i); + return static_cast(int_val / kScaleFactor); +} + +std::string get_string(std::vector> const &args, + std::string_view const &key, std::string_view const default_value) { + + auto it = find_if( + args.begin(), args.end(), + [key](const auto &p) { return p.first == key; }); + + if (it != args.end()) { + return std::string{it->second}; + } + + return std::string{default_value}; +} + +bool get_bool(std::vector> const &args, + const std::string &key, bool default_value) { + std::string value = get_string(args, key, ""); + if (value.empty()) { + return default_value; + } + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { return std::tolower(c); }); + auto true_values = std::set < std::string_view > {"true", "yes", "y", "on", "1"}; + return true_values.count(value) > 0; +} + +int get_int(std::vector> const &args, + const std::string &key, int default_value) { + + auto value = get_string(args, key, ""); + if (value.empty()) { + return default_value; + } + + return stoi(value, nullptr); +} diff --git a/integration/xgboost/encryption_plugins/tests/CMakeLists.txt b/integration/xgboost/encryption_plugins/tests/CMakeLists.txt new file mode 100644 index 0000000000..04580bdd59 --- /dev/null +++ b/integration/xgboost/encryption_plugins/tests/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE TEST_SOURCES "*.cc") + +target_sources(xgb_nvflare_test PRIVATE ${TEST_SOURCES}) + +target_include_directories(xgb_nvflare_test + PRIVATE + ${GTEST_INCLUDE_DIRS} + ${xgb_nvflare_SOURCE_DIR/tests} + ${xgb_nvflare_SOURCE_DIR}/src) + +message("Include Dir: ${GTEST_INCLUDE_DIRS}") +target_link_libraries(xgb_nvflare_test + PRIVATE + ${GTEST_LIBRARIES}) diff --git a/integration/xgboost/processor/tests/test_dam.cc b/integration/xgboost/encryption_plugins/tests/test_dam.cc similarity index 65% rename from integration/xgboost/processor/tests/test_dam.cc rename to integration/xgboost/encryption_plugins/tests/test_dam.cc index 5573d5440d..345978b110 100644 --- a/integration/xgboost/processor/tests/test_dam.cc +++ b/integration/xgboost/encryption_plugins/tests/test_dam.cc @@ -19,20 +19,45 @@ TEST(DamTest, TestEncodeDecode) { double float_array[] = {1.1, 1.2, 1.3, 1.4}; int64_t int_array[] = {123, 456, 789}; + char buf1[] = "short"; + char buf2[] = "very long"; + DamEncoder encoder(123); + auto b1 = Buffer(buf1, strlen(buf1)); + auto b2 = Buffer(buf2, strlen(buf2)); + encoder.AddBuffer(b1); + encoder.AddBuffer(b2); + + std::vector b{b1, b2}; + encoder.AddBufferArray(b); + auto f = std::vector(float_array, float_array + 4); encoder.AddFloatArray(f); + auto i = std::vector(int_array, int_array + 3); encoder.AddIntArray(i); + size_t size; auto buf = encoder.Finish(size); std::cout << "Encoded size is " << size << std::endl; - DamDecoder decoder(buf.data(), size); + // Decoding test + DamDecoder decoder(buf, size); EXPECT_EQ(decoder.IsValid(), true); EXPECT_EQ(decoder.GetDataSetId(), 123); + auto new_buf1 = decoder.DecodeBuffer(); + EXPECT_EQ(0, memcmp(new_buf1.buffer, buf1, new_buf1.buf_size)); + + auto new_buf2 = decoder.DecodeBuffer(); + EXPECT_EQ(0, memcmp(new_buf2.buffer, buf2, new_buf2.buf_size)); + + auto buf_vec = decoder.DecodeBufferArray(); + EXPECT_EQ(2, buf_vec.size()); + EXPECT_EQ(0, memcmp(buf_vec[0].buffer, buf1, buf_vec[0].buf_size)); + EXPECT_EQ(0, memcmp(buf_vec[1].buffer, buf2, buf_vec[1].buf_size)); + auto float_vec = decoder.DecodeFloatArray(); EXPECT_EQ(0, memcmp(float_vec.data(), float_array, float_vec.size()*8)); diff --git a/integration/xgboost/processor/tests/test_main.cc b/integration/xgboost/encryption_plugins/tests/test_main.cc similarity index 100% rename from integration/xgboost/processor/tests/test_main.cc rename to integration/xgboost/encryption_plugins/tests/test_main.cc diff --git a/integration/xgboost/processor/tests/test_tenseal.py b/integration/xgboost/encryption_plugins/tests/test_tenseal.py similarity index 100% rename from integration/xgboost/processor/tests/test_tenseal.py rename to integration/xgboost/encryption_plugins/tests/test_tenseal.py diff --git a/integration/xgboost/processor/CMakeLists.txt b/integration/xgboost/processor/CMakeLists.txt deleted file mode 100644 index 056fd365e2..0000000000 --- a/integration/xgboost/processor/CMakeLists.txt +++ /dev/null @@ -1,46 +0,0 @@ -cmake_minimum_required(VERSION 3.19) -project(proc_nvflare LANGUAGES CXX C VERSION 1.0) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) - -option(GOOGLE_TEST "Build google tests" OFF) - -file(GLOB_RECURSE LIB_SRC "src/*.cc") - -add_library(proc_nvflare SHARED ${LIB_SRC}) -set_target_properties(proc_nvflare PROPERTIES - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON - ENABLE_EXPORTS ON -) -target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include) - -if (APPLE) - add_link_options("LINKER:-object_path_lto,$_lto.o") - add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") -endif () - -#-- Unit Tests -if(GOOGLE_TEST) - find_package(GTest REQUIRED) - enable_testing() - add_executable(proc_test) - target_link_libraries(proc_test PRIVATE proc_nvflare) - - - target_include_directories(proc_test PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include - ${XGB_SRC}/src - ${XGB_SRC}/rabit/include - ${XGB_SRC}/include - ${XGB_SRC}/dmlc-core/include - ${XGB_SRC}/tests) - - add_subdirectory(${proc_nvflare_SOURCE_DIR}/tests) - - add_test( - NAME TestProcessor - COMMAND proc_test - WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR}) - -endif() diff --git a/integration/xgboost/processor/README.md b/integration/xgboost/processor/README.md deleted file mode 100644 index e879081b84..0000000000 --- a/integration/xgboost/processor/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Build Instruction - -``` sh -cd NVFlare/integration/xgboost/processor -mkdir build -cd build -cmake .. -make -``` - -See [tests](./tests) for simple examples. \ No newline at end of file diff --git a/integration/xgboost/processor/src/dam/dam.cc b/integration/xgboost/processor/src/dam/dam.cc deleted file mode 100644 index 10625ab9b5..0000000000 --- a/integration/xgboost/processor/src/dam/dam.cc +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "dam.h" - -void print_buffer(uint8_t *buffer, int size) { - for (int i = 0; i < size; i++) { - auto c = buffer[i]; - std::cout << std::hex << (int) c << " "; - } - std::cout << std::endl << std::dec; -} - -// DamEncoder ====== -void DamEncoder::AddFloatArray(const std::vector &value) { - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size() * 8; - uint8_t *buffer = static_cast(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); -} - -void DamEncoder::AddIntArray(const std::vector &value) { - std::cout << "AddIntArray called, size: " << value.size() << std::endl; - if (encoded) { - std::cout << "Buffer is already encoded" << std::endl; - return; - } - auto buf_size = value.size()*8; - std::cout << "Allocating " << buf_size << " bytes" << std::endl; - uint8_t *buffer = static_cast(malloc(buf_size)); - memcpy(buffer, value.data(), buf_size); - // print_buffer(buffer, buf_size); - entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size())); -} - -std::vector DamEncoder::Finish(size_t &size) { - encoded = true; - - size = calculate_size(); - std::vector buf(size); - auto pointer = buf.data(); - memcpy(pointer, kSignature, strlen(kSignature)); - memcpy(pointer + 8, &size, 8); - memcpy(pointer + 16, &data_set_id, 8); - - pointer += kPrefixLen; - for (auto entry : *entries) { - memcpy(pointer, &entry->data_type, 8); - pointer += 8; - memcpy(pointer, &entry->size, 8); - pointer += 8; - int len = 8*entry->size; - memcpy(pointer, entry->pointer, len); - free(entry->pointer); - pointer += len; - // print_buffer(entry->pointer, entry->size*8); - } - - if ((pointer - buf.data()) != size) { - throw std::runtime_error{"Invalid encoded size: " + - std::to_string(pointer - buf.data())}; - } - - return buf; -} - -std::size_t DamEncoder::calculate_size() { - auto size = kPrefixLen; - - for (auto entry : *entries) { - size += 16; // The Type and Len - size += entry->size * 8; // All supported data types are 8 bytes - } - - return size; -} - - -// DamDecoder ====== - -DamDecoder::DamDecoder(std::uint8_t const *buffer, std::size_t size) { - this->buffer = buffer; - this->buf_size = size; - this->pos = buffer + kPrefixLen; - if (size >= kPrefixLen) { - memcpy(&len, buffer + 8, 8); - memcpy(&data_set_id, buffer + 16, 8); - } else { - len = 0; - data_set_id = 0; - } -} - -bool DamDecoder::IsValid() { - return buf_size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0; -} - -std::vector DamDecoder::DecodeIntArray() { - auto type = *reinterpret_cast(pos); - if (type != kDataTypeIntArray) { - std::cout << "Data type " << type << " doesn't match Int Array" - << std::endl; - return std::vector(); - } - pos += 8; - - auto len = *reinterpret_cast(pos); - pos += 8; - auto ptr = reinterpret_cast(pos); - pos += 8 * len; - return std::vector(ptr, ptr + len); -} - -std::vector DamDecoder::DecodeFloatArray() { - auto type = *reinterpret_cast(pos); - if (type != kDataTypeFloatArray) { - std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; - return std::vector(); - } - pos += 8; - - auto len = *reinterpret_cast(pos); - pos += 8; - - auto ptr = reinterpret_cast(pos); - pos += 8*len; - return std::vector(ptr, ptr + len); -} diff --git a/integration/xgboost/processor/src/include/dam.h b/integration/xgboost/processor/src/include/dam.h deleted file mode 100644 index 7afdf983af..0000000000 --- a/integration/xgboost/processor/src/include/dam.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include // for int64_t -#include // for size_t - -const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 -const int kPrefixLen = 24; - -const int kDataTypeInt = 1; -const int kDataTypeFloat = 2; -const int kDataTypeString = 3; -const int kDataTypeIntArray = 257; -const int kDataTypeFloatArray = 258; - -const int kDataTypeMap = 1025; - -class Entry { - public: - int64_t data_type; - uint8_t * pointer; - int64_t size; - - Entry(int64_t data_type, uint8_t *pointer, int64_t size) { - this->data_type = data_type; - this->pointer = pointer; - this->size = size; - } -}; - -class DamEncoder { - private: - bool encoded = false; - int64_t data_set_id; - std::vector *entries = new std::vector(); - - public: - explicit DamEncoder(int64_t data_set_id) { - this->data_set_id = data_set_id; - } - - void AddIntArray(const std::vector &value); - - void AddFloatArray(const std::vector &value); - - std::vector Finish(size_t &size); - - private: - std::size_t calculate_size(); -}; - -class DamDecoder { - private: - std::uint8_t const *buffer = nullptr; - std::size_t buf_size = 0; - std::uint8_t const *pos = nullptr; - std::size_t remaining = 0; - int64_t data_set_id = 0; - int64_t len = 0; - - public: - explicit DamDecoder(std::uint8_t const *buffer, std::size_t size); - - size_t Size() { - return len; - } - - int64_t GetDataSetId() { - return data_set_id; - } - - bool IsValid(); - - std::vector DecodeIntArray(); - - std::vector DecodeFloatArray(); -}; - -void print_buffer(uint8_t *buffer, int size); diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc deleted file mode 100644 index 3e742b14ef..0000000000 --- a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc +++ /dev/null @@ -1,378 +0,0 @@ -/** - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "nvflare_processor.h" - -#include "dam.h" // for DamEncoder -#include -#include // for copy_n, transform -#include // for memcpy -#include // for shared_ptr -#include // for invalid_argument -#include // for string_view -#include // for vector - -namespace nvflare { -namespace { -// The opaque type for the C handle. -using CHandleT = std::shared_ptr *; -// Actual representation used in C++ code base. -using HandleT = std::remove_pointer_t; - -std::string &GlobalErrorMsg() { - static thread_local std::string msg; - return msg; -} - -// Perform handle handling for C API functions. -template auto CApiGuard(FederatedPluginHandle handle, Fn &&fn) { - auto pptr = static_cast(handle); - if (!pptr) { - return 1; - } - - try { - if constexpr (std::is_void_v>) { - fn(*pptr); - return 0; - } else { - return fn(*pptr); - } - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return 1; - } -} -} // namespace - -TensealPlugin::TensealPlugin( - std::vector> const &args) { - if (!args.empty()) { - throw std::invalid_argument{"Invaid arguments for the tenseal plugin."}; - } -} - -void TensealPlugin::EncryptGPairs(float const *in_gpair, std::size_t n_in, - std::uint8_t **out_gpair, - std::size_t *n_out) { - std::vector pairs(n_in); - std::copy_n(in_gpair, n_in, pairs.begin()); - DamEncoder encoder(kDataSetHGPairs); - encoder.AddFloatArray(pairs); - encrypted_gpairs_ = encoder.Finish(*n_out); - if (!out_gpair) { - throw std::invalid_argument{"Invalid pointer to output gpair."}; - } - *out_gpair = encrypted_gpairs_.data(); - *n_out = encrypted_gpairs_.size(); -} - -void TensealPlugin::SyncEncryptedGPairs(std::uint8_t const *in_gpair, - std::size_t n_bytes, - std::uint8_t const **out_gpair, - std::size_t *out_n_bytes) { - *out_n_bytes = n_bytes; - *out_gpair = in_gpair; -} - -void TensealPlugin::ResetHistContext(std::uint32_t const *cutptrs, - std::size_t cutptr_len, - std::int32_t const *bin_idx, - std::size_t n_idx) { - // fixme: this doesn't have to be copied multiple times. - this->cut_ptrs_.resize(cutptr_len); - std::copy_n(cutptrs, cutptr_len, cut_ptrs_.begin()); - this->bin_idx_.resize(n_idx); - std::copy_n(bin_idx, n_idx, this->bin_idx_.begin()); -} - -void TensealPlugin::BuildEncryptedHistVert(std::size_t const **ridx, - std::size_t const *sizes, - std::int32_t const *nidx, - std::size_t len, - std::uint8_t** out_hist, - std::size_t* out_len) { - std::int64_t data_set_id; - if (!feature_sent_) { - data_set_id = kDataSetAggregationWithFeatures; - feature_sent_ = true; - } else { - data_set_id = kDataSetAggregation; - } - - DamEncoder encoder(data_set_id); - - // Add cuts pointers - std::vector cuts_vec(cut_ptrs_.cbegin(), cut_ptrs_.cend()); - encoder.AddIntArray(cuts_vec); - - auto num_features = cut_ptrs_.size() - 1; - auto num_samples = bin_idx_.size() / num_features; - - if (data_set_id == kDataSetAggregationWithFeatures) { - if (features_.empty()) { // when is it not empty? - for (std::size_t f = 0; f < num_features; f++) { - auto slot = bin_idx_[f]; - if (slot >= 0) { - // what happens if it's missing? - features_.push_back(f); - } - } - } - encoder.AddIntArray(features_); - - std::vector bins; - for (int i = 0; i < num_samples; i++) { - for (auto f : features_) { - auto index = f + i * num_features; - if (index > bin_idx_.size()) { - throw std::out_of_range{"Index is out of range: " + - std::to_string(index)}; - } - auto slot = bin_idx_[index]; - bins.push_back(slot); - } - } - encoder.AddIntArray(bins); - } - - // Add nodes to build - std::vector node_vec(len); - std::copy_n(nidx, len, node_vec.begin()); - encoder.AddIntArray(node_vec); - - // For each node, get the row_id/slot pair - for (std::size_t i = 0; i < len; ++i) { - std::vector rows(sizes[i]); - std::copy_n(ridx[i], sizes[i], rows.begin()); - encoder.AddIntArray(rows); - } - - std::size_t n{0}; - encrypted_hist_ = encoder.Finish(n); - - *out_hist = encrypted_hist_.data(); - *out_len = encrypted_hist_.size(); -} - -void TensealPlugin::SyncEncryptedHistVert(std::uint8_t *buffer, - std::size_t buf_size, double **out, - std::size_t *out_len) { - auto remaining = buf_size; - char *pointer = reinterpret_cast(buffer); - - // The buffer is concatenated by AllGather. It may contain multiple DAM - // buffers - std::vector &result = hist_; - result.clear(); - auto max_slot = cut_ptrs_.back(); - auto array_size = 2 * max_slot * sizeof(double); - // A new histogram array? - double *slots = static_cast(malloc(array_size)); - while (remaining > kPrefixLen) { - DamDecoder decoder(reinterpret_cast(pointer), remaining); - if (!decoder.IsValid()) { - std::cout << "Not DAM encoded buffer ignored at offset: " - << static_cast( - (pointer - reinterpret_cast(buffer))) - << std::endl; - break; - } - auto size = decoder.Size(); - auto node_list = decoder.DecodeIntArray(); - for (auto node : node_list) { - std::memset(slots, 0, array_size); - auto feature_list = decoder.DecodeIntArray(); - // Convert per-feature histo to a flat one - for (auto f : feature_list) { - auto base = cut_ptrs_[f]; // cut pointer for the current feature - auto bins = decoder.DecodeFloatArray(); - auto n = bins.size() / 2; - for (int i = 0; i < n; i++) { - auto index = base + i; - // [Q] Build local histogram? Why does it need to be built here? - slots[2 * index] += bins[2 * i]; - slots[2 * index + 1] += bins[2 * i + 1]; - } - } - result.insert(result.end(), slots, slots + 2 * max_slot); - } - remaining -= size; - pointer += size; - } - free(slots); - - *out_len = result.size(); - *out = result.data(); -} - -void TensealPlugin::BuildEncryptedHistHori(double const *in_histogram, - std::size_t len, - std::uint8_t **out_hist, - std::size_t *out_len) { - DamEncoder encoder(kDataSetHistograms); - std::vector copy(in_histogram, in_histogram + len); - encoder.AddFloatArray(copy); - - std::size_t size{0}; - this->encrypted_hist_ = encoder.Finish(size); - - *out_hist = this->encrypted_hist_.data(); - *out_len = this->encrypted_hist_.size(); -} - -void TensealPlugin::SyncEncryptedHistHori(std::uint8_t const *buffer, - std::size_t len, double **out_hist, - std::size_t *out_len) { - DamDecoder decoder(reinterpret_cast(buffer), len); - if (!decoder.IsValid()) { - std::cout << "Not DAM encoded buffer, ignored" << std::endl; - } - - if (decoder.GetDataSetId() != kDataSetHistogramResult) { - throw std::runtime_error{"Invalid dataset: " + - std::to_string(decoder.GetDataSetId())}; - } - this->hist_ = decoder.DecodeFloatArray(); - *out_hist = this->hist_.data(); - *out_len = this->hist_.size(); -} -} // namespace nvflare - -#if defined(_MSC_VER) || defined(_WIN32) -#define NVF_C __declspec(dllexport) -#else -#define NVF_C __attribute__((visibility("default"))) -#endif // defined(_MSC_VER) || defined(_WIN32) - -extern "C" { -NVF_C char const *FederatedPluginErrorMsg() { - return nvflare::GlobalErrorMsg().c_str(); -} - -FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) { - using namespace nvflare; - try { - CHandleT pptr = new std::shared_ptr; - std::vector> args; - std::transform( - argv, argv + argc, std::back_inserter(args), [](char const *carg) { - // Split a key value pair in contructor argument: `key=value` - std::string_view arg{carg}; - auto idx = arg.find('='); - if (idx == std::string_view::npos) { - // `=` not found - throw std::invalid_argument{"Invalid argument:" + std::string{arg}}; - } - auto key = arg.substr(0, idx); - auto value = arg.substr(idx + 1); - return std::make_pair(key, value); - }); - *pptr = std::make_shared(args); - return pptr; - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return nullptr; - } -} - -int NVF_C FederatedPluginClose(FederatedPluginHandle handle) { - using namespace nvflare; - auto pptr = static_cast(handle); - if (!pptr) { - return 1; - } - - try { - delete pptr; - } catch (std::exception const &e) { - GlobalErrorMsg() = e.what(); - return 1; - } - return 0; -} - -int NVF_C FederatedPluginEncryptGPairs(FederatedPluginHandle handle, - float const *in_gpair, size_t n_in, - uint8_t **out_gpair, size_t *n_out) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->EncryptGPairs(in_gpair, n_in, out_gpair, n_out); - return 0; - }); -} - -int NVF_C FederatedPluginSyncEncryptedGPairs(FederatedPluginHandle handle, - uint8_t const *in_gpair, - size_t n_bytes, - uint8_t const **out_gpair, - size_t *n_out) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedGPairs(in_gpair, n_bytes, out_gpair, n_out); - }); -} - -int NVF_C FederatedPluginResetHistContextVert(FederatedPluginHandle handle, - uint32_t const *cutptrs, - size_t cutptr_len, - int32_t const *bin_idx, - size_t n_idx) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->ResetHistContext(cutptrs, cutptr_len, bin_idx, n_idx); - }); -} - -int NVF_C FederatedPluginBuildEncryptedHistVert( - FederatedPluginHandle handle, uint64_t const **ridx, size_t const *sizes, - int32_t const *nidx, size_t len, uint8_t **out_hist, size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->BuildEncryptedHistVert(ridx, sizes, nidx, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginSyncEnrcyptedHistVert(FederatedPluginHandle handle, - uint8_t *in_hist, size_t len, - double **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedHistVert(in_hist, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginBuildEncryptedHistHori(FederatedPluginHandle handle, - double const *in_hist, - size_t len, uint8_t **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->BuildEncryptedHistHori(in_hist, len, out_hist, out_len); - }); -} - -int NVF_C FederatedPluginSyncEnrcyptedHistHori(FederatedPluginHandle handle, - uint8_t const *in_hist, - size_t len, double **out_hist, - size_t *out_len) { - using namespace nvflare; - return CApiGuard(handle, [&](HandleT plugin) { - plugin->SyncEncryptedHistHori(in_hist, len, out_hist, out_len); - return 0; - }); -} -} // extern "C" diff --git a/integration/xgboost/processor/tests/CMakeLists.txt b/integration/xgboost/processor/tests/CMakeLists.txt deleted file mode 100644 index 893d8738dc..0000000000 --- a/integration/xgboost/processor/tests/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -file(GLOB_RECURSE TEST_SOURCES "*.cc") - -target_sources(proc_test PRIVATE ${TEST_SOURCES}) - -target_include_directories(proc_test - PRIVATE - ${GTEST_INCLUDE_DIRS} - ${proc_nvflare_SOURCE_DIR/tests} - ${proc_nvflare_SOURCE_DIR}/src) - -message("Include Dir: ${GTEST_INCLUDE_DIRS}") -target_link_libraries(proc_test - PRIVATE - ${GTEST_LIBRARIES}) diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 9b9b844774..52005ec305 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -42,7 +42,6 @@ class ReturnCode(object): EARLY_TERMINATION = "EARLY_TERMINATION" SERVER_NOT_READY = "SERVER_NOT_READY" SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE" - EARLY_TERMINATION = "EARLY_TERMINATION" class MachineStatus(Enum): @@ -494,6 +493,7 @@ class SystemVarName: JOB_ID = "JOB_ID" # Job ID ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server) SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode + JOB_CUSTOM_DIR = "JOB_CUSTOM_DIR" # custom dir of the job class RunnerTask: diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index 71f7365847..bb5e01792c 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -98,6 +98,9 @@ def __init__(self, topic, request_handler_f, executor, per_msg_timeout, tx_timeo self.replying = False def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + if not ReliableMessage.is_available(): + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + self.tx_id = request.get_header(HEADER_TX_ID) op = request.get_header(HEADER_OP) peer_ctx = fl_ctx.get_peer_context() @@ -111,9 +114,14 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable: self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT) # start processing - ReliableMessage.debug(fl_ctx, f"started processing request of topic {self.topic}") - self.executor.submit(self._do_request, request, fl_ctx) - return _status_reply(STATUS_IN_PROCESS) # ack + ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}") + try: + self.executor.submit(self._do_request, request, fl_ctx) + return _status_reply(STATUS_IN_PROCESS) # ack + except Exception as ex: + # it is possible that the RM is already closed (self.executor is shut down) + ReliableMessage.error(fl_ctx, f"failed to submit request: {secure_format_exception(ex)}") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) elif self.result: # we already finished processing - send the result back ReliableMessage.info(fl_ctx, "resend result back to requester") @@ -169,6 +177,8 @@ def _try_reply(self, fl_ctx: FLContext): # release the receiver kept by the ReliableMessage! ReliableMessage.release_request_receiver(self, fl_ctx) else: + # unsure whether the reply was sent successfully + # do not release the request receiver in case the requester asks for result in a query ReliableMessage.error( fl_ctx, f"failed to send reply in {time_spent} secs: {rc=}; will wait for requester to query" ) @@ -192,6 +202,8 @@ def _do_request(self, request: Shareable, fl_ctx: FLContext): class _ReplyReceiver: + """This class handles reliable message replies on the sending end""" + def __init__(self, tx_id: str, per_msg_timeout: float, tx_timeout: float): self.tx_id = tx_id self.tx_start_time = time.time() @@ -425,6 +437,21 @@ def warning(cls, fl_ctx: FLContext, msg: str): def error(cls, fl_ctx: FLContext, msg: str): cls._logger.error(cls._log_msg(fl_ctx, msg)) + @classmethod + def is_available(cls): + """Return whether the ReliableMessage service is available + + Returns: + + """ + if cls._shutdown_asked: + return False + + if not cls._enabled: + return False + + return True + @classmethod def debug(cls, fl_ctx: FLContext, msg: str): cls._logger.debug(cls._log_msg(fl_ctx, msg)) @@ -614,7 +641,7 @@ def _query_result( fl_ctx=fl_ctx, ) - # Ignore query result if result is already received + # Ignore query result if reply result is already received if receiver.result_ready.is_set(): return receiver.result diff --git a/nvflare/app_common/metrics_exchange/metrics_sender.py b/nvflare/app_common/metrics_exchange/metrics_sender.py index 9052104a78..f8a4d5861a 100644 --- a/nvflare/app_common/metrics_exchange/metrics_sender.py +++ b/nvflare/app_common/metrics_exchange/metrics_sender.py @@ -31,16 +31,28 @@ def __init__( read_interval: float = 0.1, heartbeat_interval: float = 5.0, heartbeat_timeout: float = 30.0, - topic: str = "metrics", pipe_channel_name=PipeChannelName.METRIC, ): + """MetricsSender is a special type of AnalyticsSender that uses `Pipe` to communicate. + + Args: + pipe_id (str): Identifier for obtaining the Pipe from NVFlare components. + read_interval (float): Interval for reading from the pipe. + heartbeat_interval (float): Interval for sending heartbeat to the peer. + heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer. + pipe_channel_name: the channel name for sending task requests. + + Note: + Users can use MetricsSender with `FilePipe`, `CellPipe`, or any other customize + `Pipe` class. + + """ super().__init__() self._pipe_id = pipe_id self._read_interval = read_interval self._heartbeat_interval = heartbeat_interval self._heartbeat_timeout = heartbeat_timeout self._pipe_handler = None - self._topic = topic self._pipe_channel_name = pipe_channel_name def handle_event(self, event_type: str, fl_ctx: FLContext): @@ -64,5 +76,5 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): def add(self, tag: str, value: Any, data_type: AnalyticsDataType, **kwargs): data = create_analytic_dxo(tag=tag, value=value, data_type=data_type, **kwargs) - req = Message.new_request(topic=self._topic, data=data) + req = Message.new_request(topic="_metrics_sender", data=data) self._pipe_handler.send_to_peer(req) diff --git a/nvflare/app_common/storages/filesystem_storage.py b/nvflare/app_common/storages/filesystem_storage.py index b1d5547a0f..bfa82654f0 100644 --- a/nvflare/app_common/storages/filesystem_storage.py +++ b/nvflare/app_common/storages/filesystem_storage.py @@ -29,6 +29,20 @@ def _write(path: str, content, mv_file=True): + """Create a file at the specified 'path' with the specified 'content'. + + Args: + path: the path of the file to be created + content: content for the file to be created. It could be either bytes, or path (str) to the source file that + contains the content. + mv_file: whether the destination file should be created simply by moving the source file. This is applicable + only when the 'content' is the path of the source file. If mv_file is False, the destination is created + by copying from the source file, and the source file will remain intact; If mv_file is True, the + destination file is created by "move" the source file, and the original source file will no longer exist. + + Returns: + + """ tmp_path = path + "_" + str(uuid.uuid4()) try: Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) @@ -162,7 +176,7 @@ def clone_object(self, from_uri: str, to_uri: str, meta: dict, overwrite_existin from_full_uri = self._object_path(from_uri) from_data_path = os.path.join(from_full_uri, DATA) - _write(data_path, from_data_path) + _write(data_path, from_data_path, mv_file=False) meta_path = os.path.join(full_uri, META) try: diff --git a/nvflare/app_common/tie/__init__.py b/nvflare/app_common/tie/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_common/tie/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/tie/applet.py b/nvflare/app_common/tie/applet.py new file mode 100644 index 0000000000..228e08a0c8 --- /dev/null +++ b/nvflare/app_common/tie/applet.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext + + +class Applet(ABC, FLComponent): + + """An Applet implements App (server or client) processing logic.""" + + def __init__(self): + FLComponent.__init__(self) + + def initialize(self, fl_ctx: FLContext): + """Called by Controller/Executor to initialize the applet. + This happens when the job is about to start. + + Args: + fl_ctx: FL context + + Returns: None + + """ + pass + + @abstractmethod + def start(self, app_ctx: dict): + """Called to start the execution of the applet. + + Args: + app_ctx: the contextual info to help the applet execution + + Returns: None + + """ + pass + + @abstractmethod + def stop(self, timeout=0.0) -> int: + """Called to stop the applet. + + Args: + timeout: the max amount of time (seconds) to stop the applet + + Returns: the exit code after stopped + + """ + pass + + @abstractmethod + def is_stopped(self) -> (bool, int): + """Called to check whether the applet is already stopped. + + Returns: whether the applet is stopped, and the exit code if stopped. + + """ + pass diff --git a/nvflare/app_common/tie/cli_applet.py b/nvflare/app_common/tie/cli_applet.py new file mode 100644 index 0000000000..cd17bf0eb7 --- /dev/null +++ b/nvflare/app_common/tie/cli_applet.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from abc import ABC, abstractmethod + +from nvflare.security.logging import secure_format_exception + +from .applet import Applet +from .defs import Constant +from .process_mgr import CommandDescriptor, start_process + + +class CLIApplet(Applet, ABC): + def __init__(self): + """Constructor of CLIApplet, which runs the applet as a subprocess started with CLI command.""" + Applet.__init__(self) + self._proc_mgr = None + self._start_error = False + + @abstractmethod + def get_command(self, app_ctx: dict) -> CommandDescriptor: + """Subclass must implement this method to return the CLI command to be executed. + + Args: + app_ctx: the applet context that contains execution env info + + Returns: a CommandDescriptor that describes the CLI command + + """ + pass + + def start(self, app_ctx: dict): + """Start the execution of the applet. + + Args: + app_ctx: the applet run context + + Returns: + + """ + cmd_desc = self.get_command(app_ctx) + if not cmd_desc: + raise RuntimeError("failed to get cli command from app context") + + fl_ctx = app_ctx.get(Constant.APP_CTX_FL_CONTEXT) + try: + self._proc_mgr = start_process(cmd_desc, fl_ctx) + except Exception as ex: + self.logger.error(f"exception starting applet '{cmd_desc.cmd}': {secure_format_exception(ex)}") + self._start_error = True + + def stop(self, timeout=0.0) -> int: + """Stop the applet + + Args: + timeout: amount of time to wait for the applet to stop by itself. If the applet does not stop on + its own within this time, we'll forcefully stop it by kill. + + Returns: exit code + + """ + mgr = self._proc_mgr + self._proc_mgr = None + + if not mgr: + raise RuntimeError("no process manager to stop") + + if timeout > 0: + # wait for the applet to stop by itself + start = time.time() + while time.time() - start < timeout: + rc = mgr.poll() + if rc is not None: + # already stopped + self.logger.info(f"applet stopped ({rc=}) after {time.time()-start} seconds") + break + time.sleep(0.1) + + rc = mgr.stop() + if rc is None: + self.logger.warning(f"killed the applet process after waiting {timeout} seconds") + return -9 + else: + return rc + + def is_stopped(self) -> (bool, int): + if self._start_error: + return True, Constant.EXIT_CODE_CANT_START + + mgr = self._proc_mgr + if mgr: + return_code = mgr.poll() + if return_code is None: + return False, 0 + else: + return True, return_code + else: + return True, 0 diff --git a/nvflare/app_common/tie/connector.py b/nvflare/app_common/tie/connector.py new file mode 100644 index 0000000000..8afa86aedb --- /dev/null +++ b/nvflare/app_common/tie/connector.py @@ -0,0 +1,264 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from abc import ABC, abstractmethod +from typing import Optional + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.apis.utils.reliable_message import ReliableMessage +from nvflare.app_common.tie.applet import Applet +from nvflare.app_common.tie.defs import Constant +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.utils.validation_utils import check_object_type + + +class Connector(ABC, FLComponent): + """ + Connectors are used to integrate FLARE with an Applet (Server or Client) in run time. + Each type of applet requires an appropriate connector to integrate it with FLARE's Controller or Executor. + The Connector class defines commonly required methods for all Connector implementations. + """ + + def __init__(self): + """Constructor of Connector""" + FLComponent.__init__(self) + self.abort_signal = None + self.applet = None + self.engine = None + + def set_applet(self, applet: Applet): + """Set the applet that will be used to run app processing logic. + Note that the connector is only responsible for starting the applet appropriately (in a separate thread or in a + separate process). + + Args: + applet: the applet to be set + + Returns: None + + """ + if not isinstance(applet, Applet): + raise TypeError(f"applet must be Applet but got {type(applet)}") + self.applet = applet + + def set_abort_signal(self, abort_signal: Signal): + """Called by Controller/Executor to set the abort_signal. + + The abort_signal is assigned by FLARE Controller/Executor. It is used by the Controller/Executor + to tell the connector that the job has been aborted. + + Args: + abort_signal: the abort signal assigned by the caller. + + Returns: None + + """ + check_object_type("abort_signal", abort_signal, Signal) + self.abort_signal = abort_signal + + def initialize(self, fl_ctx: FLContext): + """Called by the Controller/Executor to initialize the connector. + + Args: + fl_ctx: the FL context + + Returns: None + + """ + self.engine = fl_ctx.get_engine() + + @abstractmethod + def start(self, fl_ctx: FLContext): + """Called by Controller/Executor to start the connector. + If any error occurs, this method should raise an exception. + + Args: + fl_ctx: the FL context. + + Returns: None + + """ + pass + + @abstractmethod + def stop(self, fl_ctx: FLContext): + """Called by Controller/Executor to stop the connector. + If any error occurs, this method should raise an exception. + + Args: + fl_ctx: the FL context. + + Returns: None + + """ + pass + + @abstractmethod + def configure(self, config: dict, fl_ctx: FLContext): + """Called by Controller/Executor to configure the connector. + If any error occurs, this method should raise an exception. + + Args: + config: config data + fl_ctx: the FL context + + Returns: None + + """ + pass + + def _is_stopped(self) -> (bool, int): + """Called by the connector's monitor to know whether the connector is stopped. + Note that this method is not called by Controller/Executor. + + Returns: a tuple of: whether the connector is stopped, and return code (if stopped) + + Note that a non-zero return code is considered abnormal completion of the connector. + + """ + return self.is_applet_stopped() + + def _monitor(self, fl_ctx: FLContext, connector_stopped_cb): + while True: + if self.abort_signal.triggered: + # asked to abort + self.stop(fl_ctx) + return + + stopped, rc = self._is_stopped() + if stopped: + # connector already stopped - notify the caller + connector_stopped_cb(rc, fl_ctx) + return + + time.sleep(0.1) + + def monitor(self, fl_ctx: FLContext, connector_stopped_cb): + """Called by Controller/Executor to monitor the health of the connector. + + The monitor periodically checks the abort signal. Once set, it calls the connector's stop() method + to stop the running of the app. + + The monitor also periodically checks whether the connector is already stopped (by calling the is_stopped + method). If the connector is stopped, the monitor will call the specified connector_stopped_cb. + + Args: + fl_ctx: FL context + connector_stopped_cb: the callback function to be called when the connector is stopped. + + Returns: None + + """ + if not callable(connector_stopped_cb): + raise RuntimeError(f"connector_stopped_cb must be callable but got {type(connector_stopped_cb)}") + + # start the monitor in a separate daemon thread! + t = threading.Thread(target=self._monitor, args=(fl_ctx, connector_stopped_cb), daemon=True) + t.start() + + def start_applet(self, app_ctx: dict, fl_ctx: FLContext): + """Start the applet set to the connector. + + Args: + app_ctx: the contextual info for running the applet + fl_ctx: FL context + + Returns: None + + """ + if not self.applet: + raise RuntimeError("applet has not been set!") + + app_ctx[Constant.APP_CTX_FL_CONTEXT] = fl_ctx + self.applet.start(app_ctx) + + def stop_applet(self, timeout=0.0) -> int: + """Stop the running of the applet + + Returns: exit code of the applet + + """ + return self.applet.stop(timeout) + + def is_applet_stopped(self) -> (bool, int): + """Check whether the applet is already stopped + + Returns: a tuple of (whether the applet is stopped, exit code) + + """ + applet = self.applet + if applet: + return applet.is_stopped() + else: + self.logger.warning("applet is not set with the connector") + return True, 0 + + def send_request( + self, + target: Optional[str], + op: str, + request: Shareable, + per_msg_timeout: float, + tx_timeout: float, + fl_ctx: Optional[FLContext], + ) -> Shareable: + """Send app request to the specified target via FLARE ReliableMessage. + + Args: + target: the destination of the request. If not specified, default to server. + op: the operation + request: operation data + per_msg_timeout: per-message timeout + tx_timeout: transaction timeout + fl_ctx: FL context. If not provided, this method will create a new FL context. + + Returns: + operation result + """ + request.set_header(Constant.MSG_KEY_OP, op) + if not target: + target = FQCN.ROOT_SERVER + + if not fl_ctx: + fl_ctx = self.engine.new_context() + + self.logger.debug(f"sending request with RM: {op=}") + return ReliableMessage.send_request( + target=target, + topic=Constant.TOPIC_APP_REQUEST, + request=request, + per_msg_timeout=per_msg_timeout, + tx_timeout=tx_timeout, + abort_signal=self.abort_signal, + fl_ctx=fl_ctx, + ) + + def process_app_request(self, op: str, req: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + """Called by Controller/Executor to process a request from an applet on another site. + + Args: + op: the op code of the request + req: the request to be sent + fl_ctx: FL context + abort_signal: abort signal that could be triggered during the request processing + + Returns: processing result as Shareable object + + """ + pass diff --git a/nvflare/app_common/tie/controller.py b/nvflare/app_common/tie/controller.py new file mode 100644 index 0000000000..0ebb39b93e --- /dev/null +++ b/nvflare/app_common/tie/controller.py @@ -0,0 +1,565 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from abc import ABC, abstractmethod + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.apis.utils.reliable_message import ReliableMessage +from nvflare.app_common.tie.connector import Connector +from nvflare.fuel.utils.validation_utils import check_number_range, check_positive_number +from nvflare.security.logging import secure_format_exception + +from .applet import Applet +from .defs import Constant + + +class _ClientStatus: + """ + Objects of this class keep processing status of each FL client during job execution. + """ + + def __init__(self): + # Set when the client's config reply is received and the reply return code is OK. + # If the client failed to reply or the return code is not OK, this value is not set. + self.configured_time = None + + # Set when the client's start reply is received and the reply return code is OK. + # If the client failed to reply or the return code is not OK, this value is not set. + self.started_time = None + + # operation of the last request from this client + self.last_op = None + + # time of the last op request from this client + self.last_op_time = time.time() + + # whether the app process is finished on this client + self.app_done = False + + +class TieController(Controller, ABC): + def __init__( + self, + configure_task_name=Constant.CONFIG_TASK_NAME, + configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, + start_task_name=Constant.START_TASK_NAME, + start_task_timeout=Constant.START_TASK_TIMEOUT, + job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, + max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + ): + """ + Constructor + + Args: + configure_task_name - name of the config task + configure_task_timeout - time to wait for clients’ responses to the config task before timeout. + start_task_name - name of the start task + start_task_timeout - time to wait for clients’ responses to the start task before timeout. + job_status_check_interval - how often to check client statuses of the job + max_client_op_interval - max amount of time allowed between app ops from a client + progress_timeout- the maximum amount of time allowed for the workflow to not make any progress. + In other words, at least one participating client must have made progress during this time. + Otherwise, the workflow will be considered to be in trouble and the job will be aborted. + """ + Controller.__init__(self) + self.configure_task_name = configure_task_name + self.start_task_name = start_task_name + self.start_task_timeout = start_task_timeout + self.configure_task_timeout = configure_task_timeout + self.max_client_op_interval = max_client_op_interval + self.progress_timeout = progress_timeout + self.job_status_check_interval = job_status_check_interval + + self.connector = None + self.participating_clients = None + self.status_lock = threading.Lock() + self.client_statuses = {} # client name => ClientStatus + self.abort_signal = None + + check_number_range("configure_task_timeout", configure_task_timeout, min_value=1) + check_number_range("start_task_timeout", start_task_timeout, min_value=1) + check_positive_number("job_status_check_interval", job_status_check_interval) + check_number_range("max_client_op_interval", max_client_op_interval, min_value=10.0) + check_number_range("progress_timeout", progress_timeout, min_value=5.0) + + @abstractmethod + def get_client_config_params(self, fl_ctx: FLContext) -> dict: + """Called by the TieController to get config parameters to be sent to FL clients. + Subclass of TieController must implement this method. + + Args: + fl_ctx: FL context + + Returns: a dict of config params + + """ + pass + + @abstractmethod + def get_connector_config_params(self, fl_ctx: FLContext) -> dict: + """Called by the TieController to get config parameters for configuring the connector. + Subclass of TieController must implement this method. + + Args: + fl_ctx: FL context + + Returns: a dict of config params + + """ + pass + + @abstractmethod + def get_connector(self, fl_ctx: FLContext) -> Connector: + """Called by the TieController to get the Connector to be used with the controller. + Subclass of TieController must implement this method. + + Args: + fl_ctx: FL context + + Returns: a Connector object + + """ + pass + + @abstractmethod + def get_applet(self, fl_ctx: FLContext) -> Applet: + """Called by the TieController to get the Applet to be used with the controller. + Subclass of TieController must implement this method. + + Args: + fl_ctx: FL context + + Returns: an Applet object + + """ + pass + + def start_controller(self, fl_ctx: FLContext): + """Start the controller. + It first tries to get the connector and applet to be used. + It then initializes the applet, set the applet to the connector, and initializes the connector. + It finally registers message handlers for APP_REQUEST and CLIENT_DONE. + If error occurs in any step, the job is stopped. + + Note: if a subclass overwrites this method, it must call super().start_controller()! + + Args: + fl_ctx: the FL context + + Returns: None + + """ + all_clients = self._engine.get_clients() + self.participating_clients = [t.name for t in all_clients] + + for c in self.participating_clients: + self.client_statuses[c] = _ClientStatus() + + connector = self.get_connector(fl_ctx) + if not connector: + self.system_panic("cannot get connector", fl_ctx) + return None + + if not isinstance(connector, Connector): + self.system_panic( + f"invalid connector: expect Connector but got {type(connector)}", + fl_ctx, + ) + return None + + applet = self.get_applet(fl_ctx) + if not applet: + self.system_panic("cannot get applet", fl_ctx) + return + + if not isinstance(applet, Applet): + self.system_panic( + f"invalid applet: expect Applet but got {type(applet)}", + fl_ctx, + ) + return + + applet.initialize(fl_ctx) + connector.set_applet(applet) + connector.initialize(fl_ctx) + self.connector = connector + + engine = fl_ctx.get_engine() + engine.register_aux_message_handler( + topic=Constant.TOPIC_CLIENT_DONE, + message_handle_func=self._process_client_done, + ) + ReliableMessage.register_request_handler( + topic=Constant.TOPIC_APP_REQUEST, + handler_f=self._handle_app_request, + fl_ctx=fl_ctx, + ) + + def _trigger_stop(self, fl_ctx: FLContext, error=None): + # first trigger the abort_signal to tell all components (mainly the controller's control_flow and connector) + # that check this signal to abort. + if self.abort_signal: + self.abort_signal.trigger(value=True) + + # if there is error, call system_panic to terminate the job with proper status. + # if no error, the job will end normally. + if error: + self.system_panic(reason=error, fl_ctx=fl_ctx) + + def _is_stopped(self): + # check whether the abort signal is triggered + return self.abort_signal and self.abort_signal.triggered + + def _update_client_status(self, fl_ctx: FLContext, op=None, client_done=False): + """Update the status of the requesting client. + + Args: + fl_ctx: FL context + op: the app operation requested + client_done: whether the client is done + + Returns: None + + """ + with self.status_lock: + peer_ctx = fl_ctx.get_peer_context() + if not peer_ctx: + self.log_error(fl_ctx, "missing peer_ctx from fl_ctx") + return + if not isinstance(peer_ctx, FLContext): + self.log_error(fl_ctx, f"expect peer_ctx to be FLContext but got {type(peer_ctx)}") + return + client_name = peer_ctx.get_identity_name() + if not client_name: + self.log_error(fl_ctx, "missing identity from peer_ctx") + return + status = self.client_statuses.get(client_name) + if not status: + self.log_error(fl_ctx, f"no status record for client {client_name}") + assert isinstance(status, _ClientStatus) + if op: + status.last_op = op + if client_done: + status.app_done = client_done + status.last_op_time = time.time() + + def _process_client_done(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + """Process the ClientDone report for a client + + Args: + topic: topic of the message + request: request to be processed + fl_ctx: the FL context + + Returns: reply to the client + + """ + self.log_debug(fl_ctx, f"_process_client_done {topic}") + exit_code = request.get(Constant.MSG_KEY_EXIT_CODE) + + if exit_code == 0: + self.log_info(fl_ctx, f"app client is done with exit code {exit_code}") + elif exit_code == Constant.EXIT_CODE_CANT_START: + self.log_error(fl_ctx, f"app client failed to start (exit code {exit_code})") + self.system_panic("app client failed to start", fl_ctx) + else: + # Should we stop here? + # Problem is that even if the exit_code is not 0, we can't say the job failed. + self.log_warning(fl_ctx, f"app client is done with exit code {exit_code}") + + self._update_client_status(fl_ctx, client_done=True) + return make_reply(ReturnCode.OK) + + def _handle_app_request(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + """Handle app request from applets on other sites + It calls the connector to process the app request. If the connector fails to process the request, the + job will be stopped. + + Args: + topic: message topic + request: the request data + fl_ctx: FL context + + Returns: processing result as a Shareable object + + """ + self.log_debug(fl_ctx, f"_handle_app_request {topic}") + op = request.get_header(Constant.MSG_KEY_OP) + if self._is_stopped(): + self.log_warning(fl_ctx, f"dropped app request ({op=}) since server is already stopped") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + # we assume app protocol to be very strict, we'll stop the control flow when any error occurs + process_error = "app request process error" + self._update_client_status(fl_ctx, op=op) + try: + reply = self.connector.process_app_request(op, request, fl_ctx, self.abort_signal) + except Exception as ex: + self.log_exception(fl_ctx, f"exception processing app request {op=}: {secure_format_exception(ex)}") + self._trigger_stop(fl_ctx, process_error) + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + self.log_info(fl_ctx, f"received reply for app request '{op=}'") + reply.set_header(Constant.MSG_KEY_OP, op) + return reply + + def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext): + self.log_info(fl_ctx, f"Configuring clients {self.participating_clients}") + + try: + config = self.get_client_config_params(fl_ctx) + except Exception as ex: + self.system_panic(f"exception get_client_config_params: {secure_format_exception(ex)}", fl_ctx) + return False + + if config is None: + self.system_panic("no config data is returned", fl_ctx) + return False + + shareable = Shareable() + shareable[Constant.MSG_KEY_CONFIG] = config + + task = Task( + name=self.configure_task_name, + data=shareable, + timeout=self.configure_task_timeout, + result_received_cb=self._process_configure_reply, + ) + + self.log_info(fl_ctx, f"sending task {self.configure_task_name} to clients {self.participating_clients}") + start_time = time.time() + self.broadcast_and_wait( + task=task, + targets=self.participating_clients, + min_responses=len(self.participating_clients), + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + time_taken = time.time() - start_time + self.log_info(fl_ctx, f"client configuration took {time_taken} seconds") + + failed_clients = [] + for c, cs in self.client_statuses.items(): + assert isinstance(cs, _ClientStatus) + if not cs.configured_time: + failed_clients.append(c) + + # if any client failed to configure, terminate the job + if failed_clients: + self.system_panic(f"failed to configure clients {failed_clients}", fl_ctx) + return False + + self.log_info(fl_ctx, f"successfully configured clients {self.participating_clients}") + return True + + def _start_clients(self, abort_signal: Signal, fl_ctx: FLContext): + self.log_info(fl_ctx, f"Starting clients {self.participating_clients}") + + task = Task( + name=self.start_task_name, + data=Shareable(), + timeout=self.start_task_timeout, + result_received_cb=self._process_start_reply, + ) + + self.log_info(fl_ctx, f"sending task {self.start_task_name} to clients {self.participating_clients}") + start_time = time.time() + self.broadcast_and_wait( + task=task, + targets=self.participating_clients, + min_responses=len(self.participating_clients), + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + time_taken = time.time() - start_time + self.log_info(fl_ctx, f"client starting took {time_taken} seconds") + + failed_clients = [] + for c, cs in self.client_statuses.items(): + assert isinstance(cs, _ClientStatus) + if not cs.started_time: + failed_clients.append(c) + + # if any client failed to start, terminate the job + if failed_clients: + self.system_panic(f"failed to start clients {failed_clients}", fl_ctx) + return False + + self.log_info(fl_ctx, f"successfully started clients {self.participating_clients}") + return True + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + """ + To ensure smooth app execution: + - ensure that all clients are online and ready to go before starting server + - ensure that server is started and ready to take requests before asking clients to start operation + - monitor the health of the clients + - if anything goes wrong, terminate the job + + Args: + abort_signal: abort signal that is used to notify components to abort + fl_ctx: FL context + + Returns: None + + """ + self.abort_signal = abort_signal + + # the connector uses the same abort signal! + self.connector.set_abort_signal(abort_signal) + + # wait for every client to become online and properly configured + self.log_info(fl_ctx, f"Waiting for clients to be ready: {self.participating_clients}") + + # configure all clients + if not self._configure_clients(abort_signal, fl_ctx): + self.system_panic("failed to configure all clients", fl_ctx) + return + + # configure and start the connector + try: + config = self.get_connector_config_params(fl_ctx) + self.connector.configure(config, fl_ctx) + self.log_info(fl_ctx, "starting connector ...") + self.connector.start(fl_ctx) + except Exception as ex: + error = f"failed to start connector: {secure_format_exception(ex)}" + self.log_error(fl_ctx, error) + self.system_panic(error, fl_ctx) + return + + self.connector.monitor(fl_ctx, self._app_stopped) + + # start all clients + if not self._start_clients(abort_signal, fl_ctx): + self.system_panic("failed to start all clients", fl_ctx) + return + + # monitor client health + # we periodically check job status until all clients are done or the system is stopped + self.log_info(fl_ctx, "Waiting for clients to finish ...") + while not self._is_stopped(): + done = self._check_job_status(fl_ctx) + if done: + break + time.sleep(self.job_status_check_interval) + + def _app_stopped(self, rc, fl_ctx: FLContext): + # This CB is called when app server is stopped + error = None + if rc != 0: + self.log_error(fl_ctx, f"App Server stopped abnormally with code {rc}") + error = "App server abnormal stop" + + # the app server could stop at any moment, we trigger the abort_signal in case it is checked by any + # other components + self._trigger_stop(fl_ctx, error) + + def _process_configure_reply(self, client_task: ClientTask, fl_ctx: FLContext): + result = client_task.result + client_name = client_task.client.name + + rc = result.get_return_code() + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"successfully configured client {client_name}") + cs = self.client_statuses.get(client_name) + if cs: + assert isinstance(cs, _ClientStatus) + cs.configured_time = time.time() + else: + self.log_error(fl_ctx, f"client {client_task.client.name} failed to configure: {rc}") + + def _process_start_reply(self, client_task: ClientTask, fl_ctx: FLContext): + result = client_task.result + client_name = client_task.client.name + + rc = result.get_return_code() + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"successfully started client {client_name}") + cs = self.client_statuses.get(client_name) + if cs: + assert isinstance(cs, _ClientStatus) + cs.started_time = time.time() + else: + self.log_error(fl_ctx, f"client {client_name} failed to start") + + def _check_job_status(self, fl_ctx: FLContext) -> bool: + """Check job status and determine whether the job is done. + + Args: + fl_ctx: FL context + + Returns: whether the job is considered done. + + """ + now = time.time() + + # overall_last_progress_time is the latest time that any client made progress. + overall_last_progress_time = 0.0 + clients_done = 0 + for client_name, cs in self.client_statuses.items(): + assert isinstance(cs, _ClientStatus) + + if cs.app_done: + self.log_info(fl_ctx, f"client {client_name} is Done") + clients_done += 1 + elif now - cs.last_op_time > self.max_client_op_interval: + self.system_panic( + f"client {client_name} didn't have any activity for {self.max_client_op_interval} seconds", + fl_ctx, + ) + return True + + if overall_last_progress_time < cs.last_op_time: + overall_last_progress_time = cs.last_op_time + + if clients_done == len(self.client_statuses): + # all clients are done - the job is considered done + return True + elif time.time() - overall_last_progress_time > self.progress_timeout: + # there has been no progress from any client for too long. + # this could be because the clients got stuck. + # consider the job done and abort the job. + self.system_panic(f"the job has no progress for {self.progress_timeout} seconds", fl_ctx) + return True + return False + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + self.log_warning(fl_ctx, f"ignored unknown task {task_name} from client {client.name}") + + def stop_controller(self, fl_ctx: FLContext): + """This is called by base controller to stop. + If a subclass overwrites this method, it must call super().stop_controller(fl_ctx). + + Args: + fl_ctx: + + Returns: + + """ + if self.connector: + self.log_info(fl_ctx, "Stopping server connector ...") + self.connector.stop(fl_ctx) + self.log_info(fl_ctx, "Server connector stopped") diff --git a/nvflare/app_common/tie/defs.py b/nvflare/app_common/tie/defs.py new file mode 100644 index 0000000000..b06bb69042 --- /dev/null +++ b/nvflare/app_common/tie/defs.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Constant: + + # task name defaults + CONFIG_TASK_NAME = "config" + START_TASK_NAME = "start" + + # default component config values + CONFIG_TASK_TIMEOUT = 10 + START_TASK_TIMEOUT = 10 + + TASK_CHECK_INTERVAL = 0.5 + JOB_STATUS_CHECK_INTERVAL = 2.0 + MAX_CLIENT_OP_INTERVAL = 90.0 + WORKFLOW_PROGRESS_TIMEOUT = 3600.0 + + # message topics + TOPIC_APP_REQUEST = "tie.request" + TOPIC_CLIENT_DONE = "tie.client_done" + + # keys for Shareable between client and server + MSG_KEY_EXIT_CODE = "tie.exit_code" + MSG_KEY_OP = "tie.op" + MSG_KEY_CONFIG = "tie.config" + + EXIT_CODE_CANT_START = 101 + EXIT_CODE_FATAL_ERROR = 102 + + APP_CTX_FL_CONTEXT = "tie.fl_context" diff --git a/nvflare/app_common/tie/executor.py b/nvflare/app_common/tie/executor.py new file mode 100644 index 0000000000..f40bca9898 --- /dev/null +++ b/nvflare/app_common/tie/executor.py @@ -0,0 +1,197 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.tie.applet import Applet +from nvflare.app_common.tie.connector import Connector +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.security.logging import secure_format_exception + +from .defs import Constant + + +class TieExecutor(Executor): + def __init__( + self, + configure_task_name=Constant.CONFIG_TASK_NAME, + start_task_name=Constant.START_TASK_NAME, + ): + """Constructor + + Args: + configure_task_name: name of the config task + start_task_name: name of the start task + """ + Executor.__init__(self) + self.configure_task_name = configure_task_name + self.start_task_name = start_task_name + self.connector = None + self.engine = None + + # create the abort signal to be used for signaling the connector + self.abort_signal = Signal() + + @abstractmethod + def get_connector(self, fl_ctx: FLContext) -> Connector: + """Called by the TieExecutor to get the Connector to be used by this executor. + A subclass of TieExecutor must implement this method. + + Args: + fl_ctx: the FL context + + Returns: a Connector object + + """ + pass + + @abstractmethod + def get_applet(self, fl_ctx: FLContext) -> Applet: + """Called by the TieExecutor to get the Applet to be used by this executor. + A subclass of TieExecutor must implement this method. + + Args: + fl_ctx: the FL context + + Returns: an Applet object + + """ + pass + + def configure(self, config: dict, fl_ctx: FLContext): + """Called by the TieExecutor to configure the executor based on the config params received from the server. + A subclass of TieExecutor should implement this method. + + Args: + config: the config data + fl_ctx: FL context + + Returns: None + + """ + pass + + def get_connector_config(self, fl_ctx: FLContext) -> dict: + """Called by the TieExecutor to get config params for the connector. + A subclass of TieExecutor should implement this method. + Note that this method is always called after the "configure" method, hence it's possible to dynamically + determine the connector's config based on the config params in the "configure" step. + + Args: + fl_ctx: the FL context + + Returns: a dict of config params + + """ + return {} + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.engine = fl_ctx.get_engine() + connector = self.get_connector(fl_ctx) + if not connector: + self.system_panic("cannot get connector", fl_ctx) + return + + if not isinstance(connector, Connector): + self.system_panic( + f"invalid connector: expect Connector but got {type(connector)}", + fl_ctx, + ) + return + + applet = self.get_applet(fl_ctx) + if not applet: + self.system_panic("cannot get applet", fl_ctx) + return + + if not isinstance(applet, Applet): + self.system_panic( + f"invalid applet: expect Applet but got {type(applet)}", + fl_ctx, + ) + return + + applet.initialize(fl_ctx) + connector.set_abort_signal(self.abort_signal) + connector.set_applet(applet) + connector.initialize(fl_ctx) + self.connector = connector + elif event_type == EventType.FATAL_SYSTEM_ERROR: + # notify server that the client is done + self._notify_client_done(Constant.EXIT_CODE_FATAL_ERROR, fl_ctx) + elif event_type == EventType.END_RUN: + self.abort_signal.trigger(True) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + if task_name == self.configure_task_name: + config = shareable.get(Constant.MSG_KEY_CONFIG) + if config is None: + self.log_error(fl_ctx, f"missing {Constant.MSG_KEY_CONFIG} from config") + return make_reply(ReturnCode.BAD_TASK_DATA) + + self.configure(config, fl_ctx) + + # configure the connector + connector_config = self.get_connector_config(fl_ctx) + self.connector.configure(connector_config, fl_ctx) + return make_reply(ReturnCode.OK) + elif task_name == self.start_task_name: + # start the connector + try: + self.connector.start(fl_ctx) + except Exception as ex: + self.log_exception(fl_ctx, f"failed to start connector: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + # start to monitor the connector and applet + self.connector.monitor(fl_ctx, self._notify_client_done) + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"ignored unsupported {task_name}") + return make_reply(ReturnCode.TASK_UNSUPPORTED) + + def _notify_client_done(self, rc, fl_ctx: FLContext): + """This is called when app is done. + We send a message to the FL server telling it that this client is done. + + Args: + rc: the return/exit code + fl_ctx: FL context + + Returns: None + + """ + if rc != 0: + self.log_error(fl_ctx, f"App stopped with RC {rc}") + else: + self.log_info(fl_ctx, "App Stopped") + + # tell server that this client is done + engine = fl_ctx.get_engine() + req = Shareable() + req[Constant.MSG_KEY_EXIT_CODE] = rc + engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=Constant.TOPIC_CLIENT_DONE, + request=req, + timeout=0, # fire and forget + fl_ctx=fl_ctx, + optional=True, + ) diff --git a/nvflare/app_common/tie/process_mgr.py b/nvflare/app_common/tie/process_mgr.py new file mode 100644 index 0000000000..e8c300925f --- /dev/null +++ b/nvflare/app_common/tie/process_mgr.py @@ -0,0 +1,209 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shlex +import subprocess +import sys +import threading + +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.workspace import Workspace +from nvflare.fuel.utils.obj_utils import get_logger +from nvflare.fuel.utils.validation_utils import check_object_type, check_str + + +class CommandDescriptor: + def __init__( + self, + cmd: str, + cwd=None, + env=None, + log_file_name: str = "", + log_stdout: bool = True, + stdout_msg_prefix: str = None, + ): + """Constructor of CommandDescriptor. + A CommandDescriptor describes the requirements of the new process to be started. + + Args: + cmd: the command to be executed to start the new process + cwd: current work dir for the new process + env: system env for the new process + log_file_name: base name of the log file. + log_stdout: whether to output log messages to stdout. + stdout_msg_prefix: prefix to be prepended to log message when writing to stdout. + Since multiple processes could be running within the same terminal window, the prefix can help + differentiate log messages from these processes. + """ + check_str("cmd", cmd) + + if cwd: + check_str("cwd", cwd) + + if env: + check_object_type("env", env, dict) + + if log_file_name: + check_str("log_file_name", log_file_name) + + if stdout_msg_prefix: + check_str("stdout_msg_prefix", stdout_msg_prefix) + + self.cmd = cmd + self.cwd = cwd + self.env = env + self.log_file_name = log_file_name + self.log_stdout = log_stdout + self.stdout_msg_prefix = stdout_msg_prefix + + +class ProcessManager: + def __init__(self, cmd_desc: CommandDescriptor): + """Constructor of ProcessManager. + ProcessManager provides methods for managing the lifecycle of a subprocess (start, stop, poll), as well + as the handling of log file to be used by the subprocess. + + Args: + cmd_desc: the CommandDescriptor that describes the command of the new process to be started + + NOTE: the methods of ProcessManager are not thread safe. + + """ + check_object_type("cmd_desc", cmd_desc, CommandDescriptor) + self.process = None + self.cmd_desc = cmd_desc + self.log_file = None + self.msg_prefix = None + self.file_lock = threading.Lock() + self.logger = get_logger(self) + + def start( + self, + fl_ctx: FLContext, + ): + """Start the new process. + + Args: + fl_ctx: FLContext object. + + Returns: None + + """ + job_id = fl_ctx.get_job_id() + + if self.cmd_desc.stdout_msg_prefix: + site_name = fl_ctx.get_identity_name() + self.msg_prefix = f"[{self.cmd_desc.stdout_msg_prefix}@{site_name}]" + + if self.cmd_desc.log_file_name: + ws = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + if not isinstance(ws, Workspace): + self.logger.error( + f"FL context prop {FLContextKey.WORKSPACE_OBJECT} should be Workspace but got {type(ws)}" + ) + raise RuntimeError("bad FLContext object") + + run_dir = ws.get_run_dir(job_id) + log_file_path = os.path.join(run_dir, self.cmd_desc.log_file_name) + self.log_file = open(log_file_path, "a") + + env = os.environ.copy() + if self.cmd_desc.env: + env.update(self.cmd_desc.env) + + command_seq = shlex.split(self.cmd_desc.cmd) + self.process = subprocess.Popen( + command_seq, + stderr=subprocess.STDOUT, + cwd=self.cmd_desc.cwd, + env=env, + stdout=subprocess.PIPE, + ) + + log_writer = threading.Thread(target=self._write_log, daemon=True) + log_writer.start() + + def _write_log(self): + # write messages from the process's stdout pipe to log file and sys.stdout. + # note that depending on how the process flushes out its output, the messages may be buffered/delayed. + while True: + line = self.process.stdout.readline() + if not line: + break + + assert isinstance(line, bytes) + line = line.decode("utf-8") + # use file_lock to ensure file integrity since the log file could be closed by the self.stop() method! + with self.file_lock: + if self.log_file: + self.log_file.write(line) + self.log_file.flush() + + if self.cmd_desc.log_stdout: + assert isinstance(line, str) + if self.msg_prefix and not line.startswith("\r"): + line = f"{self.msg_prefix} {line}" + sys.stdout.write(line) + sys.stdout.flush() + + def poll(self): + """Perform a poll request on the process. + + Returns: None if the process is still running; an exit code (int) if process is not running. + + """ + if not self.process: + raise RuntimeError("there is no process to poll") + return self.process.poll() + + def stop(self) -> int: + """Stop the process. + If the process is still running, kill the process. If a log file is open, close the log file. + + Returns: the exit code of the process. If killed, returns -9. + + """ + rc = self.poll() + if rc is None: + # process is still alive + try: + self.process.kill() + rc = -9 + except: + # ignore kill error + pass + + # close the log file if any + with self.file_lock: + if self.log_file: + self.logger.debug("closed subprocess log file!") + self.log_file.close() + self.log_file = None + return rc + + +def start_process(cmd_desc: CommandDescriptor, fl_ctx: FLContext) -> ProcessManager: + """Convenience function for starting a subprocess. + + Args: + cmd_desc: the CommandDescriptor the describes the command to be executed + fl_ctx: FLContext object + + Returns: a ProcessManager object. + + """ + mgr = ProcessManager(cmd_desc) + mgr.start(fl_ctx) + return mgr diff --git a/nvflare/app_common/tie/py_applet.py b/nvflare/app_common/tie/py_applet.py new file mode 100644 index 0000000000..bd47ce6261 --- /dev/null +++ b/nvflare/app_common/tie/py_applet.py @@ -0,0 +1,240 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import os +import sys +import threading +import time +from abc import ABC, abstractmethod + +from nvflare.apis.workspace import Workspace +from nvflare.fuel.utils.log_utils import add_log_file_handler, configure_logging +from nvflare.security.logging import secure_format_exception, secure_log_traceback + +from .applet import Applet +from .defs import Constant + + +class PyRunner(ABC): + + """ + A PyApplet must return a light-weight PyRunner object to run the Python code of the external app. + Since the runner could be running in a separate subprocess, the runner object must be pickleable! + """ + + @abstractmethod + def start(self, app_ctx: dict): + """Start the external app's Python code + + Args: + app_ctx: the app's execution context + + Returns: + + """ + pass + + @abstractmethod + def stop(self, timeout: float): + """Stop the external app's python code + + Args: + timeout: how long to wait for the app to stop before killing it + + Returns: None + + """ + pass + + @abstractmethod + def is_stopped(self) -> (bool, int): + """Check whether the app code is stopped + + Returns: a tuple of: whether the app is stopped, and exit code if stopped + + """ + pass + + +class _PyStarter: + """This class is used to start the Python code of the applet. It is used when running the applet in a thread + or in a separate process. + """ + + def __init__(self, runner: PyRunner, in_process: bool, workspace: Workspace, job_id: str): + self.runner = runner + self.in_process = in_process + self.workspace = workspace + self.job_id = job_id + self.error = None + self.started = True + self.stopped = False + self.exit_code = 0 + + def start(self, app_ctx: dict): + """Start the applet and wait for it to finish. + + Args: + app_ctx: the app's execution context + + Returns: None + + """ + try: + if not self.in_process: + # enable logging + run_dir = self.workspace.get_run_dir(self.job_id) + log_file_name = os.path.join(run_dir, "applet_log.txt") + configure_logging(self.workspace) + add_log_file_handler(log_file_name) + self.runner.start(app_ctx) + + # Note: run_func does not return until it runs to completion! + self.stopped = True + except Exception as e: + secure_log_traceback() + self.error = f"Exception starting applet: {secure_format_exception(e)}" + self.started = False + self.exit_code = Constant.EXIT_CODE_CANT_START + self.stopped = True + if not self.in_process: + # this is a separate process + sys.exit(self.exit_code) + + +class PyApplet(Applet, ABC): + def __init__(self, in_process: bool): + """Constructor of PyApplet, which runs the applet's Python code in a separate thread or subprocess. + + Args: + in_process: whether to run the applet code as separate thread within the same process or as a separate + subprocess. + """ + Applet.__init__(self) + self.in_process = in_process + self.starter = None + self.process = None + self.runner = None + + @abstractmethod + def get_runner(self, app_ctx: dict) -> PyRunner: + """Subclass must implement this method to return a PyRunner. + The returned PyRunner must be pickleable since it could be run in a separate subprocess! + + Args: + app_ctx: the app context for the runner + + Returns: a PyRunner object + + """ + pass + + def start(self, app_ctx: dict): + """Start the execution of the applet. + + Args: + app_ctx: the app context + + Returns: + + """ + fl_ctx = app_ctx.get(Constant.APP_CTX_FL_CONTEXT) + engine = fl_ctx.get_engine() + workspace = engine.get_workspace() + job_id = fl_ctx.get_job_id() + runner = self.get_runner(app_ctx) + + if not isinstance(runner, PyRunner): + raise RuntimeError(f"runner must be a PyRunner but got {type(runner)}") + + self.runner = runner + self.starter = _PyStarter(runner, self.in_process, workspace, job_id) + if self.in_process: + self._start_in_thread(self.starter, app_ctx) + else: + self._start_in_process(self.starter, app_ctx) + + def _start_in_thread(self, starter, app_ctx: dict): + """Start the applet in a separate thread.""" + self.logger.info("Starting applet in another thread") + thread = threading.Thread(target=starter.start, args=(app_ctx,), daemon=True, name="applet") + thread.start() + if not self.starter.started: + self.logger.error(f"Cannot start applet: {self.starter.error}") + raise RuntimeError(self.starter.error) + + def _start_in_process(self, starter, app_ctx: dict): + """Start the applet in a separate process.""" + # must remove Constant.APP_CTX_FL_CONTEXT from ctx because it's not pickleable! + app_ctx.pop(Constant.APP_CTX_FL_CONTEXT, None) + self.logger.info("Starting applet in another process") + self.process = multiprocessing.Process(target=starter.start, args=(app_ctx,), daemon=True, name="applet") + self.process.start() + + def stop(self, timeout=0.0) -> int: + """Stop the applet + + Args: + timeout: amount of time to wait for the applet to stop by itself. If the applet does not stop on + its own within this time, we'll forcefully stop it by kill. + + Returns: None + + """ + if not self.runner: + raise RuntimeError("PyRunner is not set") + + if self.in_process: + self.runner.stop(timeout) + return 0 + else: + p = self.process + self.process = None + if p: + assert isinstance(p, multiprocessing.Process) + if p.exitcode is None: + # the process is still running + if timeout > 0: + # wait for the applet to stop by itself + start = time.time() + while time.time() - start < timeout: + if p.exitcode is not None: + # already stopped + self.logger.info(f"applet stopped (rc={p.exitcode}) after {time.time()-start} secs") + return p.exitcode + time.sleep(0.1) + self.logger.info("stopped applet by killing the process") + p.kill() + return -9 + + def is_stopped(self) -> (bool, int): + if not self.runner: + raise RuntimeError("PyRunner is not set") + + if self.in_process: + if self.starter: + if self.starter.stopped: + self.logger.info("starter is stopped!") + return True, self.starter.exit_code + return self.runner.is_stopped() + else: + if self.process: + assert isinstance(self.process, multiprocessing.Process) + ec = self.process.exitcode + if ec is None: + return False, 0 + else: + return True, ec + else: + return True, 0 diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 6ffa7436f1..50aefa62db 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -359,8 +359,8 @@ def save_model(self, model): else: self.error("persistor not configured, model will not be saved") - def sample_clients(self, num_clients=None): - clients = self.engine.get_clients() + def sample_clients(self, num_clients: int = None) -> List[str]: + clients = [client.name for client in self.engine.get_clients()] if num_clients: check_positive_int("num_clients", num_clients) @@ -375,7 +375,7 @@ def sample_clients(self, num_clients=None): f"num_clients ({num_clients}) is greater than the number of available clients. Returning all ({len(clients)}) available clients." ) - self.info(f"Sampled clients: {[client.name for client in clients]}") + self.info(f"Sampled clients: {clients}") return clients diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index bd46595d14..1c2570919e 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -103,7 +103,7 @@ def send_model( callback=callback, ) - def load_model(self): + def load_model(self) -> FLModel: """Load initial model from persistor. If persistor is not configured, returns empty FLModel. Returns: @@ -111,7 +111,7 @@ def load_model(self): """ return super().load_model() - def save_model(self, model: FLModel): + def save_model(self, model: FLModel) -> None: """Saves model with persistor. If persistor is not configured, does not save. Args: @@ -122,12 +122,12 @@ def save_model(self, model: FLModel): """ super().save_model(model) - def sample_clients(self, num_clients=None): + def sample_clients(self, num_clients: int = None) -> List[str]: """Returns a list of `num_clients` clients. Args: num_clients: number of clients to return. If None or > number available clients, returns all available clients. Defaults to None. - Returns: list of clients. + Returns: list of clients names. """ return super().sample_clients(num_clients) diff --git a/nvflare/app_opt/flower/__init__.py b/nvflare/app_opt/flower/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_opt/flower/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_opt/flower/applet.py b/nvflare/app_opt/flower/applet.py new file mode 100644 index 0000000000..e3b8bcd5e2 --- /dev/null +++ b/nvflare/app_opt/flower/applet.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.workspace import Workspace +from nvflare.app_common.tie.applet import Applet +from nvflare.app_common.tie.cli_applet import CLIApplet +from nvflare.app_common.tie.defs import Constant as TieConstant +from nvflare.app_common.tie.process_mgr import CommandDescriptor, ProcessManager, start_process +from nvflare.app_opt.flower.defs import Constant +from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port +from nvflare.fuel.utils.grpc_utils import create_channel +from nvflare.security.logging import secure_format_exception + + +class FlowerClientApplet(CLIApplet): + def __init__( + self, + client_app: str, + ): + """Constructor of FlowerClientApplet, which extends CLIApplet. + + Args: + client_app: the client app specification of the Flower app + """ + CLIApplet.__init__(self) + self.client_app = client_app + + def get_command(self, ctx: dict) -> CommandDescriptor: + """Implementation of the get_command method required by the super class CLIApplet. + It returns the CLI command for starting Flower's client app, as well as the full path of the log file + for the client app. + + Args: + ctx: the applet run context + + Returns: CLI command for starting client app and name of log file. + + """ + addr = ctx.get(Constant.APP_CTX_SERVER_ADDR) + fl_ctx = ctx.get(Constant.APP_CTX_FL_CONTEXT) + if not isinstance(fl_ctx, FLContext): + self.logger.error(f"expect APP_CTX_FL_CONTEXT to be FLContext but got {type(fl_ctx)}") + raise RuntimeError("invalid FLContext") + + engine = fl_ctx.get_engine() + ws = engine.get_workspace() + if not isinstance(ws, Workspace): + self.logger.error(f"expect workspace to be Workspace but got {type(ws)}") + raise RuntimeError("invalid workspace") + + job_id = fl_ctx.get_job_id() + custom_dir = ws.get_app_custom_dir(job_id) + app_dir = ws.get_app_dir(job_id) + cmd = f"flower-client-app --insecure --grpc-adapter --superlink {addr} --dir {custom_dir} {self.client_app}" + + # use app_dir as the cwd for flower's client app. + # this is necessary for client_api to be used with the flower client app for metrics logging + # client_api expects config info from the "config" folder in the cwd! + self.logger.info(f"starting flower client app: {cmd}") + return CommandDescriptor(cmd=cmd, cwd=app_dir, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA") + + +class FlowerServerApplet(Applet): + def __init__( + self, + server_app: str, + database: str, + superlink_ready_timeout: float, + server_app_args: list = None, + ): + """Constructor of FlowerServerApplet. + + Args: + server_app: Flower's server app specification + database: database spec to be used by the server app + superlink_ready_timeout: how long to wait for the superlink process to become ready + server_app_args: an optional list that contains additional command args passed to flower server app + """ + Applet.__init__(self) + self._app_process_mgr = None + self._superlink_process_mgr = None + self.server_app = server_app + self.database = database + self.superlink_ready_timeout = superlink_ready_timeout + self.server_app_args = server_app_args + self._start_error = False + + def _start_process(self, name: str, cmd_desc: CommandDescriptor, fl_ctx: FLContext) -> ProcessManager: + self.logger.info(f"starting {name}: {cmd_desc.cmd}") + try: + return start_process(cmd_desc, fl_ctx) + except Exception as ex: + self.logger.error(f"exception starting applet: {secure_format_exception(ex)}") + self._start_error = True + + def start(self, app_ctx: dict): + """Start the applet. + + Flower requires two processes for server application: + superlink: this process is responsible for client communication + server_app: this process performs server side of training. + + We start the superlink first, and wait for it to become ready, then start the server app. + Each process will have its own log file in the job's run dir. The superlink's log file is named + "superlink_log.txt". The server app's log file is named "server_app_log.txt". + + Args: + app_ctx: the run context of the applet. + + Returns: + + """ + # try to start superlink first + driver_port = get_open_tcp_port(resources={}) + if not driver_port: + raise RuntimeError("failed to get a port for Flower driver") + driver_addr = f"127.0.0.1:{driver_port}" + + server_addr = app_ctx.get(Constant.APP_CTX_SERVER_ADDR) + fl_ctx = app_ctx.get(Constant.APP_CTX_FL_CONTEXT) + if not isinstance(fl_ctx, FLContext): + self.logger.error(f"expect APP_CTX_FL_CONTEXT to be FLContext but got {type(fl_ctx)}") + raise RuntimeError("invalid FLContext") + + engine = fl_ctx.get_engine() + ws = engine.get_workspace() + if not isinstance(ws, Workspace): + self.logger.error(f"expect workspace to be Workspace but got {type(ws)}") + raise RuntimeError("invalid workspace") + + custom_dir = ws.get_app_custom_dir(fl_ctx.get_job_id()) + + db_arg = "" + if self.database: + db_arg = f"--database {self.database}" + + superlink_cmd = ( + f"flower-superlink --insecure {db_arg} " + f"--fleet-api-address {server_addr} --fleet-api-type grpc-adapter " + f"--driver-api-address {driver_addr}" + ) + + cmd_desc = CommandDescriptor(cmd=superlink_cmd, log_file_name="superlink_log.txt", stdout_msg_prefix="FLWR-SL") + + self._superlink_process_mgr = self._start_process(name="superlink", cmd_desc=cmd_desc, fl_ctx=fl_ctx) + if not self._superlink_process_mgr: + raise RuntimeError("cannot start superlink process") + + # wait until superlink's port is ready before starting server app + # note: the server app will connect to driver_addr, not server_addr + start_time = time.time() + create_channel( + server_addr=driver_addr, + grpc_options=None, + ready_timeout=self.superlink_ready_timeout, + test_only=True, + ) + self.logger.info(f"superlink is ready for server app in {time.time()-start_time} seconds") + + # start the server app + args_str = "" + if self.server_app_args: + args_str = " ".join(self.server_app_args) + + app_cmd = ( + f"flower-server-app --insecure --superlink {driver_addr} --dir {custom_dir} {args_str} {self.server_app}" + ) + cmd_desc = CommandDescriptor( + cmd=app_cmd, + log_file_name="server_app_log.txt", + stdout_msg_prefix="FLWR-SA", + ) + + self._app_process_mgr = self._start_process(name="server_app", cmd_desc=cmd_desc, fl_ctx=fl_ctx) + if not self._app_process_mgr: + # stop the superlink + self._superlink_process_mgr.stop() + self._superlink_process_mgr = None + raise RuntimeError("cannot start server_app process") + + @staticmethod + def _stop_process(p: ProcessManager) -> int: + if not p: + # nothing to stop + return 0 + else: + return p.stop() + + def stop(self, timeout=0.0) -> int: + """Stop the server applet's superlink and server app processes. + + Args: + timeout: how long to wait before forcefully stopping (kill) the process. + + Note: we always stop the process immediately - do not wait for the process to stop itself. + + Returns: + + """ + rc = self._stop_process(self._app_process_mgr) + self._app_process_mgr = None + + self._stop_process(self._superlink_process_mgr) + self._superlink_process_mgr = None + + # return the rc of the server app! + return rc + + @staticmethod + def _is_process_stopped(p: ProcessManager): + if p: + return_code = p.poll() + if return_code is None: + return False, 0 + else: + return True, return_code + else: + return True, 0 + + def is_stopped(self) -> (bool, int): + """Check whether the server applet is already stopped + + Returns: a tuple of: whether the applet is stopped, exit code if stopped. + + Note: if either superlink or server app is stopped, we treat the applet as stopped. + + """ + if self._start_error: + return True, TieConstant.EXIT_CODE_CANT_START + + # check server app + app_stopped, app_rc = self._is_process_stopped(self._app_process_mgr) + if app_stopped: + self._app_process_mgr = None + + superlink_stopped, superlink_rc = self._is_process_stopped(self._superlink_process_mgr) + if superlink_stopped: + self._superlink_process_mgr = None + + if app_stopped or superlink_stopped: + self.stop() + + if app_stopped: + return True, app_rc + elif superlink_stopped: + return True, superlink_rc + else: + return False, 0 diff --git a/nvflare/app_opt/flower/connectors/__init__.py b/nvflare/app_opt/flower/connectors/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_opt/flower/connectors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_opt/flower/connectors/flower_connector.py b/nvflare/app_opt/flower/connectors/flower_connector.py new file mode 100644 index 0000000000..deb31f3fd3 --- /dev/null +++ b/nvflare/app_opt/flower/connectors/flower_connector.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.tie.connector import Connector +from nvflare.app_opt.flower.defs import Constant +from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number + + +class FlowerServerConnector(Connector): + """ + FlowerServerConnector specifies commonly required methods for server connector implementations. + """ + + def __init__(self): + Connector.__init__(self) + self.num_rounds = None + + def configure(self, config: dict, fl_ctx: FLContext): + """Called by Flower Controller to configure the site. + + Args: + config: config data + fl_ctx: FL context + + Returns: None + + """ + num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS) + if num_rounds is None: + raise RuntimeError("num_rounds is not configured") + + check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds) + self.num_rounds = num_rounds + + @abstractmethod + def send_request_to_flower(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + """Send request to the Flower server. + Subclass must implement this method to send this request to the Flower server. + + Args: + request: the request received from FL client + fl_ctx: the FL context + + Returns: reply from the Flower server converted to Shareable + + """ + pass + + def process_app_request(self, op: str, request: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + """This method is called by the FL Server when the request is received from a FL client. + + Args: + op: the op code of the request. + request: the request received from FL client + fl_ctx: FL context + abort_signal: abort signal that could be triggered during the process + + Returns: response from the Flower server converted to Shareable + + """ + stopped, ec = self._is_stopped() + if stopped: + self.log_warning(fl_ctx, f"dropped request '{op}' since connector is already stopped {ec=}") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + reply = self.send_request_to_flower(request, fl_ctx) + self.log_info(fl_ctx, f"received reply for '{op}'") + return reply + + +class FlowerClientConnector(Connector): + """ + FlowerClientConnector defines commonly required methods for client connector implementations. + """ + + def __init__(self, per_msg_timeout: float, tx_timeout: float): + """Constructor of FlowerClientConnector + + Args: + per_msg_timeout: per-msg timeout to be used when sending request to server via ReliableMessage + tx_timeout: tx timeout to be used when sending request to server via ReliableMessage + """ + check_positive_number("per_msg_timeout", per_msg_timeout) + check_positive_number("tx_timeout", tx_timeout) + + Connector.__init__(self) + self.per_msg_timeout = per_msg_timeout + self.tx_timeout = tx_timeout + self.stopped = False + self.num_rounds = None + + def configure(self, config: dict, fl_ctx: FLContext): + """Called by Flower Executor to configure the target. + + Args: + config: config data + fl_ctx: FL context + + Returns: None + + """ + num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS) + if num_rounds is None: + raise RuntimeError("num_rounds is not configured") + + check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds) + self.num_rounds = num_rounds + + def _send_flower_request(self, request: Shareable) -> Shareable: + """Send Flower request to the FL server via FLARE message. + + Args: + request: shareable that contains flower msg + + Returns: operation result + + """ + op = "request" + reply = self.send_request( + op=op, + target=None, # server + request=request, + per_msg_timeout=self.per_msg_timeout, + tx_timeout=self.tx_timeout, + fl_ctx=None, + ) + if not isinstance(reply, Shareable): + raise RuntimeError(f"invalid reply for op {op}: expect Shareable but got {type(reply)}") + return reply diff --git a/nvflare/app_opt/flower/connectors/grpc_client_connector.py b/nvflare/app_opt/flower/connectors/grpc_client_connector.py new file mode 100644 index 0000000000..4d61dc1d38 --- /dev/null +++ b/nvflare/app_opt/flower/connectors/grpc_client_connector.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flwr.proto.grpcadapter_pb2 as pb2 +from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterServicer + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode +from nvflare.app_opt.flower.connectors.flower_connector import FlowerClientConnector +from nvflare.app_opt.flower.defs import Constant +from nvflare.app_opt.flower.grpc_server import GrpcServer +from nvflare.app_opt.flower.utils import msg_container_to_shareable, reply_should_exit, shareable_to_msg_container +from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port +from nvflare.security.logging import secure_format_exception + + +class GrpcClientConnector(FlowerClientConnector, GrpcAdapterServicer): + def __init__( + self, + int_server_grpc_options=None, + per_msg_timeout=2.0, + tx_timeout=10.0, + client_shutdown_timeout=5.0, + ): + """Constructor of GrpcClientConnector. + GrpcClientConnector is used to connect Flare Client with the Flower Client App. + + Args: + int_server_grpc_options: internal grpc server options + per_msg_timeout: per-message timeout for using ReliableMessage + tx_timeout: transaction timeout for using ReliableMessage + client_shutdown_timeout: max time for shutting down Flare client + """ + FlowerClientConnector.__init__(self, per_msg_timeout, tx_timeout) + self.client_shutdown_timeout = client_shutdown_timeout + self.int_server_grpc_options = int_server_grpc_options + self.internal_grpc_server = None + self.stopped = False + self.internal_server_addr = None + self._training_stopped = False + self._client_name = None + + def initialize(self, fl_ctx: FLContext): + super().initialize(fl_ctx) + self._client_name = fl_ctx.get_identity_name() + + def _start_client(self, server_addr: str, fl_ctx: FLContext): + app_ctx = { + Constant.APP_CTX_CLIENT_NAME: self._client_name, + Constant.APP_CTX_SERVER_ADDR: server_addr, + Constant.APP_CTX_NUM_ROUNDS: self.num_rounds, + } + self.start_applet(app_ctx, fl_ctx) + + def _stop_client(self): + self._training_stopped = True + self.stop_applet(self.client_shutdown_timeout) + + def _is_stopped(self) -> (bool, int): + applet_stopped, ec = self.is_applet_stopped() + if applet_stopped: + return applet_stopped, ec + + if self._training_stopped: + return True, 0 + + return False, 0 + + def start(self, fl_ctx: FLContext): + if not self.num_rounds: + raise RuntimeError("cannot start - num_rounds is not set") + + # dynamically determine address on localhost + port = get_open_tcp_port(resources={}) + if not port: + raise RuntimeError("failed to get a port for Flower server") + self.internal_server_addr = f"127.0.0.1:{port}" + self.logger.info(f"Start internal server at {self.internal_server_addr}") + self.internal_grpc_server = GrpcServer(self.internal_server_addr, 10, self.int_server_grpc_options, self) + self.internal_grpc_server.start(no_blocking=True) + self.logger.info(f"Started internal grpc server at {self.internal_server_addr}") + self._start_client(self.internal_server_addr, fl_ctx) + self.logger.info("Started external Flower grpc client") + + def stop(self, fl_ctx: FLContext): + if self.stopped: + return + + self.stopped = True + self._stop_client() + + if self.internal_grpc_server: + self.logger.info("Stop internal grpc Server") + self.internal_grpc_server.shutdown() + + def _abort(self, reason: str): + # stop the gRPC client (the target) + self.abort_signal.trigger(True) + + # abort the FL client + with self.engine.new_context() as fl_ctx: + self.system_panic(reason, fl_ctx) + + def SendReceive(self, request: pb2.MessageContainer, context): + """Process request received from a Flower client. + + This implements the SendReceive method required by Flower gRPC server (LGS on FLARE Client). + 1. convert the request to a Shareable object. + 2. send the Shareable request to FLARE server. + 3. convert received Shareable result to MessageContainer and return to the Flower client + + Args: + request: the request received from the Flower client + context: gRPC context + + Returns: the reply MessageContainer object + + """ + try: + reply = self._send_flower_request(msg_container_to_shareable(request)) + rc = reply.get_return_code() + if rc == ReturnCode.OK: + return shareable_to_msg_container(reply) + else: + # server side already ended + self.logger.warning(f"Flower server has stopped with RC {rc}") + return reply_should_exit() + except Exception as ex: + self._abort(reason=f"_send_flower_request exception: {secure_format_exception(ex)}") diff --git a/nvflare/app_opt/flower/connectors/grpc_server_connector.py b/nvflare/app_opt/flower/connectors/grpc_server_connector.py new file mode 100644 index 0000000000..5a88bc365a --- /dev/null +++ b/nvflare/app_opt/flower/connectors/grpc_server_connector.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flwr.proto.grpcadapter_pb2 as pb2 + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.app_opt.flower.connectors.flower_connector import FlowerServerConnector +from nvflare.app_opt.flower.defs import Constant +from nvflare.app_opt.flower.grpc_client import GrpcClient +from nvflare.app_opt.flower.utils import msg_container_to_shareable, shareable_to_msg_container +from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port + + +class GrpcServerConnector(FlowerServerConnector): + def __init__( + self, + int_client_grpc_options=None, + flower_server_ready_timeout=Constant.FLOWER_SERVER_READY_TIMEOUT, + ): + FlowerServerConnector.__init__(self) + self.int_client_grpc_options = int_client_grpc_options + self.flower_server_ready_timeout = flower_server_ready_timeout + self.internal_grpc_client = None + self._server_stopped = False + self._exit_code = 0 + + def _start_server(self, addr: str, fl_ctx: FLContext): + app_ctx = { + Constant.APP_CTX_SERVER_ADDR: addr, + Constant.APP_CTX_NUM_ROUNDS: self.num_rounds, + } + self.start_applet(app_ctx, fl_ctx) + + def _stop_server(self): + self._server_stopped = True + self._exit_code = self.stop_applet() + + def _is_stopped(self) -> (bool, int): + runner_stopped, ec = self.is_applet_stopped() + if runner_stopped: + self.logger.info("applet is stopped!") + return runner_stopped, ec + + if self._server_stopped: + self.logger.info("Flower grpc server is stopped!") + return True, self._exit_code + + return False, 0 + + def start(self, fl_ctx: FLContext): + # we dynamically create server address on localhost + port = get_open_tcp_port(resources={}) + if not port: + raise RuntimeError("failed to get a port for Flower grpc server") + + server_addr = f"127.0.0.1:{port}" + self.log_info(fl_ctx, f"starting grpc connector: {server_addr=}") + self._start_server(server_addr, fl_ctx) + + # start internal grpc client + self.internal_grpc_client = GrpcClient(server_addr, self.int_client_grpc_options) + self.internal_grpc_client.start(ready_timeout=self.flower_server_ready_timeout) + + def stop(self, fl_ctx: FLContext): + client = self.internal_grpc_client + self.internal_grpc_client = None + if client: + self.log_info(fl_ctx, "Stopping internal grpc client") + client.stop() + self._stop_server() + + def send_request_to_flower(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + """Send the request received from FL client to Flower server. + + This is done by: + 1. convert the request to Flower-defined MessageContainer object + 2. Send the MessageContainer object to Flower server via the internal GRPC client (LGC) + 3. Convert the reply MessageContainer object received from the Flower server to Shareable + 4. Return the reply Shareable object + + Args: + request: the request received from FL client + fl_ctx: FL context + + Returns: response from Flower server converted to Shareable + + """ + stopped, _ = self.is_applet_stopped() + if stopped: + self.log_warning(fl_ctx, "dropped app request since applet is already stopped") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + result = self.internal_grpc_client.send_request(shareable_to_msg_container(request)) + + if isinstance(result, pb2.MessageContainer): + return msg_container_to_shareable(result) + else: + raise RuntimeError(f"bad result from Flower server: expect MessageContainer but got {type(result)}") diff --git a/nvflare/app_opt/flower/controller.py b/nvflare/app_opt/flower/controller.py new file mode 100644 index 0000000000..69498fc794 --- /dev/null +++ b/nvflare/app_opt/flower/controller.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.tie.controller import TieController +from nvflare.app_common.tie.defs import Constant as TieConstant +from nvflare.app_opt.flower.applet import FlowerServerApplet +from nvflare.app_opt.flower.connectors.grpc_server_connector import GrpcServerConnector +from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_number + +from .defs import Constant + + +class FlowerController(TieController): + def __init__( + self, + num_rounds=1, + server_app: str = "server:app", + database: str = "", + server_app_args: list = None, + superlink_ready_timeout: float = 10.0, + configure_task_name=TieConstant.CONFIG_TASK_NAME, + configure_task_timeout=TieConstant.CONFIG_TASK_TIMEOUT, + start_task_name=TieConstant.START_TASK_NAME, + start_task_timeout=TieConstant.START_TASK_TIMEOUT, + job_status_check_interval: float = TieConstant.JOB_STATUS_CHECK_INTERVAL, + max_client_op_interval: float = TieConstant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = TieConstant.WORKFLOW_PROGRESS_TIMEOUT, + int_client_grpc_options=None, + ): + """Constructor of FlowerController + + Args: + num_rounds: number of rounds. Not used in this version. + server_app: the server app specification for Flower server app + database: database name + server_app_args: additional server app CLI args + superlink_ready_timeout: how long to wait for the superlink to become ready before starting server app + configure_task_name: name of the config task + configure_task_timeout: max time allowed for config task to complete + start_task_name: name of the start task + start_task_timeout: max time allowed for start task to complete + job_status_check_interval: how often to check job status + max_client_op_interval: max time allowed for missing client requests + progress_timeout: max time allowed for missing overall progress + int_client_grpc_options: internal grpc client options + """ + TieController.__init__( + self, + configure_task_name=configure_task_name, + configure_task_timeout=configure_task_timeout, + start_task_name=start_task_name, + start_task_timeout=start_task_timeout, + job_status_check_interval=job_status_check_interval, + max_client_op_interval=max_client_op_interval, + progress_timeout=progress_timeout, + ) + + check_positive_number("superlink_ready_timeout", superlink_ready_timeout) + + if server_app_args: + check_object_type("server_app_args", server_app_args, list) + + self.num_rounds = num_rounds + self.server_app = server_app + self.database = database + self.server_app_args = server_app_args + self.superlink_ready_timeout = superlink_ready_timeout + self.int_client_grpc_options = int_client_grpc_options + + def get_connector(self, fl_ctx: FLContext): + return GrpcServerConnector( + int_client_grpc_options=self.int_client_grpc_options, + ) + + def get_applet(self, fl_ctx: FLContext): + return FlowerServerApplet( + server_app=self.server_app, + database=self.database, + superlink_ready_timeout=self.superlink_ready_timeout, + server_app_args=self.server_app_args, + ) + + def get_client_config_params(self, fl_ctx: FLContext) -> dict: + return { + Constant.CONF_KEY_NUM_ROUNDS: self.num_rounds, + } + + def get_connector_config_params(self, fl_ctx: FLContext) -> dict: + return { + Constant.CONF_KEY_NUM_ROUNDS: self.num_rounds, + } diff --git a/nvflare/app_opt/flower/defs.py b/nvflare/app_opt/flower/defs.py new file mode 100644 index 0000000000..f9011c8ee8 --- /dev/null +++ b/nvflare/app_opt/flower/defs.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_common.tie.defs import Constant as TieConstant +from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE + + +class Constant: + + # task name defaults + CONFIG_TASK_NAME = TieConstant.CONFIG_TASK_NAME + START_TASK_NAME = TieConstant.START_TASK_NAME + + # keys of config parameters + CONF_KEY_NUM_ROUNDS = "num_rounds" + + PARAM_KEY_HEADERS = "flower.headers" + PARAM_KEY_CONTENT = "flower.content" + PARAM_KEY_MSG_NAME = "flower.name" + + # default component config values + CONFIG_TASK_TIMEOUT = TieConstant.CONFIG_TASK_TIMEOUT + START_TASK_TIMEOUT = TieConstant.START_TASK_TIMEOUT + FLOWER_SERVER_READY_TIMEOUT = 10.0 + + TASK_CHECK_INTERVAL = TieConstant.TASK_CHECK_INTERVAL + JOB_STATUS_CHECK_INTERVAL = TieConstant.JOB_STATUS_CHECK_INTERVAL + MAX_CLIENT_OP_INTERVAL = TieConstant.MAX_CLIENT_OP_INTERVAL + WORKFLOW_PROGRESS_TIMEOUT = TieConstant.WORKFLOW_PROGRESS_TIMEOUT + + APP_CTX_SERVER_ADDR = "flower_server_addr" + APP_CTX_PORT = "flower_port" + APP_CTX_CLIENT_NAME = "flower_client_name" + APP_CTX_NUM_ROUNDS = "flower_num_rounds" + APP_CTX_FL_CONTEXT = TieConstant.APP_CTX_FL_CONTEXT + + +GRPC_DEFAULT_OPTIONS = [ + ("grpc.max_send_message_length", MAX_FRAME_SIZE), + ("grpc.max_receive_message_length", MAX_FRAME_SIZE), +] diff --git a/nvflare/app_opt/flower/executor.py b/nvflare/app_opt/flower/executor.py new file mode 100644 index 0000000000..f11e8ee00f --- /dev/null +++ b/nvflare/app_opt/flower/executor.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.tie.executor import TieExecutor +from nvflare.app_opt.flower.applet import FlowerClientApplet +from nvflare.app_opt.flower.connectors.grpc_client_connector import GrpcClientConnector + +from .defs import Constant + + +class FlowerExecutor(TieExecutor): + def __init__( + self, + client_app: str = "client:app", + start_task_name=Constant.START_TASK_NAME, + configure_task_name=Constant.CONFIG_TASK_NAME, + per_msg_timeout=10.0, + tx_timeout=100.0, + client_shutdown_timeout=5.0, + ): + TieExecutor.__init__( + self, + start_task_name=start_task_name, + configure_task_name=configure_task_name, + ) + + self.int_server_grpc_options = None + self.per_msg_timeout = per_msg_timeout + self.tx_timeout = tx_timeout + self.client_shutdown_timeout = client_shutdown_timeout + self.num_rounds = None + self.client_app = client_app + + def get_connector(self, fl_ctx: FLContext): + return GrpcClientConnector( + int_server_grpc_options=self.int_server_grpc_options, + per_msg_timeout=self.per_msg_timeout, + tx_timeout=self.tx_timeout, + ) + + def get_applet(self, fl_ctx: FLContext): + return FlowerClientApplet(self.client_app) + + def configure(self, config: dict, fl_ctx: FLContext): + self.num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS) + + def get_connector_config(self, fl_ctx: FLContext) -> dict: + return {Constant.CONF_KEY_NUM_ROUNDS: self.num_rounds} diff --git a/nvflare/app_opt/flower/grpc_client.py b/nvflare/app_opt/flower/grpc_client.py new file mode 100644 index 0000000000..67f14fe72e --- /dev/null +++ b/nvflare/app_opt/flower/grpc_client.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flwr.proto.grpcadapter_pb2 as pb2 +from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub + +from nvflare.app_opt.flower.defs import GRPC_DEFAULT_OPTIONS +from nvflare.fuel.utils.grpc_utils import create_channel +from nvflare.fuel.utils.obj_utils import get_logger + +from .utils import reply_should_exit + + +class GrpcClient: + """This class implements a gRPC Client that is capable of sending Flower requests to a Flower gRPC Server.""" + + def __init__(self, server_addr, grpc_options=None): + """Constructor + + Args: + server_addr: address of the gRPC server to connect to + grpc_options: gRPC options for the gRPC client + """ + if not grpc_options: + grpc_options = GRPC_DEFAULT_OPTIONS + + self.stub = None + self.channel = None + self.server_addr = server_addr + self.grpc_options = grpc_options + self.started = False + self.logger = get_logger(self) + + def start(self, ready_timeout=10): + """Start the gRPC client and wait for the server to be ready. + + Args: + ready_timeout: how long to wait for the server to be ready + + Returns: None + + """ + if self.started: + return + + self.started = True + + self.channel = create_channel( + server_addr=self.server_addr, + grpc_options=self.grpc_options, + ready_timeout=ready_timeout, + test_only=False, + ) + self.stub = GrpcAdapterStub(self.channel) + + def send_request(self, request: pb2.MessageContainer): + """Send Flower request to gRPC server + + Args: + request: grpc request + + Returns: a pb2.MessageContainer object + + """ + self.logger.info(f"sending {len(request.grpc_message_content)} bytes: {request.grpc_message_name=}") + try: + result = self.stub.SendReceive(request) + except Exception as ex: + self.logger.warning(f"exception occurred communicating to Flower server: {ex}") + return reply_should_exit() + + if not isinstance(result, pb2.MessageContainer): + self.logger.error(f"expect reply to be pb2.MessageContainer but got {type(result)}") + return None + return result + + def stop(self): + """Stop the gRPC client + + Returns: None + + """ + ch = self.channel + self.channel = None # set to None in case another thread also tries to close. + if ch: + try: + ch.close() + except: + # ignore errors when closing the channel + pass diff --git a/nvflare/app_opt/flower/grpc_server.py b/nvflare/app_opt/flower/grpc_server.py new file mode 100644 index 0000000000..0eb09a9f7a --- /dev/null +++ b/nvflare/app_opt/flower/grpc_server.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures as futures + +import grpc +from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterServicer, add_GrpcAdapterServicer_to_server + +from nvflare.app_opt.flower.defs import GRPC_DEFAULT_OPTIONS +from nvflare.fuel.utils.obj_utils import get_logger +from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int +from nvflare.security.logging import secure_format_exception + + +class GrpcServer: + """This class implements a gRPC Server that is capable of processing Flower requests.""" + + def __init__(self, addr, max_workers: int, grpc_options, servicer): + """Constructor + + Args: + addr: the listening address of the server + max_workers: max number of workers + grpc_options: gRPC options + servicer: the servicer that is capable of processing Flower requests + """ + if not grpc_options: + grpc_options = GRPC_DEFAULT_OPTIONS + + check_object_type("servicer", servicer, GrpcAdapterServicer) + check_positive_int("max_workers", max_workers) + self.grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), options=grpc_options) + add_GrpcAdapterServicer_to_server(servicer, self.grpc_server) + self.logger = get_logger(self) + + try: + # TBD: will be enhanced to support secure port + self.grpc_server.add_insecure_port(addr) + self.logger.info(f"Flower gRPC Server: added insecure port at {addr}") + except Exception as ex: + self.logger.error(f"cannot listen on {addr}: {secure_format_exception(ex)}") + + def start(self, no_blocking=False): + """Called to start the server + + Args: + no_blocking: whether blocking the current thread and wait for server termination + + Returns: None + + """ + self.logger.info("starting Flower gRPC Server") + self.grpc_server.start() + if no_blocking: + # don't wait for server termination + return + else: + self.grpc_server.wait_for_termination() + self.logger.info("Flower gRPC server terminated") + + def shutdown(self): + """Shut down the gRPC server gracefully. + + Returns: + + """ + self.logger.info("shutting down Flower gRPC server") + server = self.grpc_server + self.grpc_server = None # in case another thread calls shutdown at the same time + if server: + server.stop(grace=0.5) diff --git a/nvflare/app_opt/flower/mock/__init__.py b/nvflare/app_opt/flower/mock/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_opt/flower/mock/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_opt/flower/mock/applet.py b/nvflare/app_opt/flower/mock/applet.py new file mode 100644 index 0000000000..c2004035ec --- /dev/null +++ b/nvflare/app_opt/flower/mock/applet.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.app_common.tie.cli_applet import CLIApplet, CommandDescriptor +from nvflare.app_common.tie.py_applet import PyApplet, PyRunner +from nvflare.app_opt.flower.defs import Constant + +from .flower_client import train + + +class MockClientApplet(CLIApplet): + def __init__(self): + CLIApplet.__init__(self) + + def get_command(self, app_ctx: dict) -> CommandDescriptor: + main_module = "nvflare.app_opt.flower.mock.flower_client" + addr = app_ctx.get(Constant.APP_CTX_SERVER_ADDR) + num_rounds = app_ctx.get(Constant.APP_CTX_NUM_ROUNDS) + client_name = app_ctx.get(Constant.APP_CTX_CLIENT_NAME) + + return CommandDescriptor( + cmd=f"python -m {main_module} -a {addr} -n {num_rounds} -c {client_name}", + log_file_name="flower_client_log.txt", + stdout_msg_prefix="FLWR-CA", + ) + + +class MockServerApplet(CLIApplet): + def __init__(self): + CLIApplet.__init__(self) + + def get_command(self, app_ctx: dict) -> CommandDescriptor: + main_module = "nvflare.app_opt.flower.mock.flower_server" + addr = app_ctx.get(Constant.APP_CTX_SERVER_ADDR) + num_rounds = app_ctx.get(Constant.APP_CTX_NUM_ROUNDS) + + return CommandDescriptor( + cmd=f"python -m {main_module} -a {addr} -n {num_rounds}", + log_file_name="flower_server_log.txt", + stdout_msg_prefix="FLWR-SA", + ) + + +class MockClientPyRunner(PyRunner): + def __init__(self): + self.stopped = False + + def start(self, app_ctx: dict): + addr = app_ctx.get(Constant.APP_CTX_SERVER_ADDR) + client_name = app_ctx.get(Constant.APP_CTX_CLIENT_NAME) + train(server_addr=addr, client_name=client_name) + self.stopped = True + + def stop(self, timeout: float): + pass + + def is_stopped(self) -> (bool, int): + return self.stopped, 0 + + +class MockClientPyApplet(PyApplet): + def __init__(self, in_process=True): + PyApplet.__init__(self, in_process) + + def get_runner(self, app_ctx: dict) -> PyRunner: + return MockClientPyRunner() diff --git a/nvflare/app_opt/flower/mock/controller.py b/nvflare/app_opt/flower/mock/controller.py new file mode 100644 index 0000000000..ea6999dc1d --- /dev/null +++ b/nvflare/app_opt/flower/mock/controller.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.flower.controller import FlowerController +from nvflare.app_opt.flower.mock.applet import MockServerApplet + + +class MockController(FlowerController): + def __init__(self, num_rounds: int): + FlowerController.__init__(self, num_rounds=num_rounds) + + def get_applet(self, fl_ctx: FLContext): + return MockServerApplet() diff --git a/nvflare/app_opt/flower/mock/echo_servicer.py b/nvflare/app_opt/flower/mock/echo_servicer.py new file mode 100644 index 0000000000..e8a798e6d3 --- /dev/null +++ b/nvflare/app_opt/flower/mock/echo_servicer.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flwr.proto.grpcadapter_pb2 as pb2 +from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterServicer + +from nvflare.fuel.utils.obj_utils import get_logger + + +class EchoServicer(GrpcAdapterServicer): + def __init__(self, num_rounds): + self.logger = get_logger(self) + self.num_rounds = num_rounds + self.server = None + self.stopped = False + + def set_server(self, s): + self.server = s + + def SendReceive(self, request: pb2.MessageContainer, context): + msg_name = request.grpc_message_name + headers = request.metadata + content = request.grpc_message_content + self.logger.info(f"got {msg_name=}: {headers=} {content=}") + + round_num = int(headers.get("round")) + if round_num >= self.num_rounds: + # stop the server + self.logger.info(f"got round number {round_num}: ask to shutdown server") + self.server.shutdown() + self.stopped = True + + headers["round"] = str(round_num + 1) + return pb2.MessageContainer( + metadata=headers, + grpc_message_name=msg_name, + grpc_message_content=content, + ) diff --git a/nvflare/app_opt/flower/mock/executor.py b/nvflare/app_opt/flower/mock/executor.py new file mode 100644 index 0000000000..87a239d46b --- /dev/null +++ b/nvflare/app_opt/flower/mock/executor.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.flower.executor import FlowerExecutor +from nvflare.app_opt.flower.mock.applet import MockClientApplet, MockClientPyApplet + + +class MockExecutor(FlowerExecutor): + def __init__(self): + FlowerExecutor.__init__(self) + + def get_applet(self, fl_ctx: FLContext): + return MockClientApplet() + + +class MockPyExecutor(FlowerExecutor): + def __init__(self, in_process=True): + FlowerExecutor.__init__(self) + self.in_process = in_process + + def get_applet(self, fl_ctx: FLContext): + return MockClientPyApplet(self.in_process) diff --git a/nvflare/app_opt/flower/mock/flower_client.py b/nvflare/app_opt/flower/mock/flower_client.py new file mode 100644 index 0000000000..e2f4d7ab19 --- /dev/null +++ b/nvflare/app_opt/flower/mock/flower_client.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import time + +import flwr.proto.grpcadapter_pb2 as pb2 + +from nvflare.app_opt.flower.grpc_client import GrpcClient +from nvflare.fuel.utils.time_utils import time_to_string + + +def log(msg: str): + for i in range(5): + print(f"\r{i}", end=" ") + sys.stdout.flush() + print("\nend") + print(f"{time_to_string(time.time())}: {msg}") + sys.stdout.flush() + + +def train(server_addr, client_name): + log(f"starting client {client_name} to connect to server at {server_addr}") + client = GrpcClient(server_addr=server_addr) + client.start() + + total_time = 0 + total_reqs = 0 + next_round = 0 + while True: + log(f"Test round {next_round}") + data = os.urandom(10) + + headers = { + "target": "server", + "round": str(next_round), + "origin": client_name, + } + req = pb2.MessageContainer( + grpc_message_name="abc", + grpc_message_content=data, + ) + req.metadata.update(headers) + + start = time.time() + result = client.send_request(req) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.MessageContainer): + log(f"expect reply to be pb2.MessageContainer but got {type(result)}") + elif result.grpc_message_name != req.grpc_message_name: + log("ERROR: msg_name does not match request") + elif result.grpc_message_content != data: + log("ERROR: result does not match request") + else: + log("OK: result matches request!") + + result_headers = result.metadata + should_exit = result_headers.get("should-exit") + if should_exit: + log("got should-exit!") + break + + next_round = result_headers.get("round") + time.sleep(1.0) + + time_per_req = total_time / total_reqs + log(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + + +def main(): + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--addr", "-a", type=str, help="server address", required=True) + parser.add_argument("--client_name", "-c", type=str, help="client name", required=True) + parser.add_argument("--num_rounds", "-n", type=int, help="number of rounds", required=True) + args = parser.parse_args() + + if not args.addr: + raise RuntimeError("missing server address '--addr/-a' in command") + + if not args.num_rounds: + raise RuntimeError("missing num rounds '--num_rounds/-n' in command") + + if args.num_rounds <= 0: + raise RuntimeError("bad num rounds '--num_rounds/-n' in command: must be > 0") + + train(args.addr, args.client_name) + + +if __name__ == "__main__": + main() diff --git a/nvflare/app_opt/flower/mock/flower_server.py b/nvflare/app_opt/flower/mock/flower_server.py new file mode 100644 index 0000000000..f09d6688b9 --- /dev/null +++ b/nvflare/app_opt/flower/mock/flower_server.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +from nvflare.app_opt.flower.grpc_server import GrpcServer +from nvflare.app_opt.flower.mock.echo_servicer import EchoServicer + + +def main(): + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--max_workers", "-w", type=int, help="max number of workers", required=False, default=20) + parser.add_argument("--addr", "-a", type=str, help="server address", required=True) + parser.add_argument("--num_rounds", "-n", type=int, help="number of rounds", required=True) + args = parser.parse_args() + + if not args.addr: + raise RuntimeError("missing server address '--addr/-a' in command") + + print(f"starting server: {args.addr=} {args.max_workers=} {args.num_rounds=}") + servicer = EchoServicer(args.num_rounds) + server = GrpcServer( + args.addr, + max_workers=args.max_workers, + grpc_options=None, + servicer=servicer, + ) + servicer.set_server(server) + server.start() + + +if __name__ == "__main__": + main() diff --git a/nvflare/app_opt/flower/utils.py b/nvflare/app_opt/flower/utils.py new file mode 100644 index 0000000000..6b271d8749 --- /dev/null +++ b/nvflare/app_opt/flower/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flwr.proto.grpcadapter_pb2 as pb2 + +from nvflare.apis.shareable import Shareable + +from .defs import Constant + + +def msg_container_to_shareable(msg: pb2.MessageContainer) -> Shareable: + """Convert Flower-defined MessageContainer object to a Shareable object. + This function is typically used to in two cases: + 1. Convert Flower-client generated request to Shareable before sending it to FLARE Server via RM. + 2. Convert Flower-server generated response to Shareable before sending it back to FLARE client. + + Args: + msg: MessageContainer object to be converted + + Returns: a Shareable object + + """ + s = Shareable() + headers = msg.metadata + if headers is not None: + # must convert msg.metadata to dict; otherwise it is not serializable. + headers = dict(msg.metadata) + s[Constant.PARAM_KEY_CONTENT] = msg.grpc_message_content + s[Constant.PARAM_KEY_HEADERS] = headers + s[Constant.PARAM_KEY_MSG_NAME] = msg.grpc_message_name + return s + + +def shareable_to_msg_container(s: Shareable) -> pb2.MessageContainer: + """Convert Shareable object to Flower-defined MessageContainer + This function is typically used to in two cases: + 1. Convert a Shareable object received from FLARE client to MessageContainer before sending it to Flower server. + 2. Convert a Shareable object received from FLARE server to MessageContainer before sending it to Flower client. + + Args: + s: the Shareable object to be converted + + Returns: a MessageContainer object + + """ + m = pb2.MessageContainer( + grpc_message_name=s.get(Constant.PARAM_KEY_MSG_NAME), + grpc_message_content=s.get(Constant.PARAM_KEY_CONTENT), + ) + headers = s.get(Constant.PARAM_KEY_HEADERS) + if headers: + # Note: headers is a dict, but m.metadata is Google defined MapContainer, which is subclass of dict. + m.metadata.update(headers) + return m + + +def reply_should_exit() -> pb2.MessageContainer: + return pb2.MessageContainer(metadata={"should-exit": "true"}) diff --git a/nvflare/app_opt/tf/fedopt_ctl.py b/nvflare/app_opt/tf/fedopt_ctl.py index 2cf2d051b0..42814cf55a 100644 --- a/nvflare/app_opt/tf/fedopt_ctl.py +++ b/nvflare/app_opt/tf/fedopt_ctl.py @@ -72,7 +72,10 @@ def run(self): try: if "args" not in self.optimizer_args: self.optimizer_args["args"] = {} +<<<<<<< HEAD # self.optimizer_args["args"]["params"] = self.keras_model.parameters() +======= +>>>>>>> upstream/main self.optimizer = self.build_component(self.optimizer_args) except Exception as e: error_msg = f"Exception while constructing optimizer: {secure_format_exception(e)}" diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index b559306440..b32535ac15 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -109,6 +109,8 @@ class Constant: HEADER_KEY_HORIZONTAL = "xgb.horizontal" HEADER_KEY_ORIGINAL_BUF_SIZE = "xgb.original_buf_size" HEADER_KEY_IN_AGGR = "xgb.in_aggr" + HEADER_KEY_WORLD_SIZE = "xgb.world_size" + HEADER_KEY_SIZE_DICT = "xgb.size_dict" DUMMY_BUFFER_SIZE = 4 @@ -122,8 +124,6 @@ class Constant: class SplitMode: ROW = 0 COL = 1 - COL_SECURE = 2 - ROW_SECURE = 3 # Mapping of text training mode to split mode @@ -132,10 +132,10 @@ class SplitMode: "horizontal": SplitMode.ROW, "v": SplitMode.COL, "vertical": SplitMode.COL, - "hs": SplitMode.ROW_SECURE, - "horizontal_secure": SplitMode.ROW_SECURE, - "vs": SplitMode.COL_SECURE, - "vertical_secure": SplitMode.COL_SECURE, + "hs": SplitMode.ROW, + "horizontal_secure": SplitMode.ROW, + "vs": SplitMode.COL, + "vertical_secure": SplitMode.COL, } SECURE_TRAINING_MODES = {"hs", "horizontal_secure", "vs", "vertical_secure"} diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi index 7ad47596df..7dc3e6dde1 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi @@ -6,7 +6,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = () HALF: _ClassVar[DataType] FLOAT: _ClassVar[DataType] DOUBLE: _ClassVar[DataType] @@ -21,7 +21,7 @@ class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): UINT64: _ClassVar[DataType] class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = () MAX: _ClassVar[ReduceOperation] MIN: _ClassVar[ReduceOperation] SUM: _ClassVar[ReduceOperation] @@ -48,7 +48,7 @@ BITWISE_OR: ReduceOperation BITWISE_XOR: ReduceOperation class AllgatherRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer"] + __slots__ = ("sequence_number", "rank", "send_buffer") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -58,13 +58,13 @@ class AllgatherRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer"] + __slots__ = ("sequence_number", "rank", "send_buffer") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -74,13 +74,13 @@ class AllgatherVRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllreduceRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer", "data_type", "reduce_operation"] + __slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -94,13 +94,13 @@ class AllreduceRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., data_type: _Optional[_Union[DataType, str]] = ..., reduce_operation: _Optional[_Union[ReduceOperation, str]] = ...) -> None: ... class AllreduceReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class BroadcastRequest(_message.Message): - __slots__ = ["sequence_number", "rank", "send_buffer", "root"] + __slots__ = ("sequence_number", "rank", "send_buffer", "root") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] @@ -112,7 +112,7 @@ class BroadcastRequest(_message.Message): def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., root: _Optional[int] = ...) -> None: ... class BroadcastReply(_message.Message): - __slots__ = ["receive_buffer"] + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py index 45eee5c8dd..549d0e4ffc 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: federated.proto +# Protobuf Python Version: 4.25.1 # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as federated__pb2 - class FederatedStub(object): """Missing associated documentation comment in .proto file.""" diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index 1b98829711..0d8e8bec1d 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -30,7 +30,9 @@ from nvflare.fuel.utils.obj_utils import get_logger from nvflare.utils.cli_utils import get_package_root -LOADER_PARAMS_LIBRARY_PATH = "LIBRARY_PATH" +PLUGIN_PARAM_KEY = "federated_plugin" +PLUGIN_KEY_NAME = "name" +PLUGIN_KEY_PATH = "path" class XGBClientRunner(AppRunner, FLComponent): @@ -135,7 +137,7 @@ def run(self, ctx: dict): self.logger.info(f"server address is {self._server_addr}") communicator_env = { - "xgboost_communicator": "federated", + "dmlc_communicator": "federated", "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, @@ -145,38 +147,35 @@ def run(self, ctx: dict): self.logger.info("XGBoost non-secure training") else: xgb_plugin_name = ConfigService.get_str_var( - name="xgb_plugin_name", conf=SystemConfigs.RESOURCES_CONF, default="nvflare" + name="xgb_plugin_name", conf=SystemConfigs.RESOURCES_CONF, default=None ) - - xgb_loader_params = ConfigService.get_dict_var( - name="xgb_loader_params", conf=SystemConfigs.RESOURCES_CONF, default={} + xgb_plugin_path = ConfigService.get_str_var( + name="xgb_plugin_path", conf=SystemConfigs.RESOURCES_CONF, default=None + ) + xgb_plugin_params: dict = ConfigService.get_dict_var( + name=PLUGIN_PARAM_KEY, conf=SystemConfigs.RESOURCES_CONF, default={} ) - # Library path is frequently used, add a scalar config var and overwrite what's in the dict - xgb_library_path = ConfigService.get_str_var(name="xgb_library_path", conf=SystemConfigs.RESOURCES_CONF) - if xgb_library_path: - xgb_loader_params[LOADER_PARAMS_LIBRARY_PATH] = xgb_library_path + # path and name can be overwritten by scalar configuration + if xgb_plugin_name: + xgb_plugin_params[PLUGIN_KEY_NAME] = xgb_plugin_name - lib_path = xgb_loader_params.get(LOADER_PARAMS_LIBRARY_PATH, None) - if not lib_path: - xgb_loader_params[LOADER_PARAMS_LIBRARY_PATH] = str(get_package_root() / "libs") + if xgb_plugin_path: + xgb_plugin_params[PLUGIN_KEY_PATH] = xgb_plugin_path - xgb_proc_params = ConfigService.get_dict_var( - name="xgb_proc_params", conf=SystemConfigs.RESOURCES_CONF, default={} - ) + # Set default plugin name + if not xgb_plugin_params.get(PLUGIN_KEY_NAME): + xgb_plugin_params[PLUGIN_KEY_NAME] = "cuda_paillier" - self.logger.info( - f"XGBoost secure mode: {self._training_mode} plugin_name: {xgb_plugin_name} " - f"proc_params: {xgb_proc_params} loader_params: {xgb_loader_params}" - ) + if not xgb_plugin_params.get(PLUGIN_KEY_PATH): + # This only works on Linux. Need to support other platforms + lib_ext = "so" + lib_name = f"lib{xgb_plugin_params[PLUGIN_KEY_NAME]}.{lib_ext}" + xgb_plugin_params[PLUGIN_KEY_PATH] = str(get_package_root() / "libs" / lib_name) - communicator_env.update( - { - "plugin_name": xgb_plugin_name, - "proc_params": xgb_proc_params, - "loader_params": xgb_loader_params, - } - ) + self.logger.info(f"XGBoost secure training: {self._training_mode} Params: {xgb_plugin_params}") + + communicator_env[PLUGIN_PARAM_KEY] = xgb_plugin_params with xgb.collective.CommunicatorContext(**communicator_env): # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py index 32e708c90e..e4a8796a38 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_server_runner.py @@ -29,8 +29,8 @@ def run(self, ctx: dict): self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) xgb_federated.run_federated_server( + n_workers=self._world_size, port=self._port, - world_size=self._world_size, ) self._stopped = True diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index 5aad654824..ea5607d828 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -299,6 +299,10 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): self._process_after_all_gather_v_vertical(fl_ctx) def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): + reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) + size_dict = reply.get_header(Constant.HEADER_KEY_SIZE_DICT) + total_size = sum(size_dict.values()) + self.info(fl_ctx, f"{total_size=} {size_dict=}") rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) # this rcv_buf is a list of replies from ALL clients! rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) @@ -309,7 +313,7 @@ def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): if not self.clear_ghs: # this is non-label client - don't care about the results - dummy = os.urandom(Constant.DUMMY_BUFFER_SIZE) + dummy = os.urandom(total_size) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=dummy, private=True, sticky=False) self.info(fl_ctx, "non-label client: return dummy buffer back to XGB") return @@ -352,16 +356,45 @@ def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): self.info(fl_ctx, f"final aggr: {gid=} features={fid_list}") result = self.data_converter.encode_aggregation_result(final_result, fl_ctx) + + # XGBoost expects every work has a set of histograms. They are already combined here so + # just add zeros + zero_result = final_result + for result_list in zero_result.values(): + for item in result_list: + size = len(item.aggregated_hist) + item.aggregated_hist = [(0, 0)] * size + zero_buf = self.data_converter.encode_aggregation_result(zero_result, fl_ctx) + world_size = len(size_dict) + for _ in range(world_size - 1): + result += zero_buf + + # XGBoost checks that the size of allgatherv is not changed + padding_size = total_size - len(result) + if padding_size > 0: + result += b"\x00" * padding_size + elif padding_size < 0: + self.error(fl_ctx, f"The original size {total_size} is not big enough for data size {len(result)}") + fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def _process_after_all_gather_v_horizontal(self, fl_ctx: FLContext): + reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) + world_size = reply.get_header(Constant.HEADER_KEY_WORLD_SIZE) encrypted_histograms = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) if not isinstance(encrypted_histograms, CKKSVector): return self._abort(f"rank {rank}: expect a CKKSVector but got {type(encrypted_histograms)}", fl_ctx) histograms = encrypted_histograms.decrypt(secret_key=self.tenseal_context.secret_key()) + result = self.data_converter.encode_histograms_result(histograms, fl_ctx) + + # XGBoost expect every worker returns a histogram, all zeros are returned for other workers + zeros = [0.0] * len(histograms) + zero_buf = self.data_converter.encode_histograms_result(zeros, fl_ctx) + for _ in range(world_size - 1): + result += zero_buf fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def handle_event(self, event_type: str, fl_ctx: FLContext): @@ -376,7 +409,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): else: self.debug(fl_ctx, "Tenseal module not loaded, horizontal secure XGBoost is not supported") except Exception as ex: - self.debug(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") + self.error(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") self.tenseal_context = None elif event_type == EventType.END_RUN: self.tenseal_context = None diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py index 53e936c7d4..47e44d17d6 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py @@ -39,6 +39,8 @@ def __init__(self): self.aggr_result_dict = None self.aggr_result_to_send = None self.aggr_result_lock = threading.Lock() + self.world_size = 0 + self.size_dict = None if tenseal_imported: decomposers.register() @@ -124,6 +126,10 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): else: self.info(fl_ctx, f"no aggr data from {rank=}") + if self.size_dict is None: + self.size_dict = {} + + self.size_dict[rank] = request.get_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE) # only send a dummy to the Server fl_ctx.set_prop( key=Constant.PARAM_KEY_SEND_BUF, value=os.urandom(Constant.DUMMY_BUFFER_SIZE), private=True, sticky=False @@ -146,6 +152,7 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal) + with self.aggr_result_lock: if not self.aggr_result_to_send: if not self.aggr_result_dict: @@ -159,6 +166,10 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): # reset aggr_result_dict for next gather self.aggr_result_dict = None + self.world_size = len(self.size_dict) + reply.set_header(Constant.HEADER_KEY_WORLD_SIZE, self.world_size) + reply.set_header(Constant.HEADER_KEY_SIZE_DICT, self.size_dict) + if horizontal: length = self.aggr_result_to_send.size() else: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py new file mode 100644 index 0000000000..6540eb519c --- /dev/null +++ b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xgboost as xgb + +from nvflare.app_opt.xgboost.data_loader import XGBDataLoader +from nvflare.app_opt.xgboost.histogram_based_v2.defs import TRAINING_MODE_MAPPING, SplitMode + + +class SecureDataLoader(XGBDataLoader): + def __init__(self, rank: int, folder: str): + """Reads CSV dataset and return XGB data matrix in vertical secure mode. + + Args: + rank: Rank of the site + folder: Folder to find the CSV files + """ + self.rank = rank + self.folder = folder + + def load_data(self, client_id: str, training_mode: str): + + train_path = f"{self.folder}/site-{self.rank + 1}/train.csv" + valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv" + + if training_mode not in TRAINING_MODE_MAPPING: + raise ValueError(f"Invalid training_mode: {training_mode}") + + data_split_mode = TRAINING_MODE_MAPPING[training_mode] + + if self.rank == 0 or data_split_mode == SplitMode.ROW: + label = "&label_column=0" + else: + label = "" + + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=data_split_mode) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=data_split_mode) + + return train_data, valid_data diff --git a/nvflare/fuel/f3/cellnet/cell_cipher.py b/nvflare/fuel/f3/cellnet/cell_cipher.py index da8ddba23b..88866f81d8 100644 --- a/nvflare/fuel/f3/cellnet/cell_cipher.py +++ b/nvflare/fuel/f3/cellnet/cell_cipher.py @@ -89,7 +89,7 @@ def _verify(k, m, s): ) -def _sym_enc(k, n, m): +def _sym_enc(k: bytes, n: bytes, m: bytes): cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n)) encryptor = cipher.encryptor() padder = padding.PKCS7(PADDING_LENGTH).padder() @@ -97,7 +97,7 @@ def _sym_enc(k, n, m): return encryptor.update(padded_data) + encryptor.finalize() -def _sym_dec(k, n, m): +def _sym_dec(k: bytes, n: bytes, m: bytes): cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n)) decryptor = cipher.decryptor() plain_text = decryptor.update(m) @@ -157,28 +157,6 @@ def get_latest_key(self): return last_value -class CellCipher: - def __init__(self, session_key_manager: SessionKeyManager): - self.session_key_manager = session_key_manager - - def encrypt(self, message): - key = self.session_key_manager.get_latest_key() - key_hash = get_hash(key) - nonce = os.urandom(NONCE_LENGTH) - return nonce + key_hash[-HASH_LENGTH:] + _sym_enc(key, nonce, message) - - def decrypt(self, message): - nonce, key_hash, message = ( - message[:NONCE_LENGTH], - message[NONCE_LENGTH:HEADER_LENGTH], - message[HEADER_LENGTH:], - ) - key = self.session_key_manager.get_key(key_hash) - if key is None: - raise SessionKeyUnavailable("No session key found for received message") - return _sym_dec(key, nonce, message) - - class SimpleCellCipher: def __init__(self, root_ca: Certificate, pri_key: asymmetric.rsa.RSAPrivateKey, cert: Certificate): self._root_ca = root_ca diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index d7e821cbef..dbb2b8c10f 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -942,6 +942,10 @@ def encrypt_payload(self, message: Message): if message.payload is None: message.payload = bytes(0) + elif isinstance(message.payload, memoryview) or isinstance(message.payload, bytearray): + message.payload = bytes(message.payload) + elif not isinstance(message.payload, bytes): + raise RuntimeError(f"Payload type of {type(message.payload)} is not supported.") payload_len = len(message.payload) message.add_headers( diff --git a/nvflare/fuel/utils/grpc_utils.py b/nvflare/fuel/utils/grpc_utils.py new file mode 100644 index 0000000000..3adc6b003b --- /dev/null +++ b/nvflare/fuel/utils/grpc_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import grpc + + +def create_channel(server_addr, grpc_options, ready_timeout: float, test_only: bool): + """Create gRPC channel and waits for the server to be ready + + Args: + server_addr: the gRPC server address to connect to + grpc_options: gRPC client connection options + ready_timeout: how long to wait for the server to be ready + test_only: whether for testing the server readiness only + + Returns: the gRPC channel created. Bit if test_only, the channel is closed and returns None. + + If the server does not become ready within ready_timeout, the RuntimeError exception will raise. + + """ + channel = grpc.insecure_channel(server_addr, options=grpc_options) + + # wait for channel ready + try: + grpc.channel_ready_future(channel).result(timeout=ready_timeout) + except grpc.FutureTimeoutError: + raise RuntimeError(f"cannot connect to server after {ready_timeout} seconds") + + if test_only: + channel.close() + channel = None + return channel diff --git a/nvflare/job_config/fed_job.py b/nvflare/job_config/fed_job.py index fe17406687..883c83ad29 100644 --- a/nvflare/job_config/fed_job.py +++ b/nvflare/job_config/fed_job.py @@ -19,12 +19,14 @@ from nvflare.apis.executor import Executor from nvflare.apis.filter import Filter from nvflare.apis.impl.controller import Controller +from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME from nvflare.app_common.executors.script_executor import ScriptExecutor from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator from nvflare.fuel.utils.class_utils import get_component_init_parameters from nvflare.fuel.utils.import_utils import optional_import +from nvflare.fuel.utils.validation_utils import check_positive_int from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig from nvflare.job_config.fed_job_config import FedJobConfig @@ -103,6 +105,56 @@ def add_external_scripts(self, external_scripts: List): self.app.add_ext_script(_script) +class ExecutorApp(FedApp): + def __init__(self): + """Wrapper around `ClientAppConfig`.""" + super().__init__() + self._create_client_app() + + def add_executor(self, executor, tasks=None): + if tasks is None: + tasks = ["*"] # Add executor for any task by default + self.app.add_executor(tasks, executor) + + def _create_client_app(self): + self.app = ClientAppConfig() + + component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") + self.app.add_component("event_to_fed", component) + + +class ControllerApp(FedApp): + """Wrapper around `ServerAppConfig`. + + Args: + """ + + def __init__(self, key_metric="accuracy"): + super().__init__() + self.key_metric = key_metric + self._create_server_app() + + def add_controller(self, controller, id=None): + if id is None: + id = "controller" + self.app.add_workflow(self._gen_tracked_id(id), controller) + + def _create_server_app(self): + self.app: ServerAppConfig = ServerAppConfig() + + component = ValidationJsonGenerator() + self.app.add_component("json_generator", component) + + if self.key_metric: + component = IntimeModelSelector(key_metric=self.key_metric) + self.app.add_component("model_selector", component) + + # TODO: make different tracking receivers configurable + if torch_ok and tb_ok: + component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) + self.app.add_component("receiver", component) + + class FedJob: def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, key_metric="accuracy") -> None: """FedJob allows users to generate job configurations in a Pythonic way. @@ -136,7 +188,7 @@ def to( filter_type: FilterType = None, id=None, ): - """assign an `obj` to a target (server or clients). + """assign an object to a target (server or clients). Args: obj: The object to be assigned. The obj will be given a default `id` if non is provided based on its type. @@ -218,6 +270,51 @@ def to( if self._components: self._add_referenced_components(obj, target) + def to_server( + self, + obj: Any, + filter_type: FilterType = None, + id=None, + ): + """assign an object to the server. + + Args: + obj: The object to be assigned. The obj will be given a default `id` if non is provided based on its type. + filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`. + id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned. + + Returns: + + """ + if isinstance(obj, Executor): + raise ValueError("Use `job.to(executor, )` or `job.to_clients(executor)` for Executors.") + + self.to(obj=obj, target=SERVER_SITE_NAME, filter_type=filter_type, id=id) + + def to_clients( + self, + obj: Any, + tasks: List[str] = None, + filter_type: FilterType = None, + id=None, + ): + """assign an object to all clients. + + Args: + obj (Any): Object to be deployed. + tasks: In case object is an `Executor`, optional list of tasks the executor should handle. + Defaults to `None`. If `None`, all tasks will be handled using `[*]`. + filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`. + id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned. + + Returns: + + """ + if isinstance(obj, Controller): + raise ValueError('Use `job.to(controller, "server")` or `job.to_server(controller)` for Controllers.') + + self.to(obj=obj, target=ALL_SITES, tasks=tasks, filter_type=filter_type, id=id) + def as_id(self, obj: Any): id = str(uuid.uuid4()) self._components[id] = obj @@ -260,10 +357,30 @@ def _set_site_app(self, app: FedApp, target: str): self.job.add_fed_app(app_name, app_config) self.job.set_site_app(target, app_name) + def _set_all_app(self, client_app: ExecutorApp, server_app: ControllerApp): + if not isinstance(client_app, ExecutorApp): + raise ValueError(f"`client_app` needs to be of type `ExecutorApp` but was type {type(client_app)}") + if not isinstance(server_app, ControllerApp): + raise ValueError(f"`server_app` needs to be of type `ControllerApp` but was type {type(server_app)}") + + client_config = client_app.get_app_config() + server_config = server_app.get_app_config() + + app_config = FedAppConfig(server_app=server_config, client_app=client_config) + app_name = "app" + + self.job.add_fed_app(app_name, app_config) + self.job.set_site_app(ALL_SITES, app_name) + def _set_all_apps(self): if not self._deployed: - for target in self._deploy_map: - self._set_site_app(self._deploy_map[target], target) + if ALL_SITES in self._deploy_map: + if SERVER_SITE_NAME not in self._deploy_map: + raise ValueError('Missing server components! Deploy using `to(obj, "server") or `to_server(obj)`') + self._set_all_app(client_app=self._deploy_map[ALL_SITES], server_app=self._deploy_map[SERVER_SITE_NAME]) + else: + for target in self._deploy_map: + self._set_site_app(self._deploy_map[target], target) self._deployed = True @@ -271,10 +388,19 @@ def export_job(self, job_root): self._set_all_apps() self.job.generate_job_config(job_root) - def simulator_run(self, workspace, threads: int = None): + def simulator_run(self, workspace, n_clients: int = None, threads: int = None): self._set_all_apps() + if ALL_SITES in self.clients and not n_clients: + raise ValueError("Clients were not specified using to(). Please provide the number of clients to simulate.") + elif ALL_SITES in self.clients and n_clients: + check_positive_int("n_clients", n_clients) + self.clients = [f"site-{i}" for i in range(1, n_clients + 1)] + elif self.clients and n_clients: + raise ValueError("You already specified clients using `to()`. Don't use `n_clients` in simulator_run.") + n_clients = len(self.clients) + if threads is None: threads = n_clients @@ -290,56 +416,6 @@ def _validate_target(self, target): if not target: raise ValueError("Must provide a valid target name") - if any(c in SPECIAL_CHARACTERS for c in target): + if any(c in SPECIAL_CHARACTERS for c in target) and target != ALL_SITES: raise ValueError(f"target {target} name contains invalid character") pass - - -class ExecutorApp(FedApp): - def __init__(self): - """Wrapper around `ClientAppConfig`.""" - super().__init__() - self._create_client_app() - - def add_executor(self, executor, tasks=None): - if tasks is None: - tasks = ["*"] # Add executor for any task by default - self.app.add_executor(tasks, executor) - - def _create_client_app(self): - self.app = ClientAppConfig() - - component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") - self.app.add_component("event_to_fed", component) - - -class ControllerApp(FedApp): - """Wrapper around `ServerAppConfig`. - - Args: - """ - - def __init__(self, key_metric="accuracy"): - super().__init__() - self.key_metric = key_metric - self._create_server_app() - - def add_controller(self, controller, id=None): - if id is None: - id = "controller" - self.app.add_workflow(self._gen_tracked_id(id), controller) - - def _create_server_app(self): - self.app: ServerAppConfig = ServerAppConfig() - - component = ValidationJsonGenerator() - self.app.add_component("json_generator", component) - - if self.key_metric: - component = IntimeModelSelector(key_metric=self.key_metric) - self.app.add_component("model_selector", component) - - # TODO: make different tracking receivers configurable - if torch_ok and tb_ok: - component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) - self.app.add_component("receiver", component) diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py index 509e35d382..f98bc9ce9f 100644 --- a/nvflare/job_config/fed_job_config.py +++ b/nvflare/job_config/fed_job_config.py @@ -65,6 +65,15 @@ def add_fed_app(self, app_name: str, fed_app: FedAppConfig): self.fed_apps[app_name] = fed_app def set_site_app(self, site_name: str, app_name: str): + """assign an app to a certain site. + + Args: + site_name: The target site name. + app_name: The app name. + + Returns: + + """ if app_name not in self.fed_apps.keys(): raise RuntimeError(f"fed_app {app_name} does not exist.") diff --git a/nvflare/private/aux_runner.py b/nvflare/private/aux_runner.py index 1b338bd037..be7bc8d436 100644 --- a/nvflare/private/aux_runner.py +++ b/nvflare/private/aux_runner.py @@ -152,7 +152,7 @@ def _wait_for_cell(self): ) start = time.time() - self.logger.info(f"waiting for cell for {self.cell_wait_timeout} seconds") + self.logger.debug(f"waiting for cell for {self.cell_wait_timeout} seconds") while True: cell = self.engine.get_cell() if cell: diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index 3ccdcd1435..3952ea5db5 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -149,7 +149,7 @@ def setup(self): for i in range(self.args.n_clients): self.client_names.append("site-" + str(i + 1)) - log_config_file_path = os.path.join(self.args.workspace, "startup", WorkspaceConstants.LOGGING_CONFIG) + log_config_file_path = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG) if not os.path.isfile(log_config_file_path): log_config_file_path = os.path.join(os.path.dirname(__file__), WorkspaceConstants.LOGGING_CONFIG) logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False) @@ -174,6 +174,10 @@ def setup(self): self._cleanup_workspace() init_security_content_service(self.args.workspace) + os.makedirs(os.path.join(self.simulator_root, "server")) + log_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.LOG_FILE_NAME) + add_logfile_handler(log_file) + try: data_bytes, job_name, meta = self.validate_job_data() @@ -271,18 +275,29 @@ def _cleanup_workspace(self): with tempfile.TemporaryDirectory() as temp_dir: startup_dir = os.path.join(self.args.workspace, "startup") temp_start_up = os.path.join(temp_dir, "startup") + local_dir = os.path.join(self.args.workspace, "local") + temp_local_dir = os.path.join(temp_dir, "local") if os.path.exists(startup_dir): shutil.move(startup_dir, temp_start_up) + if os.path.exists(local_dir): + shutil.move(local_dir, temp_local_dir) + if os.path.exists(self.simulator_root): shutil.rmtree(self.simulator_root) + if os.path.exists(temp_start_up): shutil.move(temp_start_up, startup_dir) + if os.path.exists(temp_local_dir): + shutil.move(temp_local_dir, local_dir) def _setup_local_startup(self, log_config_file_path, workspace): local_dir = os.path.join(workspace, "local") startup = os.path.join(workspace, "startup") os.makedirs(local_dir, exist_ok=True) shutil.copyfile(log_config_file_path, os.path.join(local_dir, WorkspaceConstants.LOGGING_CONFIG)) + workspace_local = os.path.join(self.simulator_root, "local") + if os.path.exists(workspace_local): + shutil.copytree(workspace_local, local_dir, dirs_exist_ok=True) shutil.copytree(os.path.join(self.simulator_root, "startup"), startup) def validate_job_data(self): @@ -490,9 +505,6 @@ def start_server_app(self, args): args.workspace = os.path.join(self.simulator_root, "server") os.chdir(args.workspace) - log_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.LOG_FILE_NAME) - add_logfile_handler(log_file) - args.server_config = os.path.join("config", JobConstants.SERVER_JOB_CONFIG) app_custom_folder = os.path.join(app_server_root, "custom") if os.path.isdir(app_custom_folder) and app_custom_folder not in sys.path: diff --git a/nvflare/private/fed/client/client_app_runner.py b/nvflare/private/fed/client/client_app_runner.py index e3a84794b8..4ac3d3cb09 100644 --- a/nvflare/private/fed/client/client_app_runner.py +++ b/nvflare/private/fed/client/client_app_runner.py @@ -86,7 +86,11 @@ def create_client_runner(self, app_root, args, config_folder, federated_client, client_config_file_name = os.path.join(app_root, args.client_config) args.set.append(f"secure_train={secure_train}") conf = ClientJsonConfigurator( - config_file_name=client_config_file_name, app_root=app_root, args=args, kv_list=args.set + workspace_obj=workspace, + config_file_name=client_config_file_name, + app_root=app_root, + args=args, + kv_list=args.set, ) if event_handlers: conf.set_component_build_authorizer(authorize_build_component, fl_ctx=fl_ctx, event_handlers=event_handlers) diff --git a/nvflare/private/fed/client/client_json_config.py b/nvflare/private/fed/client/client_json_config.py index 238d1c13e6..c0f3770ad9 100644 --- a/nvflare/private/fed/client/client_json_config.py +++ b/nvflare/private/fed/client/client_json_config.py @@ -17,6 +17,7 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import SystemConfigs, SystemVarName +from nvflare.apis.workspace import Workspace from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -37,7 +38,9 @@ def __init__(self): class ClientJsonConfigurator(FedJsonConfigurator): - def __init__(self, config_file_name: str, args, app_root: str, kv_list=None, exclude_libs=True): + def __init__( + self, workspace_obj: Workspace, config_file_name: str, args, app_root: str, kv_list=None, exclude_libs=True + ): """To init the ClientJsonConfigurator. Args: @@ -67,6 +70,7 @@ def __init__(self, config_file_name: str, args, app_root: str, kv_list=None, exc SystemVarName.WORKSPACE: args.workspace, SystemVarName.ROOT_URL: sp_url, SystemVarName.SECURE_MODE: self.cmd_vars.get("secure_train", True), + SystemVarName.JOB_CUSTOM_DIR: workspace_obj.get_app_custom_dir(args.job_id), } FedJsonConfigurator.__init__( diff --git a/nvflare/private/fed/server/server_app_runner.py b/nvflare/private/fed/server/server_app_runner.py index 5af8ea891a..d078adf31d 100644 --- a/nvflare/private/fed/server/server_app_runner.py +++ b/nvflare/private/fed/server/server_app_runner.py @@ -57,7 +57,11 @@ def start_server_app( server_config_file_name = os.path.join(app_root, args.server_config) conf = ServerJsonConfigurator( - config_file_name=server_config_file_name, app_root=app_root, args=args, kv_list=kv_list + workspace_obj=workspace, + config_file_name=server_config_file_name, + app_root=app_root, + args=args, + kv_list=kv_list, ) if event_handlers: fl_ctx = FLContext() diff --git a/nvflare/private/fed/server/server_json_config.py b/nvflare/private/fed/server/server_json_config.py index f58cb8a2f2..f93ef6805f 100644 --- a/nvflare/private/fed/server/server_json_config.py +++ b/nvflare/private/fed/server/server_json_config.py @@ -18,6 +18,7 @@ from nvflare.apis.fl_constant import SystemConfigs, SystemVarName from nvflare.apis.impl.controller import Controller from nvflare.apis.impl.wf_comm_server import WFCommServer +from nvflare.apis.workspace import Workspace from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -45,7 +46,9 @@ def __init__(self, id, controller: Controller): class ServerJsonConfigurator(FedJsonConfigurator): - def __init__(self, config_file_name: str, args, app_root: str, kv_list=None, exclude_libs=True): + def __init__( + self, workspace_obj: Workspace, config_file_name: str, args, app_root: str, kv_list=None, exclude_libs=True + ): """This class parses server config from json file. Args: @@ -70,6 +73,7 @@ def __init__(self, config_file_name: str, args, app_root: str, kv_list=None, exc SystemVarName.SITE_NAME: "server", SystemVarName.WORKSPACE: args.workspace, SystemVarName.SECURE_MODE: self.cmd_vars.get("secure_train", True), + SystemVarName.JOB_CUSTOM_DIR: workspace_obj.get_app_custom_dir(args.job_id), } FedJsonConfigurator.__init__( diff --git a/tests/integration_test/overseer_test.py b/tests/integration_test/overseer_test.py index cd9671bd4b..70dc2c9c9e 100644 --- a/tests/integration_test/overseer_test.py +++ b/tests/integration_test/overseer_test.py @@ -83,7 +83,7 @@ def test_overseer_overseer_down_and_up(self): oa_launcher.stop_overseer() time.sleep(10) oa_launcher.start_overseer() - time.sleep(10) + time.sleep(20) for client_agent in client_agent_list: psp = oa_launcher.get_primary_sp(client_agent) assert psp.name == "server00"