diff --git a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py index 24fba56990..f1eb461373 100644 --- a/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py +++ b/examples/getting_started/tf/src/cifar10_tf_fl_alpha_split_scaffold.py @@ -24,9 +24,9 @@ # (1) import nvflare client API import nvflare.client as flare from nvflare.app_common.app_constant import AlgorithmConstants +from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss from nvflare.app_opt.tf.scaffold import ScaffoldCallback, TFScaffoldHelper, get_lr_values from nvflare.client.tracking import SummaryWriter -from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss PATH = "./tf_model.weights.h5" @@ -199,14 +199,13 @@ def main(): model.get_layer(k).set_weights(v) if args.fedprox_mu > 0: - + local_model_weights = model.trainable_variables global_model_weights = copy.deepcopy(model.trainable_variables) model.loss = TFFedProxLoss(local_model_weights, global_model_weights, args.fedprox_mu, loss) elif args.fedprox_mu < 0.0: raise ValueError("mu should be no less than 0.0") - # (step 4) load regularization parameters from scaffold global_ctrl_weights = input_model.meta.get(AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL) diff --git a/examples/getting_started/tf/tf_fl_script_executor_cifar10.py b/examples/getting_started/tf/tf_fl_script_executor_cifar10.py index 6037bf0d69..4f2d294f03 100644 --- a/examples/getting_started/tf/tf_fl_script_executor_cifar10.py +++ b/examples/getting_started/tf/tf_fl_script_executor_cifar10.py @@ -34,8 +34,6 @@ FEDPROX_ALGO = "fedprox" - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/nvflare/app_opt/tf/fedopt_ctl.py b/nvflare/app_opt/tf/fedopt_ctl.py index 42814cf55a..0ec0d2420b 100644 --- a/nvflare/app_opt/tf/fedopt_ctl.py +++ b/nvflare/app_opt/tf/fedopt_ctl.py @@ -72,10 +72,6 @@ def run(self): try: if "args" not in self.optimizer_args: self.optimizer_args["args"] = {} -<<<<<<< HEAD - # self.optimizer_args["args"]["params"] = self.keras_model.parameters() -======= ->>>>>>> upstream/main self.optimizer = self.build_component(self.optimizer_args) except Exception as e: error_msg = f"Exception while constructing optimizer: {secure_format_exception(e)}"