Skip to content

Commit

Permalink
added support for TF 2.17
Browse files Browse the repository at this point in the history
  • Loading branch information
uvecw@student.kit.edu committed Jul 24, 2024
1 parent f7e1583 commit 16ca72c
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions nvflare/app_opt/tf/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand Down Expand Up @@ -43,39 +43,34 @@ def __init__(self):
self.c_delta_para = None
self.global_keys = None

# self.clip_norm = 1.0

def init(self, model):
self.c_global = tf.keras.models.clone_model(model)
self.c_local = tf.keras.models.clone_model(model)
# Initialize correction term with zeros
c_init_para = {v.name: np.zeros_like(v.numpy()) for v in model.variables}
self.c_global.set_weights([c_init_para[k] for k in c_init_para])
self.c_local.set_weights([c_init_para[k] for k in c_init_para])
c_init_para = [np.zeros(shape) for shape in [w.shape for w in model.get_weights()]]
self.c_global.set_weights(c_init_para)
self.c_local.set_weights(c_init_para)

# Generate a list of the flattened layers
layer_weights_dict = {layer.name: layer.get_weights() for layer in self.c_global.layers}
flattened_layer_weights_dict = flat_layer_weights_dict(layer_weights_dict)
self.global_keys = [key for key, _ in flattened_layer_weights_dict.items()]
print("Gloabl")
print(self.global_keys)

def get_params(self):
self.cnt = 0
c_global_para = self.c_global.variables
c_local_para = self.c_local.variables
c_global_para = self.c_global.trainable_variables
c_local_para = self.c_local.trainable_variables
return c_global_para, c_local_para

def model_update(self, model, curr_lr, c_global_para, c_local_para):
net_para = model.variables # Access only trainable variables
net_para = model.trainable_variables # Access only trainable trainable_variables
trainable_var_names = [var.name for var in model.trainable_variables]
model_difference = tf.nest.map_structure(
lambda a, b: tf.multiply(curr_lr, a - b),
lambda a, b: tf.multiply(tf.cast(curr_lr, a.dtype), a - b),
c_global_para,
c_local_para,
)
new_weights = tf.nest.map_structure(lambda a, b: a - b, net_para, model_difference)
# print('the length of the weights are:',(new_weights))
for var, new_weight in zip(net_para, new_weights):
if var.name in trainable_var_names:
var.assign(new_weight)
Expand All @@ -90,16 +85,16 @@ def terms_update(
c_local_para,
model_global,
):
c_new_para = self.c_local.variables
self.c_delta_para = copy.deepcopy(self.c_local.variables)
global_model_para = model_global.variables
net_para = model.variables
c_new_para = self.c_local.trainable_variables
self.c_delta_para = copy.deepcopy(self.c_local.trainable_variables)
global_model_para = model_global.trainable_variables
net_para = model.trainable_variables
scaler = 1 / (self.cnt * curr_lr)

c_new_para_c_global = tf.nest.map_structure(lambda a, b: a - b, c_new_para, c_global_para)

global_model_para_net_para = tf.nest.map_structure(
lambda a, b: tf.multiply(scaler, a - b),
lambda a, b: tf.multiply(tf.cast(scaler, a.dtype), a - b),
global_model_para,
net_para,
)
Expand All @@ -112,9 +107,10 @@ def terms_update(

c_delta_para_value = tf.nest.map_structure(lambda a, b: a - b, c_new_para, c_local_para)

self.c_delta_para = {var.name: c_delta_para_value[i].numpy() for i, var in enumerate(net_para)}
self.c_delta_para = {self.global_keys[i]: delta_val.numpy() for i, delta_val in enumerate(c_delta_para_value)}

self.c_local.set_weights(c_new_para)
for var, new_weight in zip(self.c_local.trainable_variables, c_new_para):
var.assign(new_weight)

def load_global_controls(self, weights):
weights_values = [v for _, v in weights.items()]
Expand All @@ -124,14 +120,6 @@ def get_delta_controls(self):
if self.c_delta_para is None:
raise ValueError("c_delta_para hasn't been computed yet!")

print(type(self.c_delta_para))

# print(self.c_delta_para)

c_delta_para_new = {self.global_keys[i]: value for i, (key, value) in enumerate(self.c_delta_para.items())}

self.c_delta_para = c_delta_para_new

return self.c_delta_para


Expand Down

0 comments on commit 16ca72c

Please sign in to comment.