From 1cf7ec399464b248847d5727aa910eccc991d5d8 Mon Sep 17 00:00:00 2001 From: falibabaei <66964597+falibabaei@users.noreply.github.com> Date: Wed, 24 Jul 2024 20:25:00 +0200 Subject: [PATCH] Update scaffold.py to get just trainable layer names --- nvflare/app_opt/tf/scaffold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/app_opt/tf/scaffold.py b/nvflare/app_opt/tf/scaffold.py index 19d4948377..4b1ceabf51 100644 --- a/nvflare/app_opt/tf/scaffold.py +++ b/nvflare/app_opt/tf/scaffold.py @@ -52,7 +52,7 @@ def init(self, model): self.c_local.set_weights(c_init_para) # Generate a list of the flattened layers - layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers} + layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers if layer.trainable} flattened_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict) self.global_keys = [key for key, _ in flattened_layer_weights_dict.items()]