Skip to content

Commit

Permalink
add comments in the tf/fedopt_ctl
Browse files Browse the repository at this point in the history
  • Loading branch information
falibabaei committed Nov 22, 2024
1 parent 37f5bb9 commit fa3700f
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion nvflare/app_opt/tf/fedopt_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def update_model(self, global_model: FLModel, aggr_result: FLModel):
global_params = global_model_tf.trainable_weights
num_trainable_weights = len(global_params)

# Compute model diff: need to use model diffs as
# gradients to be applied by the optimizer.
model_diff_params = {}

w_idx = 0
Expand All @@ -140,11 +142,15 @@ def update_model(self, global_model: FLModel, aggr_result: FLModel):
w_idx += 1

model_diff = self._to_tf_params_list(model_diff_params, negate=True)

# Apply model diffs as gradients, using the optimizer.
start = time.time()

self.optimizer.apply_gradients(zip(model_diff, global_params))
secs = time.time() - start


# Convert updated global model weights to
# numpy format for FLModel.
start = time.time()
weights = global_model_tf.get_weights()

Expand Down

0 comments on commit fa3700f

Please sign in to comment.