From 4e0bd9cc30f915765779a497211cb8d8acf7e446 Mon Sep 17 00:00:00 2001 From: LeoDuda Date: Mon, 29 Jul 2024 15:35:33 +0200 Subject: [PATCH] Fix for models with non-trainable variables --- nvflare/app_opt/tf/scaffold.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/nvflare/app_opt/tf/scaffold.py b/nvflare/app_opt/tf/scaffold.py index 4b1ceabf51..7fbfe1ff5f 100644 --- a/nvflare/app_opt/tf/scaffold.py +++ b/nvflare/app_opt/tf/scaffold.py @@ -17,7 +17,7 @@ import numpy as np import tensorflow as tf -from .utils import flat_layer_weights_dict +from nvflare.app_opt.tf.utils import flat_layer_weights_dict gpu_devices = tf.config.experimental.list_physical_devices("GPU") for device in gpu_devices: @@ -52,7 +52,7 @@ def init(self, model): self.c_local.set_weights(c_init_para) # Generate a list of the flattened layers - layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers if layer.trainable} + layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers} flattened_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict) self.global_keys = [key for key, _ in flattened_layer_weights_dict.items()] @@ -86,7 +86,7 @@ def terms_update( model_global, ): c_new_para = self.c_local.trainable_variables - self.c_delta_para = copy.deepcopy(self.c_local.trainable_variables) + self.c_delta_para = dict() global_model_para = model_global.trainable_variables net_para = model.trainable_variables scaler = 1 / (self.cnt * curr_lr) @@ -106,8 +106,14 @@ def terms_update( ) c_delta_para_value = tf.nest.map_structure(lambda a, b: a - b, c_new_para, c_local_para) - - self.c_delta_para = {self.global_keys[i]: delta_val.numpy() for i, delta_val in enumerate(c_delta_para_value)} + trainable_variables = [var.name for var in self.c_local.trainable_variables] + j = 0 + for i, var in enumerate(self.c_local.variables): + if var.name in trainable_variables: + self.c_delta_para[self.global_keys[i]] = c_delta_para_value[j].numpy() + j = j+1 + else: + self.c_delta_para[self.global_keys[i]] = model.variables[i].numpy() for var, new_weight in zip(self.c_local.trainable_variables, c_new_para): var.assign(new_weight) @@ -132,4 +138,4 @@ def __init__(self, scaffold_helper): def on_epoch_end(self, epoch, logs=None): curr_lr = self.model.optimizer.learning_rate self.scaffold_helper.model_update(self.model, curr_lr, self.c_global_para, self.c_local_para) - print(f"SCAFFOLD model updated at end of epoch {epoch + 1}") + print(f"SCAFFOLD model updated at end of epoch {epoch + 1}") \ No newline at end of file