Skip to content

Commit

Permalink
Fix for models with non-trainable variables
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoDuda committed Jul 29, 2024
1 parent 32bdbe2 commit 4e0bd9c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions nvflare/app_opt/tf/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import tensorflow as tf

from .utils import flat_layer_weights_dict
from nvflare.app_opt.tf.utils import flat_layer_weights_dict

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
for device in gpu_devices:
Expand Down Expand Up @@ -52,7 +52,7 @@ def init(self, model):
self.c_local.set_weights(c_init_para)

# Generate a list of the flattened layers
layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers if layer.trainable}
layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers}
flattened_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict)
self.global_keys = [key for key, _ in flattened_layer_weights_dict.items()]

Expand Down Expand Up @@ -86,7 +86,7 @@ def terms_update(
model_global,
):
c_new_para = self.c_local.trainable_variables
self.c_delta_para = copy.deepcopy(self.c_local.trainable_variables)
self.c_delta_para = dict()
global_model_para = model_global.trainable_variables
net_para = model.trainable_variables
scaler = 1 / (self.cnt * curr_lr)
Expand All @@ -106,8 +106,14 @@ def terms_update(
)

c_delta_para_value = tf.nest.map_structure(lambda a, b: a - b, c_new_para, c_local_para)

self.c_delta_para = {self.global_keys[i]: delta_val.numpy() for i, delta_val in enumerate(c_delta_para_value)}
trainable_variables = [var.name for var in self.c_local.trainable_variables]
j = 0
for i, var in enumerate(self.c_local.variables):
if var.name in trainable_variables:
self.c_delta_para[self.global_keys[i]] = c_delta_para_value[j].numpy()
j = j+1
else:
self.c_delta_para[self.global_keys[i]] = model.variables[i].numpy()

for var, new_weight in zip(self.c_local.trainable_variables, c_new_para):
var.assign(new_weight)
Expand All @@ -132,4 +138,4 @@ def __init__(self, scaffold_helper):
def on_epoch_end(self, epoch, logs=None):
curr_lr = self.model.optimizer.learning_rate
self.scaffold_helper.model_update(self.model, curr_lr, self.c_global_para, self.c_local_para)
print(f"SCAFFOLD model updated at end of epoch {epoch + 1}")
print(f"SCAFFOLD model updated at end of epoch {epoch + 1}")

0 comments on commit 4e0bd9c

Please sign in to comment.