Skip to content

Commit

Permalink
Update scaffold.py to get just trainable layer names
Browse files Browse the repository at this point in the history
  • Loading branch information
falibabaei authored Jul 24, 2024
1 parent 16ca72c commit 1cf7ec3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nvflare/app_opt/tf/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def init(self, model):
self.c_local.set_weights(c_init_para)

# Generate a list of the flattened layers
layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers}
layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers if layer.trainable}
flattened_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict)
self.global_keys = [key for key, _ in flattened_layer_weights_dict.items()]

Expand Down

0 comments on commit 1cf7ec3

Please sign in to comment.