Skip to content

Commit

Permalink
Merge branch 'main' into tf_algos
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen authored Jul 25, 2024
2 parents 1cf7ec3 + ce02675 commit ed5249f
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 100 deletions.
4 changes: 2 additions & 2 deletions nvflare/app_opt/xgboost/histogram_based/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self, xgb_params: dict, num_rounds=10, early_stopping_rounds=2, ver
"""Container for all XGBoost parameters.
Args:
xgb_params: This dict is passed to `xgboost.train()` as the first argument `params`.
It contains all the Booster parameters.
xgb_params: The Booster parameters. This dict is passed to `xgboost.train()`
as the argument `params`. It contains all the Booster parameters.
Please refer to XGBoost documentation for details:
https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.training
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class since the self object contains a sender that contains a Core Cell which ca
Constant.RUNNER_CTX_SERVER_ADDR: server_addr,
Constant.RUNNER_CTX_RANK: self.rank,
Constant.RUNNER_CTX_NUM_ROUNDS: self.num_rounds,
Constant.RUNNER_CTX_TRAINING_MODE: self.training_mode,
Constant.RUNNER_CTX_XGB_PARAMS: self.xgb_params,
Constant.RUNNER_CTX_XGB_OPTIONS: self.xgb_options,
Constant.RUNNER_CTX_MODEL_DIR: self._run_dir,
Constant.RUNNER_CTX_TB_DIR: self._app_dir,
}
Expand All @@ -96,9 +99,6 @@ def start(self, fl_ctx: FLContext):
if self.rank is None:
raise RuntimeError("cannot start - my rank is not set")

if not self.num_rounds:
raise RuntimeError("cannot start - num_rounds is not set")

# dynamically determine address on localhost
port = get_open_tcp_port(resources={})
if not port:
Expand Down
26 changes: 19 additions & 7 deletions nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def __init__(self, in_process, per_msg_timeout: float, tx_timeout: float):
self.stopped = False
self.rank = None
self.num_rounds = None
self.training_mode = None
self.xgb_params = None
self.xgb_options = None
self.world_size = None
self.per_msg_timeout = per_msg_timeout
self.tx_timeout = tx_timeout
Expand All @@ -163,27 +166,36 @@ def configure(self, config: dict, fl_ctx: FLContext):
Returns: None
"""
ws = config.get(Constant.CONF_KEY_WORLD_SIZE)
ranks = config.get(Constant.CONF_KEY_CLIENT_RANKS)
ws = len(ranks)
if not ws:
raise RuntimeError("world_size is not configured")

check_positive_int(Constant.CONF_KEY_WORLD_SIZE, ws)
self.world_size = ws

rank = config.get(Constant.CONF_KEY_RANK)
me = fl_ctx.get_identity_name()
rank = ranks.get(me)
if rank is None:
raise RuntimeError("rank is not configured")

check_non_negative_int(Constant.CONF_KEY_RANK, rank)
self.rank = rank

num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS)
if num_rounds is None:
raise RuntimeError("num_rounds is not configured")
if num_rounds is None or num_rounds <= 0:
raise RuntimeError("num_rounds is not configured or invalid value")

check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds)
self.num_rounds = num_rounds

self.training_mode = config.get(Constant.CONF_KEY_TRAINING_MODE)
if self.training_mode is None:
raise RuntimeError("training_mode is not configured")

self.xgb_params = config.get(Constant.CONF_KEY_XGB_PARAMS)
if not self.xgb_params:
raise RuntimeError("xgb_params is not configured")

self.xgb_options = config.get(Constant.CONF_KEY_XGB_OPTIONS, {})

