Skip to content

Commit

Permalink
update the style
Browse files Browse the repository at this point in the history
  • Loading branch information
falibabaei committed Aug 4, 2024
1 parent fc8ea58 commit e31c032
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
# (1) import nvflare client API
import nvflare.client as flare
from nvflare.app_common.app_constant import AlgorithmConstants
from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss
from nvflare.app_opt.tf.scaffold import ScaffoldCallback, TFScaffoldHelper, get_lr_values
from nvflare.client.tracking import SummaryWriter
from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss

PATH = "./tf_model.weights.h5"

Expand Down Expand Up @@ -199,14 +199,13 @@ def main():
model.get_layer(k).set_weights(v)

if args.fedprox_mu > 0:

local_model_weights = model.trainable_variables
global_model_weights = copy.deepcopy(model.trainable_variables)
model.loss = TFFedProxLoss(local_model_weights, global_model_weights, args.fedprox_mu, loss)
elif args.fedprox_mu < 0.0:

raise ValueError("mu should be no less than 0.0")


# (step 4) load regularization parameters from scaffold
global_ctrl_weights = input_model.meta.get(AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL)
Expand Down
2 changes: 0 additions & 2 deletions examples/getting_started/tf/tf_fl_script_executor_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
FEDPROX_ALGO = "fedprox"




if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
4 changes: 0 additions & 4 deletions nvflare/app_opt/tf/fedopt_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def run(self):
try:
if "args" not in self.optimizer_args:
self.optimizer_args["args"] = {}
<<<<<<< HEAD
# self.optimizer_args["args"]["params"] = self.keras_model.parameters()
=======
>>>>>>> upstream/main
self.optimizer = self.build_component(self.optimizer_args)
except Exception as e:
error_msg = f"Exception while constructing optimizer: {secure_format_exception(e)}"
Expand Down

0 comments on commit e31c032

Please sign in to comment.