diff --git a/nvflare/app_opt/tf/fedopt_ctl.py b/nvflare/app_opt/tf/fedopt_ctl.py index 7bfd051e8c..45ca4c50c0 100644 --- a/nvflare/app_opt/tf/fedopt_ctl.py +++ b/nvflare/app_opt/tf/fedopt_ctl.py @@ -106,69 +106,69 @@ def _to_tf_params_list(self, params: Dict, negate: bool = False): tf_params_list.append(tf.Variable(v)) return tf_params_list - def update_model(self, global_model: FLModel, aggr_result: FLModel): - """ - Override the default version of update_model - to perform update with Keras Optimizer on the - global model stored in memory in persistor, instead of - creating new temporary model on-the-fly. - - Creating a new model would not work for Keras - Optimizers, since an optimizer is bind to - specific set of Variables. - - """ - global_model_tf = self.persistor.model - global_params = global_model_tf.trainable_weights - num_trainable_weights = len(global_params) - - model_diff_params = {} - - w_idx = 0 - - for key, param in global_model.params.items(): - if w_idx >= num_trainable_weights: - break - - if param.shape == global_params[w_idx].shape: - model_diff_params[key] = ( - aggr_result.params[key] - param - if aggr_result.params_type == ParamsType.FULL - else aggr_result.params[key] - ) - w_idx += 1 + def update_model(self, global_model: FLModel, aggr_result: FLModel): + """ + Override the default version of update_model + to perform update with Keras Optimizer on the + global model stored in memory in persistor, instead of + creating new temporary model on-the-fly. - model_diff = self._to_tf_params_list(model_diff_params, negate=True) - start = time.time() + Creating a new model would not work for Keras + Optimizers, since an optimizer is bind to + specific set of Variables. - self.optimizer.apply_gradients(zip(model_diff, global_params)) - secs = time.time() - start + """ + global_model_tf = self.persistor.model + global_params = global_model_tf.trainable_weights + num_trainable_weights = len(global_params) - start = time.time() - weights = global_model_tf.get_weights() + model_diff_params = {} - new_weights = {} - for w_idx, key in enumerate(global_model.params): - if key in model_diff_params: - new_weights[key] = weights[w_idx] + w_idx = 0 - else: + for key, param in global_model.params.items(): + if w_idx >= num_trainable_weights: + break - new_weights[key] = ( - aggr_result.params[key] - if aggr_result.params_type == ParamsType.FULL - else global_model.params[key] + aggr_result.params[key] + if param.shape == global_params[w_idx].shape: + model_diff_params[key] = ( + aggr_result.params[key] - param + if aggr_result.params_type == ParamsType.FULL + else aggr_result.params[key] ) - secs_detach = time.time() - start - self.info( - f"FedOpt ({type(self.optimizer)}) server model update " - f"round {self.current_round}, " - f"{type(self.lr_scheduler)} " - f"lr: {self.optimizer.learning_rate(self.optimizer.iterations).numpy()}, " - f"update: {secs} secs., detach: {secs_detach} secs.", - ) - - global_model.params = new_weights - global_model.meta = aggr_result.meta - - return global_model + w_idx += 1 + + model_diff = self._to_tf_params_list(model_diff_params, negate=True) + start = time.time() + + self.optimizer.apply_gradients(zip(model_diff, global_params)) + secs = time.time() - start + + start = time.time() + weights = global_model_tf.get_weights() + + new_weights = {} + for w_idx, key in enumerate(global_model.params): + if key in model_diff_params: + new_weights[key] = weights[w_idx] + + else: + + new_weights[key] = ( + aggr_result.params[key] + if aggr_result.params_type == ParamsType.FULL + else global_model.params[key] + aggr_result.params[key] + ) + secs_detach = time.time() - start + self.info( + f"FedOpt ({type(self.optimizer)}) server model update " + f"round {self.current_round}, " + f"{type(self.lr_scheduler)} " + f"lr: {self.optimizer.learning_rate(self.optimizer.iterations).numpy()}, " + f"update: {secs} secs., detach: {secs_detach} secs.", + ) + + global_model.params = new_weights + global_model.meta = aggr_result.meta + + return global_model