Skip to content

Commit

Permalink
revert the style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
falibabaei committed Aug 6, 2024
1 parent 3bf6a27 commit a2d41c3
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 11 deletions.
7 changes: 3 additions & 4 deletions examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
import argparse
import copy


import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets, losses
from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss
from tf_net import ModerateTFNet

# (1) import nvflare client API
import nvflare.client as flare
from nvflare.app_opt.tf.fedprox_loss import TFFedProxLoss

PATH = "./tf_model.weights.h5"

Expand Down Expand Up @@ -168,13 +167,13 @@ def main():
model.get_layer(k).set_weights(v)

if args.fedprox_mu > 0:

local_model_weights = model.trainable_variables
global_model_weights = copy.deepcopy(model.trainable_variables)
model.loss = TFFedProxLoss(local_model_weights, global_model_weights, args.fedprox_mu, loss)
elif args.fedprox_mu < 0.0:

raise ValueError("mu should be no less than 0.0")
raise ValueError("mu should be no less than 0.0")

# (5) evaluate aggregated/received model
_, test_global_acc = model.evaluate(x=test_ds, verbose=2)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_opt/tf/scaffold.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 0 additions & 1 deletion nvflight/build_wheel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion nvflight/prepare_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,4 @@ def prepare_setup(setup_dir: str):

for src in src_files:
shutil.copy(src, os.path.join(setup_dir, os.path.basename(src)))


2 changes: 1 addition & 1 deletion nvflight/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@
packages=release_package,
package_data=package_data,
include_package_data=True,
)
)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,5 @@ def remove_dir(target_path):
include_package_data=True,
)

remove_dir(target_path=tmp_job_template_folder)
remove_dir(target_path=tmp_job_template_folder)

2 changes: 1 addition & 1 deletion versioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,4 +2106,4 @@ def scan_setup_py():
errors = do_setup()
errors += scan_setup_py()
if errors:
sys.exit(1)
sys.exit(1)

0 comments on commit a2d41c3

Please sign in to comment.