def _send_request(self, op: str, req: Shareable) -> (bytes, Shareable):
"""Send XGB operation request to the FL server via FLARE message.
Expand Down
28 changes: 27 additions & 1 deletion nvflare/app_opt/xgboost/histogram_based_v2/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str
from nvflare.security.logging import secure_format_exception

from .defs import Constant
from .defs import TRAINING_MODE_MAPPING, Constant


class ClientStatus:
Expand Down Expand Up @@ -57,6 +57,9 @@ def __init__(
self,
adaptor_component_id: str,
num_rounds: int,
training_mode: str,
xgb_params: dict,
xgb_options: dict,
configure_task_name=Constant.CONFIG_TASK_NAME,
configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT,
start_task_name=Constant.START_TASK_NAME,
Expand All @@ -69,9 +72,15 @@ def __init__(
"""
Constructor
For the meaning of XGBoost parameters, please refer to the documentation for train API,
https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.train
Args:
adaptor_component_id - the component ID of server target adaptor
num_rounds - number of rounds
training_mode - Split mode (horizontal, vertical, horizontal_secure, vertical_secure)
xgb_params - The params argument for train method
xgb_options - All other arguments for train method are passed through this dictionary
configure_task_name - name of the config task
configure_task_timeout - time to wait for clients’ responses to the config task before timeout.
start_task_name - name of the start task
Expand All @@ -89,6 +98,9 @@ def __init__(
Controller.__init__(self)
self.adaptor_component_id = adaptor_component_id
self.num_rounds = num_rounds
self.training_mode = training_mode.lower()
self.xgb_params = xgb_params
self.xgb_options = xgb_options
self.configure_task_name = configure_task_name
self.start_task_name = start_task_name
self.start_task_timeout = start_task_timeout
Expand All @@ -104,6 +116,17 @@ def __init__(
self.client_statuses = {} # client name => ClientStatus
self.abort_signal = None

check_str("training_mode", training_mode)
valid_mode = TRAINING_MODE_MAPPING.keys()
if training_mode not in valid_mode:
raise ValueError(f"training_mode must be one of following values: {valid_mode}")

if not self.xgb_params:
raise ValueError("xgb_params can't be empty")

if not self.xgb_options:
self.xgb_options = {}

check_str("adaptor_component_id", adaptor_component_id)
check_number_range("configure_task_timeout", configure_task_timeout, min_value=1)
check_number_range("start_task_timeout", start_task_timeout, min_value=1)
Expand Down Expand Up @@ -427,6 +450,9 @@ def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext):

shareable[Constant.CONF_KEY_CLIENT_RANKS] = self.client_ranks
shareable[Constant.CONF_KEY_NUM_ROUNDS] = self.num_rounds
shareable[Constant.CONF_KEY_TRAINING_MODE] = self.training_mode
shareable[Constant.CONF_KEY_XGB_PARAMS] = self.xgb_params
shareable[Constant.CONF_KEY_XGB_OPTIONS] = self.xgb_options

task = Task(
name=self.configure_task_name,
Expand Down
30 changes: 29 additions & 1 deletion nvflare/app_opt/xgboost/histogram_based_v2/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ class Constant:
CONF_KEY_RANK = "rank"
CONF_KEY_WORLD_SIZE = "world_size"
CONF_KEY_NUM_ROUNDS = "num_rounds"
CONF_KEY_TRAINING_MODE = "training_mode"
CONF_KEY_XGB_PARAMS = "xgb_params"
CONF_KEY_XGB_OPTIONS = "xgb_options"

# default component config values
CONFIG_TASK_TIMEOUT = 10
CONFIG_TASK_TIMEOUT = 20
START_TASK_TIMEOUT = 10
XGB_SERVER_READY_TIMEOUT = 10.0

Expand Down Expand Up @@ -88,6 +91,9 @@ class Constant:
RUNNER_CTX_PORT = "port"
RUNNER_CTX_CLIENT_NAME = "client_name"
RUNNER_CTX_NUM_ROUNDS = "num_rounds"
RUNNER_CTX_TRAINING_MODE = "training_mode"
RUNNER_CTX_XGB_PARAMS = "xgb_params"
RUNNER_CTX_XGB_OPTIONS = "xgb_options"
RUNNER_CTX_WORLD_SIZE = "world_size"
RUNNER_CTX_RANK = "rank"
RUNNER_CTX_DATA_LOADER = "data_loader"
Expand All @@ -111,3 +117,25 @@ class Constant:
("grpc.max_send_message_length", MAX_FRAME_SIZE),
("grpc.max_receive_message_length", MAX_FRAME_SIZE),
]


class SplitMode:
ROW = 0
COL = 1
COL_SECURE = 2
ROW_SECURE = 3


# Mapping of text training mode to split mode
TRAINING_MODE_MAPPING = {
"h": SplitMode.ROW,
"horizontal": SplitMode.ROW,
"v": SplitMode.COL,
"vertical": SplitMode.COL,
"hs": SplitMode.ROW_SECURE,
"horizontal_secure": SplitMode.ROW_SECURE,
"vs": SplitMode.COL_SECURE,
"vertical_secure": SplitMode.COL_SECURE,
}

SECURE_TRAINING_MODES = {"hs", "horizontal_secure", "vs", "vertical_secure"}
8 changes: 1 addition & 7 deletions nvflare/app_opt/xgboost/histogram_based_v2/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,9 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
self.log_error(fl_ctx, f"missing {Constant.CONF_KEY_NUM_ROUNDS} from config")
return make_reply(ReturnCode.BAD_TASK_DATA)

world_size = len(ranks)

# configure the XGB client target via the adaptor
self.adaptor.configure(
{
Constant.CONF_KEY_RANK: my_rank,
Constant.CONF_KEY_NUM_ROUNDS: num_rounds,
Constant.CONF_KEY_WORLD_SIZE: world_size,
},
shareable,
fl_ctx,
)
return make_reply(ReturnCode.OK)
Expand Down
14 changes: 14 additions & 0 deletions nvflare/app_opt/xgboost/histogram_based_v2/fed_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from typing import Optional

from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_server_adaptor import GrpcServerAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_server_runner import XGBServerRunner

from .controller import XGBController
from .defs import Constant
from .sec.server_handler import ServerSecurityHandler


class XGBFedController(XGBController):
def __init__(
self,
num_rounds: int,
training_mode: str,
xgb_params: dict,
xgb_options: Optional[dict] = None,
configure_task_name=Constant.CONFIG_TASK_NAME,
configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT,
start_task_name=Constant.START_TASK_NAME,
Expand All @@ -39,6 +45,9 @@ def __init__(
self,
adaptor_component_id="",
num_rounds=num_rounds,
training_mode=training_mode,
xgb_params=xgb_params,
xgb_options=xgb_options,
configure_task_name=configure_task_name,
configure_task_timeout=configure_task_timeout,
start_task_name=start_task_name,
Expand All @@ -52,6 +61,11 @@ def __init__(
self.in_process = in_process

def get_adaptor(self, fl_ctx: FLContext):

engine = fl_ctx.get_engine()
handler = ServerSecurityHandler()
engine.add_component(str(uuid.uuid4()), handler)

runner = XGBServerRunner()
runner.initialize(fl_ctx)
adaptor = GrpcServerAdaptor(
Expand Down
16 changes: 8 additions & 8 deletions nvflare/app_opt/xgboost/histogram_based_v2/fed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid

from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_client_adaptor import GrpcClientAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_client_runner import XGBClientRunner

from .executor import XGBExecutor
from .sec.client_handler import ClientSecurityHandler


class FedXGBHistogramExecutor(XGBExecutor):
def __init__(
self,
early_stopping_rounds,
xgb_params: dict,
data_loader_id: str,
verbose_eval=False,
use_gpus=False,
Expand All @@ -39,8 +40,6 @@ def __init__(
per_msg_timeout=per_msg_timeout,
tx_timeout=tx_timeout,
)
self.early_stopping_rounds = early_stopping_rounds
self.xgb_params = xgb_params
self.data_loader_id = data_loader_id
self.verbose_eval = verbose_eval
self.use_gpus = use_gpus
Expand All @@ -50,12 +49,13 @@ def __init__(
self.in_process = in_process

def get_adaptor(self, fl_ctx: FLContext):

engine = fl_ctx.get_engine()
handler = ClientSecurityHandler()
engine.add_component(str(uuid.uuid4()), handler)

runner = XGBClientRunner(
data_loader_id=self.data_loader_id,
early_stopping_rounds=self.early_stopping_rounds,
xgb_params=self.xgb_params,
verbose_eval=self.verbose_eval,
use_gpus=self.use_gpus,
model_file_name=self.model_file_name,
metrics_writer_id=self.metrics_writer_id,
)
Expand Down
Loading

0 comments on commit ed5249f

Please sign in to comment.