Skip to content

Commit

Permalink
remove the indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
khadijeh.alibabaei committed Nov 15, 2024
1 parent 5521345 commit a183869
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions nvflare/app_opt/tf/fedopt_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,69 +106,69 @@ def _to_tf_params_list(self, params: Dict, negate: bool = False):
tf_params_list.append(tf.Variable(v))
return tf_params_list

def update_model(self, global_model: FLModel, aggr_result: FLModel):
"""
Override the default version of update_model
to perform update with Keras Optimizer on the
global model stored in memory in persistor, instead of
creating new temporary model on-the-fly.
Creating a new model would not work for Keras
Optimizers, since an optimizer is bind to
specific set of Variables.
"""
global_model_tf = self.persistor.model
global_params = global_model_tf.trainable_weights
num_trainable_weights = len(global_params)

model_diff_params = {}

w_idx = 0

for key, param in global_model.params.items():
if w_idx >= num_trainable_weights:
break

if param.shape == global_params[w_idx].shape:
model_diff_params[key] = (
aggr_result.params[key] - param
if aggr_result.params_type == ParamsType.FULL
else aggr_result.params[key]
)
w_idx += 1
def update_model(self, global_model: FLModel, aggr_result: FLModel):
"""
Override the default version of update_model
to perform update with Keras Optimizer on the
global model stored in memory in persistor, instead of
creating new temporary model on-the-fly.
model_diff = self._to_tf_params_list(model_diff_params, negate=True)
start = time.time()
Creating a new model would not work for Keras
Optimizers, since an optimizer is bind to
specific set of Variables.
self.optimizer.apply_gradients(zip(model_diff, global_params))
secs = time.time() - start
"""
global_model_tf = self.persistor.model
global_params = global_model_tf.trainable_weights
num_trainable_weights = len(global_params)

start = time.time()
weights = global_model_tf.get_weights()
model_diff_params = {}

new_weights = {}
for w_idx, key in enumerate(global_model.params):
if key in model_diff_params:
new_weights[key] = weights[w_idx]
w_idx = 0

else:
for key, param in global_model.params.items():
if w_idx >= num_trainable_weights:
break

new_weights[key] = (
aggr_result.params[key]
if aggr_result.params_type == ParamsType.FULL
else global_model.params[key] + aggr_result.params[key]
if param.shape == global_params[w_idx].shape:
model_diff_params[key] = (
aggr_result.params[key] - param
if aggr_result.params_type == ParamsType.FULL
else aggr_result.params[key]
)
secs_detach = time.time() - start
self.info(
f"FedOpt ({type(self.optimizer)}) server model update "
f"round {self.current_round}, "
f"{type(self.lr_scheduler)} "
f"lr: {self.optimizer.learning_rate(self.optimizer.iterations).numpy()}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)

global_model.params = new_weights
global_model.meta = aggr_result.meta

return global_model
w_idx += 1

model_diff = self._to_tf_params_list(model_diff_params, negate=True)
start = time.time()

self.optimizer.apply_gradients(zip(model_diff, global_params))
secs = time.time() - start

start = time.time()
weights = global_model_tf.get_weights()

new_weights = {}
for w_idx, key in enumerate(global_model.params):
if key in model_diff_params:
new_weights[key] = weights[w_idx]

else:

new_weights[key] = (
aggr_result.params[key]
if aggr_result.params_type == ParamsType.FULL
else global_model.params[key] + aggr_result.params[key]
)
secs_detach = time.time() - start
self.info(
f"FedOpt ({type(self.optimizer)}) server model update "
f"round {self.current_round}, "
f"{type(self.lr_scheduler)} "
f"lr: {self.optimizer.learning_rate(self.optimizer.iterations).numpy()}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)

global_model.params = new_weights
global_model.meta = aggr_result.meta

return global_model

0 comments on commit a183869

Please sign in to comment.