Skip to content

Commit

Permalink
prepared scaffold for TF2.17
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoDuda committed Jul 31, 2024
1 parent 1528ce4 commit e69cfc1
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions nvflare/app_opt/tf/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@

tf.debugging.enable_check_numerics()

def optimize_weights(model, c_delta_para_value):
"""
Efficiently assigns weights to `self.c_delta_para` based on trainability or non tainability.
Args:
model: The TensorFlow model containing layers.
c_delta_para_value: Delta values for trainable variables.
Returns:
None. Modifies the `self.c_delta_para` attribute in-place.
"""
c_delta_para={}
layer_weights_dict = {layer.name: layer.get_weights() for layer in model.layers}
trainable_layers_dict = {layer.name: layer.get_weights() for layer in model.layers if layer.trainable}
flatten_layer_weights_dict= flat_layer_weights_dict(layer_weights_dict)
flatten_trainable_layers_dict = flat_layer_weights_dict(trainable_layers_dict)

trainable_layer_names = list(flatten_trainable_layers_dict.keys())
j = 0
for layer_name, weight in flatten_layer_weights_dict.items():
if layer_name in trainable_layer_names:
c_delta_para[layer_name] = c_delta_para_value[j].numpy()
j +=1
else:
c_delta_para[layer_name] = weight

return c_delta_para


def get_lr_values(optimizer):
"""
Expand All @@ -50,12 +78,7 @@ def init(self, model):
c_init_para = [np.zeros(shape) for shape in [w.shape for w in model.get_weights()]]
self.c_global.set_weights(c_init_para)
self.c_local.set_weights(c_init_para)

# Generate a list of the flattened layers
layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers}
flattened_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict)
self.global_keys = [key for key, _ in flattened_layer_weights_dict.items()]


def get_params(self):
self.cnt = 0
c_global_para = self.c_global.trainable_variables
Expand Down Expand Up @@ -86,7 +109,6 @@ def terms_update(
model_global,
):
c_new_para = self.c_local.trainable_variables
self.c_delta_para = dict()
global_model_para = model_global.trainable_variables
net_para = model.trainable_variables
scaler = 1 / (self.cnt * curr_lr)
Expand All @@ -106,14 +128,9 @@ def terms_update(
)

c_delta_para_value = tf.nest.map_structure(lambda a, b: a - b, c_new_para, c_local_para)
trainable_variables = [var.name for var in self.c_local.trainable_variables]
j = 0
for i, var in enumerate(self.c_local.variables):
if var.name in trainable_variables:
self.c_delta_para[self.global_keys[i]] = c_delta_para_value[j].numpy()
j = j+1
else:
self.c_delta_para[self.global_keys[i]] = model.variables[i].numpy()
self.c_delta_para = optimize_weights(model, c_delta_para_value)



for var, new_weight in zip(self.c_local.trainable_variables, c_new_para):
var.assign(new_weight)
Expand Down

0 comments on commit e69cfc1

Please sign in to comment.