diff --git a/nvflare/app_opt/tf/fedopt_ctl.py b/nvflare/app_opt/tf/fedopt_ctl.py index b04ac9af28..c39a5d2376 100644 --- a/nvflare/app_opt/tf/fedopt_ctl.py +++ b/nvflare/app_opt/tf/fedopt_ctl.py @@ -123,6 +123,8 @@ def update_model(self, global_model: FLModel, aggr_result: FLModel): global_params = global_model_tf.trainable_weights num_trainable_weights = len(global_params) + # Compute model diff: need to use model diffs as + # gradients to be applied by the optimizer. model_diff_params = {} w_idx = 0 @@ -140,11 +142,15 @@ def update_model(self, global_model: FLModel, aggr_result: FLModel): w_idx += 1 model_diff = self._to_tf_params_list(model_diff_params, negate=True) + + # Apply model diffs as gradients, using the optimizer. start = time.time() self.optimizer.apply_gradients(zip(model_diff, global_params)) secs = time.time() - start - + + # Convert updated global model weights to + # numpy format for FLModel. start = time.time() weights = global_model_tf.get_weights()