Skip to content

Commit

Permalink
Update scaffold.py
Browse files Browse the repository at this point in the history
  • Loading branch information
falibabaei authored Aug 1, 2024
1 parent 2006c97 commit f5b6d03
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nvflare/app_opt/tf/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def optimize_weights(model, c_delta_para_value):
"""
c_delta_para = {}
layer_weights_dict = {layer.name: layer.get_weights() for layer in model.layers}
trainable_layers_dict = {layer.name: layer.get_weights() for layer in model.layers if layer.trainable}
trainable_layers_dict = {layer.name: layer.trainable_weights for layer in model.layers if layer.trainable_weights}
flatten_layer_weights_dict= flat_layer_weights_dict(layer_weights_dict)
flatten_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict)
flatten_trainable_layers_dict = flat_layer_weights_dict(trainable_layers_dict)

Expand Down

0 comments on commit f5b6d03

Please sign in to comment.