diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 7bb0052d..0e7ab0c0 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -20,7 +20,7 @@ jobs: unit_tests: needs: cancel_previous_workflows runs-on: [self-hosted] - timeout-minutes: 30 + timeout-minutes: 45 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index c3618524..887e342f 100644 --- a/Makefile +++ b/Makefile @@ -21,8 +21,8 @@ install: .PHONY: lint lint: check_pylint_installed check_pytest_installed - @pylint --rcfile=.pylintrc -s n --jobs=128 ./llumnix - + @pylint --rcfile=.pylintrc -s n --jobs=128 ./llumnix + @pylint --rcfile=.pylintrc \ --disable=protected-access,super-init-not-called,unused-argument,redefined-outer-name,invalid-name \ -s n --jobs=128 ./tests @@ -53,7 +53,7 @@ proto-clean: .PHONY: test test: check_pytest_installed - @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v --ignore=third_party --ignore=tests/e2e_test --disable-warnings @python examlpes/offline_inference.py @pytest -v -x -s --tb=long ./tests/e2e_test/test_e2e.py @pytest -v -x -s --tb=long ./tests/e2e_test/test_bench.py @@ -61,7 +61,7 @@ test: check_pytest_installed .PHONY: unit_test unit_test: check_pytest_installed - @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v --ignore=third_party --ignore=tests/e2e_test --disable-warnings .PHONY: offline_test offline_test: diff --git a/README.md b/README.md index 09902f69..de06122d 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Llumnix is easy to use with: ## Getting Started -If you are already utilizing vLLM for multi-instance LLM serving deployments, simply replace the vLLM serving deployment command `python -m vllm.entrypoints.api_server ...` for each instance with the command provided below: +If you are already utilizing vLLM for multi-instance LLM serving deployments, simply replace the vLLM serving deployment command `python -m entrypoints.vllm.api_server ...` for each instance with the command provided below: ``` python -m llumnix.entrypoints.vllm.api_server \ --host $HOST \ diff --git a/docs/Arguments.md b/docs/Arguments.md index 916755cf..f5be9210 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -6,17 +6,28 @@ Note: since Llumnix is still in alpha stage, the interface and arguments are *su ``` usage: -m llumnix.entrypoints.vllm.api_server [-h] + [--host HOST] + [--port PORT] + [--ssl-keyfile SSL_KEYFILE] + [--ssl-certfile SSL_CERTFILE] + [--log-level {debug,info,warning,error}] + [--launch-ray-cluster] + [--ray-cluster-port RAY_CLUSTER_PORT] + [--request-output-queue-type {rayqueue,zmq}] + [--request-output-queue-port REQUEST_OUTPUT_QUEUE_PORT] + [--disable-log-requests-server] + [--log-request-timestamps] [--config-file CONFIG_FILE] [--initial-instances INITIAL_INSTANCES] [--load-metric {remaining_steps,usage_ratio}] [--polling-interval POLLING_INTERVAL] [--dispatch-policy {balanced,load,queue,rr}] [--enable-migration] + [--enable-defrag] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] [--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}] [--migrate-out-threshold MIGRATE_OUT_THRESHOLD] [--request-migration-policy {LCR,SR,LR,FCW,FCWSR}] - [--enable-defrag ENABLE_DEFRAG] [--enable-scaling] [--min-instances MIN_INSTANCES] [--max-instances MAX_INSTANCES] @@ -27,26 +38,69 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--disable-log-requests-manager] [--log-instance-info] [--log-filename LOG_FILENAME] + [--simulator-mode] [--profiling-result-file-path PROFILING_RESULT_FILE_PATH] [--gpu-type GPU_TYPE] - [--polling-interval POLLING_INTERVAL] [--migration-backend {gloo,nccl,rayrpc,grpc,kvtransfer}] [--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS] - [--migration-backend-transfer-type {cuda_ipc,rdma,}] - [--migration-backend-kvtransfer-naming-url MIGRATION_BACKEND_KVTRANSFER_NAMING_URL] - [--migration-backend-server-address MIGRATION_BACKEND_SERVER_ADDRESS] - [--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT] [--migration-num-layers MIGRATION_NUM_LAYERS] - [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] + [--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT] + [--migration-backend-transfer-type {cuda_ipc,rdma,}] + [--grpc-migration-backend-server-address GRPC_MIGRATION_BACKEND_SERVER_ADDRESS] + [--kvtransfer-migration-backend-naming-url KVTRANSFER_MIGRATION_BACKEND_NAMING_URL] [--max-stages MAX_STAGES] + [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] [--enable-pd-disagg] [--num-dispatch-instances NUM_DISPATCH_INSTANCES] - [--log-request-timestamps] - + [--enable-port-increment] ``` +`--host` +- Hostname of the server. +- Default: "localhost" + +`--port` +- Port number of the server. +- Default: 8000 + +`--ssl-keyfile` +- Path to SSL key file. +- Default: None + +`--ssl-certfile` +- Path to SSL certificate file. +- Default: None + +`--log-level` +- Log level for the server. +- Possible choices: debug, info, warning, error +- Default: "info" + +`--launch-ray-cluster` +- If launch ray cluster. + +`--ray-cluster-port` +- Ray cluster port. +- Default: 6379 + +`--request-output-queue-type` +- Queue type for request output queue. +- Possible choices: rayqueue, zmq +- Default: "rayqueue" + +`--request-output-queue-port` +- Port number for the zmq request output queue. +- Default: 1234 + +`--disable-log-requests-server` +- Disable logging requests in server. + +`--log-request-timestamps` +- If log request timestamps. + `--config-file` -- Path to config file. +- Path to config file of arguments. +- Default: None `--initial-instances` - Number of instances created at initialization. @@ -69,6 +123,9 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--enable-migration` - Enable migrate requests between instances. +`--enable-defrag` +- Enable defragmentation through migration based on virtual usage. + `--pair-migration-frequency` - Pair migration frequency. - Default: 1 @@ -87,10 +144,6 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Possible choices: LCR, SR, LR, FCW, FCWSR - Default: "SR" -`--enable-defrag` -- Enable defragmentation through migration based on virtual usage. -- Default: False - `--enable-scaling` - Enable auto scaling. @@ -129,60 +182,60 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Log filename. - Default: "server.log" -`--profiling-result-file-path` -- Profiling result file path. -- Default: "" +`--simulator-mode` +- Enable simulator mode. -`--gpu-type` -- GPU type specified when using simulator. -- Default: "a10" +`--profiling-result-file-path` +- Profiling result file path when using simulator. +- Default: None `--migration-backend` - Communication backend of migration. - Possible choices: gloo, rayrpc, nccl, grpc, kvtransfer. [gloo, rayrpc, nccl] are available for vllm and [grpc, kvtransfer] are available for bladellm. - Default: "gloo" -`--migration-backend-transfer-type` -- Transfer type for migration backend kvTransfer. -- Possible choices: cuda_ipc, rdma -- Default: "rdma" - -`--migration-backend-server-address` -- Address of grpc server for migration backend -- Default: "127.0.0.1:50051" - -`--migration-backend-kvtransfer-naming-url` -- URL of naming server for kvtransfer migration backend -- Default: "file:/tmp/llumnix/naming/" - `--migration-buffer-blocks` - Number of buffer blocks in migration. - Default: 512 +`--migration-num-layers` +- number of kv-cache layers to transfer in each round during migration +- Default: 1 + `--migration-backend-init-timeout` - Timeout(s) for initializing migration backend. - Default: 10.0 -`--migration-num-layers` -- number of kv-cache layers to transfer in each round during migration -- Default: 1 +`--migration-backend-transfer-type` +- Transfer type for migration backend grpc and kvTransfer. +- Possible choices: cuda_ipc, rdma +- Default: "rdma" -`--last-stage-max-blocks` -- If the number of remaining blocks < last_stage_max_blocks, do last stage migration. -- Default: 4 +`--grpc-migration-backend-server-address` +- Address of grpc server for migration backend +- Default: "127.0.0.1:50051" + +`--kvtransfer-migration-backend-naming-url` +- URL of naming server for kvtransfer migration backend +- Default: "file:/tmp/llumnix/naming/" `--max-stages` - Drop migration if the number of stages > max_stages. - Default: 3 -`--log-request-timestamps` -- Enable logging request timestamps. +`--last-stage-max-blocks` +- If the number of remaining blocks < last_stage_max_blocks, do last stage migration. +- Default: 16 `--enable-pd-disagg` - Enable prefill decoding disaggregation. `--num-dispatch-instances` - Number of available instances for dispatch. +- Default: math.inf + +`--enable-port-increment` +- Enable port increment when desploying multiple servers. # Unsupported vLLM feature options diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 4fcd605f..5798ba3f 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -34,7 +34,7 @@ After installation, you can follow this guide to use Llumnix for multi-instance ## Migrating from Existing Deployments -Inference engines like vLLM provide an API server user interface, e.g., `python -m vllm.entrypoints.api_server`. To deploy multiple instances, people start multiple such API servers, each corresponding to one instance, on multiple nodes / containers / k8s pods. +Inference engines like vLLM provide an API server user interface, e.g., `python -m entrypoints.vllm.api_server`. To deploy multiple instances, people start multiple such API servers, each corresponding to one instance, on multiple nodes / containers / k8s pods. Llumnix provides a similar user interface to enable seamless integration with such existing multi-instance deployments. You only need two simple steps to migrate from a deployed vLLM service to Llumnix: @@ -62,11 +62,25 @@ export HEAD_NODE=1 During the execution of serving deployment, Llumnix will: - Initiate the Ray cluster for distributed execution. -- Start Llumnix actor components, including LLMEngineManager, Llumlet, among others. +- Start Llumnix actor components, including Manager, Llumlet, among others. - Launch the vLLM engine instances. Following these steps, Llumnix acts as the request scheduling layer situated behind the multiple frontend API servers and above the multiple backend vLLM engine instances. This positioning allows Llumnix to significantly enhance serving performance through its dynamic, fine-grained, and KV-cache-aware request scheduling and rescheduling across instances. +## Centralized Deployment + +Llumnix also supports deploying multiple servers and instances at once by running `python -m entrypoints.vllm.serve`, which is named as centralized deployment. + +``` +python -m llumnix.entrypoints.vllm.serve \ + --config-file $CONFIG_PATH \ + # vLLM arguments ... + # Llumnix arguments ... + ... +``` + +Centralized deployment assumes that user has already launch a Ray cluter. Upon running the serve module, Llumnix will automatically connect to the existing Ray cluster, start the Llumnix components, and deploy multiple servers and instances to the Ray cluster until there is no more available gpus or cpus. + ## Ray Cluster Notice When you include the --launch-ray-cluster option in Llumnix's serving deployment command, Llumnix automatically builds a Ray cluster during the execution of serving deployment. This action will overwrite any existing Ray cluster. If this behavior is not desired, simply omit the --launch-ray-cluster option, and Llumnix will initiate its actor components within the current Ray cluster. @@ -84,7 +98,8 @@ HEAD_NODE=1 python -m llumnix.entrypoints.vllm.api_server \ --model $MODEL_PATH \ --engine-use-ray \ --worker-use-ray \ - --max-model-len 4096 + --max-model-len 4096 \ + --migration-backend rayrpc \ ``` `CONFIG_PATH` is the path to the configuration file for Llumnix, and we give an example configuration file [here](../configs/base.yml). `MODEL_PATH` defines the location of your model. `INITIAL_INSTANCES` determines the number of instances to be launched on the current node, diff --git a/examlpes/offline_inference.py b/examlpes/offline_inference.py index 4b6bc5d3..5148a9e8 100644 --- a/examlpes/offline_inference.py +++ b/examlpes/offline_inference.py @@ -5,8 +5,9 @@ import ray from llumnix import launch_ray_cluster, connect_to_ray_cluster, init_manager -from llumnix import (SamplingParams, ServerInfo, EngineManagerArgs, LLMEngineManager, Llumlet, - EngineArgs, QueueType, BackendType) +from llumnix import (ManagerArgs, EngineArgs, Manager, + Llumlet, ServerInfo, QueueType, BackendType, + SamplingParams) from llumnix.utils import random_uuid from llumnix.queue.ray_queue_server import RayQueueServer @@ -33,23 +34,18 @@ connect_to_ray_cluster(port=ray_cluster_port) # Set manager args and engine args. -manager_args = EngineManagerArgs() +manager_args = ManagerArgs() engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True, trust_remote_code=True, max_model_len=370) -# Create a manager. If the manager is created first, and then the llumlets are created, manager.scale_up -# need to be called to add the newly created llumlets to the management of the manager. -manager: LLMEngineManager = init_manager(manager_args) +# Create a manager. If the manager is created first, and then the instances are created. +manager: Manager = init_manager(manager_args) ray.get(manager.is_ready.remote()) -# Create llumlets. +# Create instances. instance_ids: List[str] = None -llumlets: List[Llumlet] = None -instance_ids, llumlets = ray.get(manager.init_llumlets.remote( - engine_args, QueueType("rayqueue"), BackendType.VLLM, 1, -)) - -ray.get(manager.scale_up.remote(instance_ids, llumlets)) +instances: List[Llumlet] = None +instance_ids, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.VLLM, engine_args)) # The requests‘ outputs will be put to the request_output_queue no matter which instance it's running in. server_id = random_uuid() diff --git a/llumnix/__init__.py b/llumnix/__init__.py index 3e6e04b4..fba69575 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -15,8 +15,8 @@ from llumnix.entrypoints.setup import (launch_ray_cluster, connect_to_ray_cluster, init_manager) -from llumnix.arg_utils import EngineManagerArgs -from llumnix.llm_engine_manager import LLMEngineManager +from llumnix.arg_utils import ManagerArgs +from llumnix.manager import Manager from llumnix.llumlet.llumlet import Llumlet from llumnix.queue.queue_type import QueueType from llumnix.backends.backend_interface import BackendType @@ -28,8 +28,8 @@ "launch_ray_cluster", "connect_to_ray_cluster", "init_manager", - "EngineManagerArgs", - "LLMEngineManager", + "ManagerArgs", + "Manager", "Llumlet", "QueueType", "BackendType", diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index d9138407..c3a7d9ff 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -21,6 +21,8 @@ from llumnix.internal_config import GlobalSchedulerConfig, MigrationConfig from llumnix.config import LlumnixConfig, get_llumnix_config from llumnix.config.default import _C +from llumnix.backends.backend_interface import BackendType +from llumnix.entrypoints.utils import LaunchMode class LlumnixArgumentParser(argparse.ArgumentParser): @@ -44,7 +46,12 @@ def add_argument(self, *args, **kwargs): # All the default values of llumnix arguments are set in default.py. So all the arguments here are set to None. @dataclass -class LlumnixEntrypointsArgs: +class EntrypointsArgs: + host: str = None + port: int = None + ssl_keyfile: str = None + ssl_certfile: str = None + log_level: str = None launch_ray_cluster: bool = None ray_cluster_port: int = None request_output_queue_type: str = None @@ -52,6 +59,7 @@ class LlumnixEntrypointsArgs: disable_log_requests_server: bool = None log_request_timestamps: bool = None config_file: str = None + disable_keep_serve_process_alive: bool = None def __post_init__(self): for attr in dataclasses.fields(self): @@ -59,16 +67,16 @@ def __post_init__(self): setattr(self, attr.name, getattr(_C.SERVER, attr.name.upper())) @classmethod - def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'LlumnixEntrypointsArgs': + def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'EntrypointsArgs': # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. # The defalut values of attributes are defined in default.py. - llumnix_entrypoints_args = cls(**{attr: getattr(cfg.SERVER, attr.upper()) for attr in attrs}) - return llumnix_entrypoints_args + entrypoints_args = cls(**{attr: getattr(cfg.SERVER, attr.upper()) for attr in attrs}) + return entrypoints_args @classmethod - def check_args(cls, args: 'LlumnixEntrypointsArgs', parser: argparse.ArgumentParser): + def check_args(cls, args: 'EntrypointsArgs', parser: argparse.ArgumentParser): # pylint: disable=protected-access for action in parser._optionals._actions: if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): @@ -78,17 +86,17 @@ def check_args(cls, args: 'LlumnixEntrypointsArgs', parser: argparse.ArgumentPar def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--launch-ray-cluster', action='store_true', - help='if launch ray cluster in api server') + help='if launch ray cluster') parser.add_argument("--ray-cluster-port", type=int, help='ray cluster port') parser.add_argument("--request-output-queue-type", type=str, choices=['rayqueue', 'zmq'], - help='request output queue type for request output queue') + help='queue type for request output queue') parser.add_argument("--request-output-queue-port", type=int, - help='port for zmq') + help='port number for the zmq request output queue') parser.add_argument('--disable-log-requests-server', action='store_true', help='disable logging requests in server') @@ -97,18 +105,18 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help='if log request timestamps') parser.add_argument("--config-file", type=str, - help="path to config file") + help="path to config file of arguments") + return parser @dataclass -class EngineManagerArgs: +class ManagerArgs: initial_instances: int = None load_metric: str = None polling_interval: float = None dispatch_policy: str = None - num_dispatch_instances: int = None enable_migration: bool = None enable_defrag: bool = None @@ -125,22 +133,26 @@ class EngineManagerArgs: scale_up_threshold: float = None scale_down_threshold: float = None - log_filename: str = None disable_log_requests_manager: bool = None log_instance_info: bool = None + log_filename: str = None + simulator_mode: bool = None profiling_result_file_path: str = None - migration_backend_kvtransfer_naming_url: str = None - migration_backend_server_address: str = None - migration_backend_init_timeout: float = None migration_backend: str = None migration_buffer_blocks: int = None - migration_backend_transfer_type: str = None migration_num_layers: int = None + migration_backend_init_timeout: float = None + migration_backend_transfer_type: str = None + grpc_migration_backend_server_address: str = None + kvtransfer_migration_backend_naming_url: str = None last_stage_max_blocks: int = None max_stages: int = None enable_pd_disagg: bool = None + num_dispatch_instances: int = None + + enable_port_increment: bool = None def __post_init__(self): # Check if all fields default to None @@ -152,7 +164,7 @@ def __post_init__(self): if getattr(self, attr.name) is None: setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper())) - def create_global_scheduler_configs( + def create_global_scheduler_config( self, ) -> Tuple[GlobalSchedulerConfig]: @@ -168,7 +180,7 @@ def create_global_scheduler_configs( self.scale_up_threshold, self.scale_down_threshold, self.enable_pd_disagg, - self.migration_backend,) + self.migration_backend) return global_scheduler_config def create_migration_config(self) -> MigrationConfig: @@ -180,21 +192,21 @@ def create_migration_config(self) -> MigrationConfig: self.max_stages, self.migration_backend_init_timeout, self.migration_backend_transfer_type, - self.migration_backend_server_address, - self.migration_backend_kvtransfer_naming_url) + self.grpc_migration_backend_server_address, + self.kvtransfer_migration_backend_naming_url) return migration_config @classmethod - def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'EngineManagerArgs': + def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'ManagerArgs': # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. # The defalut values of attributes are defined in default.py. - engine_manager_args = cls(**{attr: getattr(cfg.MANAGER, attr.upper()) for attr in attrs}) - return engine_manager_args + manager_args = cls(**{attr: getattr(cfg.MANAGER, attr.upper()) for attr in attrs}) + return manager_args @classmethod - def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser): + def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser): # pylint: disable=protected-access for action in parser._optionals._actions: if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): @@ -207,6 +219,9 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser): ("When using kvTransfer as migration backend, " "do not set --migration-backend-transfer-type as empty.") + assert not args.simulator_mode or args.profiling_result_file_path is not None, \ + "Set profiling_result_file_path args when enable simulator mode" + @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--initial-instances', @@ -230,13 +245,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: '* "queue" dispatch request to the instance with minimum waiting request queue length.\n' '* "flood" dispatch request to the instance with maximum requests dispatched.\n' '* "rr" dispatch requests with round-robin policy.\n') - parser.add_argument('--num-available-dispatch-instances', - type=int, - help='number of available instances for dispatching') parser.add_argument('--enable-migration', action='store_true', help='enable migrate requests between instances') + parser.add_argument('--enable-defrag', + type=bool, + help='enable defragmentation through migration based on virtual usage') parser.add_argument('--pair-migration-frequency', type=int, help='pair migration frequency') @@ -262,9 +277,6 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: '* "LR" migrate the running request longest.\n' '* "FCW" migrate the waiting request first come.\n' '* "FCWSR" migrate the waiting request first come and running request shortest.\n') - parser.add_argument('--enable-defrag', - type=bool, - help='enable defragmentation through migration based on virtual usage') parser.add_argument('--enable-scaling', action='store_true', @@ -300,41 +312,55 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help='log filename') parser.add_argument('--profiling-result-file-path', type=str, - help='profiling result file path') + help='profiling result file path when using simulator') + parser.add_argument('--simulator-mode', + action='store_true', + help='enable simulator mode') parser.add_argument('--migration-backend', type=str, choices=['gloo','nccl','rayrpc','grpc','kvtransfer'], help='communication backend of migration, [gloo, rayrpc, nccl] are available for vllm \ and [grpc, kvtransfer] are available for bladellm') + parser.add_argument('--migration-buffer-blocks', + type=int, + help='number of buffer blocks in migration') + parser.add_argument('--migration-num-layers', + type=int, + help='number of kv-cache layers to transfer in each round during migration') + parser.add_argument('--migration-backend-init-timeout', + type=float, + help='timeout(s) for initializing migration backend') parser.add_argument('--migration-backend-transfer-type', type=str, choices=['cuda_ipc','rdma', ''], help='transfer type for migration backend grpc and kvTransfer') - parser.add_argument('--grpc-migration-backend-address', + parser.add_argument('--grpc-migration-backend-server-address', type=str, help='address of grpc server for migration backend') - parser.add_argument('--migration-backend-kvtransfer-naming-url', + parser.add_argument('--kvtransfer-migration-backend-naming-url', type=str, help='url of naming server for kvtransfer migration backend') - parser.add_argument('--migration-backend-init-timeout', - type=float, - help='timeout(s) for initializing migration backend') - parser.add_argument('--migration-buffer-blocks', - type=int, - help='number of buffer blocks in migration') - parser.add_argument('--migration-num-layers', + parser.add_argument('--max-stages', type=int, - help='number of kv-cache layers to transfer in each round during migration') + help='drop migration if the number of stages > max_stages') parser.add_argument('--last-stage-max-blocks', type=int, help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration') - parser.add_argument('--max-stages', - type=int, - help='drop migration if the number of stages > max_stages') + parser.add_argument('--enable-pd-disagg', action='store_true', help='enable prefill decoding disaggregation') parser.add_argument('--num-dispatch-instances', type=int, help='number of available instances for dispatch') + + parser.add_argument('--enable-port-increment', + action='store_true', + help='enable port increment when desploying multiple servers') + return parser + +@dataclass +class LaunchArgs: + launch_mode: LaunchMode = None + backend_type: BackendType = None diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 5e34c01f..d8631b84 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -18,24 +18,24 @@ from llumnix.llumlet.request import LlumnixRequest, RequestStatus from llumnix.server_info import ServerInfo + class EngineState(str, Enum): INIT = "INIT" CRASHED = "CRASHED" RUNNING = "RUNNING" STOPPED = "STOPPED" + class BackendType(str, Enum): VLLM = "VLLM" - SIM_VLLM = "SIM_VLLM" BLADELLM = "BLADELLM" + SIM_VLLM = "SIM_VLLM" @staticmethod def is_sim_backend(status: "BackendType") -> bool: - return status in [ - BackendType.SIM_VLLM, - ] + return status in [BackendType.SIM_VLLM] + -# TODO(KuilongCui): separate backend interface into two parts: DispatchBackendInterface and MigrationBackendInterface class BackendInterface(ABC): # Methods for inference @abstractmethod diff --git a/llumnix/backends/bladellm/llm_engine.py b/llumnix/backends/bladellm/llm_engine.py index 7dff0caf..79bd9abc 100644 --- a/llumnix/backends/bladellm/llm_engine.py +++ b/llumnix/backends/bladellm/llm_engine.py @@ -40,7 +40,7 @@ from llumnix.queue.queue_type import QueueType class AsyncBackQueueWrapper(APIWrapper): - def __init__(self, placement_group, instance_id, output_queue_type) -> None: + def __init__(self, placement_group, instance_id, request_output_queue_type) -> None: super().__init__(args=None, resp_queue=None) scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, @@ -54,7 +54,7 @@ def __init__(self, placement_group, instance_id, output_queue_type) -> None: self.async_put_queue_actor = ray.remote( num_cpus=1, scheduling_strategy=scheduling_strategy - )(AsyncPutQueueActor).remote(instance_id, output_queue_type) + )(AsyncPutQueueActor).remote(instance_id, request_output_queue_type) self.put_queue_loop_thread.start() self.request_server_map = {} @@ -113,9 +113,9 @@ class AsyncLLMEngineLlumnixMixin: # pylint: disable=unused-argument def __init__(self, instance_id: str, - output_queue_type: QueueType, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, migration_config: MigrationConfig, - placement_group: Optional[PlacementGroup], ) -> None: self.instance_id = instance_id @@ -123,7 +123,7 @@ def __init__(self, logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) self.placement_group = placement_group - self.output_queue_type = output_queue_type + self.request_output_queue_type = request_output_queue_type @property def instance_info(self) -> InstanceInfo: @@ -134,7 +134,7 @@ def start(self, loop: asyncio.AbstractEventLoop): self._client = self.init_client_from_engine() self.trans_wrapper: AsyncBackQueueWrapper = AsyncBackQueueWrapper(self.placement_group, self.instance_id, - self.output_queue_type) + self.request_output_queue_type) self._scheduler.llumnix_metrics.engine_init_metrics(self) async def update_callback(self, resp_list, step_requests): @@ -150,7 +150,7 @@ async def _loop(self): await super()._loop() # pylint: disable=broad-except except Exception as e: - logger.error("Error in engine loop: {}".format(e)) + logger.error("error in engine loop: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) previous_state = self.state @@ -185,49 +185,53 @@ async def drop_request(self, req_id: int): class AsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, AsyncLLMEngine): def __init__(self, instance_id: str, - output_queue_type: QueueType, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, migration_config: MigrationConfig, - placement_group: Optional[PlacementGroup], *args, **kwargs, ) -> None: AsyncLLMEngine.__init__(self, *args, **kwargs) - AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, placement_group, request_output_queue_type, migration_config) class PrefillAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, PrefillAsyncLLMEngine): def __init__(self, instance_id: str, - output_queue_type: QueueType, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, migration_config: MigrationConfig, - placement_group: Optional[PlacementGroup], *args, **kwargs, ) -> None: PrefillAsyncLLMEngine.__init__(self, *args, **kwargs) - AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, placement_group, request_output_queue_type, migration_config) class DecodeAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, DecodeAsyncLLMEngine): def __init__(self, instance_id: str, - output_queue_type: QueueType, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, migration_config: MigrationConfig, - placement_group: Optional[PlacementGroup], *args, **kwargs, ) -> None: DecodeAsyncLLMEngine.__init__(self, *args, **kwargs) - AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, placement_group, request_output_queue_type, migration_config) class BackendBladeLLM(BackendInterface): def __init__( self, instance_id: str, - output_queue_type: QueueType, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, migration_config: MigrationConfig, - engine_args: ServingArgs, - placement_group: PlacementGroup = None, + engine_args: ServingArgs ) -> None: self.instance_id = instance_id self.engine_args = engine_args engine_cls = self._get_engine_cls() - self.engine = engine_cls(instance_id, output_queue_type, migration_config, placement_group, engine_args) + self.engine = engine_cls(instance_id, + placement_group, + request_output_queue_type, + migration_config, + engine_args) self._loop = asyncio.new_event_loop() self._engine_ready = threading.Event() diff --git a/llumnix/backends/profiling.py b/llumnix/backends/profiling.py index 452bd5cd..b79afcc1 100644 --- a/llumnix/backends/profiling.py +++ b/llumnix/backends/profiling.py @@ -196,7 +196,7 @@ def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingData assert sim_parallel_config in profiling_result.para_dict.keys(), "sim parallel config not in database" latency_mem: LatencyMemData = profiling_result.para_dict[sim_parallel_config] return latency_mem - raise ValueError(f'Unsupported backend: {backend_type}') + raise ValueError(f'Unsupported simulator backend: {backend_type}') if __name__ == "__main__": import argparse diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 8976128d..8659c016 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Dict, List +from typing import Dict, List import asyncio import time @@ -24,13 +24,14 @@ from llumnix.queue.utils import init_request_output_queue_client from llumnix.server_info import ServerInfo from llumnix.logger import init_logger -from llumnix.utils import get_placement_group_name, get_instance_name +from llumnix.utils import get_instance_name +from llumnix.internal_config import MigrationConfig logger = init_logger(__name__) class AsyncPutQueueActor: - def __init__(self, instance_id, request_output_queue_type: QueueType): + def __init__(self, instance_id: str, request_output_queue_type: QueueType): self.instance_id = instance_id self.request_output_queue_type = request_output_queue_type self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(request_output_queue_type) @@ -56,87 +57,51 @@ async def put_nowait_to_servers(self, logger.info("server {} is dead".format(server_id)) if self.request_output_queue_type == QueueType.ZMQ: logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, - server_info.request_output_queue_port)) + server_info.request_output_queue_port)) req_outputs = list(server_request_outputs.values())[idx] request_ids = [req_output.request_id for req_output in req_outputs] self.engine_actor_handle.abort_request.remote(request_ids) -def init_backend_engine(instance_id: str, request_output_queue_type: QueueType, - backend_type: BackendType, *args, **kwargs) -> BackendInterface: +def init_backend_engine(instance_id: str, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, + migration_config: MigrationConfig, + backend_type: BackendType, + engine_args, + profiling_result_file_path: str = None) -> BackendInterface: if backend_type == BackendType.VLLM: # pylint: disable=import-outside-toplevel from llumnix.backends.vllm.llm_engine import BackendVLLM - backend_engine = BackendVLLM(instance_id, request_output_queue_type, *args, **kwargs) - elif backend_type == BackendType.SIM_VLLM: - # pylint: disable=import-outside-toplevel - from llumnix.backends.vllm.simulator import BackendSimVLLM - backend_engine = BackendSimVLLM(instance_id, request_output_queue_type, *args, **kwargs) + backend_engine = BackendVLLM(instance_id, + placement_group, + request_output_queue_type, + migration_config, + engine_args) elif backend_type == BackendType.BLADELLM: # pylint: disable=import-outside-toplevel from llumnix.backends.bladellm.llm_engine import BackendBladeLLM - backend_engine = BackendBladeLLM(instance_id, request_output_queue_type, *args, **kwargs) + backend_engine = BackendBladeLLM(instance_id, + placement_group, + request_output_queue_type, + migration_config, + engine_args) + elif backend_type == BackendType.SIM_VLLM: + # pylint: disable=import-outside-toplevel + from llumnix.backends.vllm.simulator import BackendSimVLLM + backend_engine = BackendSimVLLM(instance_id, + placement_group, + request_output_queue_type, + migration_config, + engine_args, + profiling_result_file_path) else: raise ValueError(f'Unsupported backend: {backend_type}') return backend_engine -def initialize_placement_group( - instance_id: str, - num_cpus: int = 1, - num_gpus: int = 1, - detached: bool = False -) -> Tuple[str, Optional[PlacementGroup]]: - """Initialize the distributed cluster probably with Ray. - - Args: - parallel_config: The configurations for parallel execution. - engine_use_ray: Whether to use Ray for async engine. - ray_address: The address of the Ray cluster. If None, uses - the default Ray cluster address. - - Returns: - A tuple of (`distributed_init_method`, `placement_group`). The - `distributed_init_method` is the address for initializing the - distributed backend. `placement_group` includes the specification - of the resources for each distributed worker. - """ - if ray is None: - raise ImportError( - "Ray is not installed. Please install Ray to use distributed " - "serving.") - - lifetime = "detached" if detached else None - # Create placement group for worker processes - current_placement_group = ray.util.get_current_placement_group() - if current_placement_group: - # We are in a placement group - bundles = current_placement_group.bundle_specs - # Verify that we can use the placement group. - gpu_bundles = 0 - for bundle in bundles: - bundle_gpus = bundle.get("GPU", 0) - if bundle_gpus > 1: - raise ValueError( - "Placement group bundle cannot have more than 1 GPU.") - if bundle_gpus: - gpu_bundles += 1 - if num_gpus > gpu_bundles: - raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the placement group.") - else: - num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) - if num_gpus > num_gpus_in_cluster: - raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the cluster.") - # Create a new placement group - # bundle_0: Llumlet + AsyncPutQueueActor + ProxyActor, bundle_1: Workers - placement_group_specs = ([{"CPU": num_cpus}] + [{"GPU": 1}] * num_gpus) - current_placement_group = ray.util.placement_group( - placement_group_specs, "STRICT_PACK", name=get_placement_group_name(instance_id), lifetime=lifetime) - # Wait until PG is ready - this will block until all - # requested resources are available, and will timeout - # if they cannot be provisioned. - ray.get(current_placement_group.ready(), timeout=1800) - - return current_placement_group +def get_engine_world_size(engine_args, backend_type: BackendType): + if backend_type == BackendType.VLLM: + engine_config = engine_args.create_engine_config() + world_size = engine_config.parallel_config.world_size + else: # BLADE_LLM + world_size = engine_args.tensor_parallel_size * engine_args.pipeline_parallel_size + return world_size diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index 6cb333da..7feeefcb 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -36,6 +36,7 @@ logger = init_logger(__name__) + class LlumnixRayGPUExecutor(RayGPUExecutorAsync): migration_config: MigrationConfig = None diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index d580696f..dec38700 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -50,9 +50,10 @@ class LLMEngineLlumnix(_AsyncLLMEngine): def __init__(self, instance_id: str, + placement_group: PlacementGroup, request_output_queue_type: QueueType, - placement_group: Optional[PlacementGroup], - *args, **kwargs) -> None: + *args, + **kwargs) -> None: super().__init__(*args, **kwargs) self.instance_id = instance_id self.step_counter = Counter() @@ -77,13 +78,13 @@ def __init__(self, @classmethod def from_engine_args( cls, - engine_args: EngineArgs, + instance_id: str, + placement_group: PlacementGroup, request_output_queue_type: QueueType, migration_config: MigrationConfig, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - instance_id: str = None, - placement_group: Optional[PlacementGroup] = None, - latency_mem: Optional[LatencyMemData] = None + engine_args: EngineArgs, + latency_mem: Optional[LatencyMemData] = None, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT ) -> "LLMEngineLlumnix": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. @@ -105,8 +106,8 @@ def from_engine_args( # Create the LLM engine. engine = cls( instance_id=instance_id, - request_output_queue_type=request_output_queue_type, placement_group=placement_group, + request_output_queue_type=request_output_queue_type, **engine_config.to_dict(), executor_class=executor_class, log_stats=not engine_args.disable_log_stats, @@ -209,8 +210,10 @@ def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: if hasattr(server_info, 'request_timestamps'): server_info.request_timestamps.engine_add_request_timestamp = time.time() self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]], - seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request, - seq_group.multi_modal_data) + sampling_params=seq_group.sampling_params, + arrival_time=seq_group.metrics.arrival_time, + lora_request=seq_group.lora_request, + multi_modal_data=seq_group.multi_modal_data) def _start_put_queue_loop(self): while True: @@ -237,16 +240,16 @@ class BackendVLLM(BackendInterface): def __init__( self, instance_id: str, + placement_group: PlacementGroup, request_output_queue_type: QueueType, migration_config: MigrationConfig, - placement_group: PlacementGroup, engine_args: EngineArgs, ) -> None: - self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, - request_output_queue_type=request_output_queue_type, - migration_config=migration_config, - instance_id=instance_id, - placement_group=placement_group) + self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(instance_id, + placement_group, + request_output_queue_type, + migration_config, + engine_args) self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler @@ -279,7 +282,7 @@ async def _start_engine_step_loop(self) -> None: await asyncio.sleep(NO_OUTPUTS_STEP_INTERVAL) # pylint: disable=broad-except except Exception as e: - logger.error("Error in engine loop: {}".format(e)) + logger.error("error in engine loop: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) self._run_workers("shutdown") diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index 39978b6e..368b8b49 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -57,7 +57,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.rpc_dtype = self.cache_engine.dtype else: self.rpc_dtype = torch.float32 - logger.warning("Detect numpy unsupported dtype: {}. Using torch.float32.".format(self.cache_engine.dtype)) + logger.warning("Detect numpy unsupported dtype: {}, using torch.float32.".format(self.cache_engine.dtype)) self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache @@ -189,7 +189,7 @@ def init_group(world_size, rank, backend, group_name): try: init_group(world_size, rank, self.backend, group_name) except FunctionTimedOut: - logger.info("create migration backend fail (group_name: {}, world_size: {}, rank: {}, backbend: {})." + logger.info("create migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {})." .format(group_name, world_size, rank, self.backend)) return False @@ -227,7 +227,7 @@ def warmup(self) -> bool: col.allreduce(self.dummy_cache[0], self.group_name) # pylint: disable=W0703 except Exception as e: - logger.info("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}." + logger.error("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}." .format(self.group_name, self.global_world_size, self.global_rank, self.backend, e)) return False @@ -276,7 +276,7 @@ def do_recv(self, src_handle, blocks: List[int]): self.migration_stream.synchronize() def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, - is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase: + is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase: if cache_engine.num_gpu_blocks < migration_config.migration_buffer_blocks: logger.warning("migration_buffer_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." .format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks)) diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index ea0991f7..874b5e1e 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -195,7 +195,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] - logger.info("free request: {}, free seq: {}".format(backend_request.request_id, seq.seq_id)) + logger.info("free request: {}, seq: {}".format(backend_request.request_id, seq.seq_id)) self.free_seq(seq) def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) -> InstanceInfo: diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index 85613edb..94ff6850 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -32,20 +32,20 @@ class BackendSimVLLM(BackendVLLM): def __init__( self, instance_id: str, + placement_group: PlacementGroup, request_output_queue_type: QueueType, migration_config: MigrationConfig, - placement_group: PlacementGroup, engine_args: EngineArgs, - profiling_result_file_path: str, + profiling_result_file_path: str ) -> None: # multi-instance args latency_mem = self._get_lantecy_mem(profiling_result_file_path, engine_args) - self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, - request_output_queue_type=request_output_queue_type, - migration_config=migration_config, - instance_id=instance_id, - placement_group=placement_group, - latency_mem=latency_mem) + self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(instance_id, + placement_group, + request_output_queue_type, + migration_config, + engine_args, + latency_mem) self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler diff --git a/llumnix/backends/vllm/utils.py b/llumnix/backends/vllm/utils.py index 7e49720a..80c63e6d 100644 --- a/llumnix/backends/vllm/utils.py +++ b/llumnix/backends/vllm/utils.py @@ -23,7 +23,7 @@ _modify_greedy_probs_inplace, _beam_search_sample from llumnix.logger import init_logger -from llumnix.arg_utils import EngineManagerArgs +from llumnix.arg_utils import ManagerArgs logger = init_logger(__name__) @@ -41,15 +41,15 @@ def detect_unsupported_feature(engine_args: EngineArgs) -> None: if unsupported_feature: raise ValueError(f'Unsupported feature: Llumnix does not support "{unsupported_feature}" currently.') -def check_engine_args(engine_args: AsyncEngineArgs, engine_manager_args: EngineManagerArgs) -> None: +def check_engine_args(engine_args: AsyncEngineArgs, manager_args: ManagerArgs) -> None: assert engine_args.engine_use_ray and engine_args.worker_use_ray, \ ("In Llumnix, engine and worker must be ray actor.") - migration_config = engine_manager_args.create_migration_config() + migration_config = manager_args.create_migration_config() engine_config = engine_args.create_engine_config() parallel_config = engine_config.parallel_config if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl': - logger.info("Llumnix does not support TP or PP enabled model when the migration backend is nccl, change migration backend to gloo.") - engine_manager_args.migration_backend = 'gloo' + logger.warning("Llumnix does not support TP or PP when the migration backend is nccl, change migration backend to gloo.") + manager_args.migration_backend = 'gloo' detect_unsupported_feature(engine_args) def _get_dtype_size(dtype: torch.dtype) -> int: diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index d18c993f..fd7dcca6 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -60,16 +60,16 @@ def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_ if migration_config.migration_backend == "nccl" and parallel_config.world_size == 1: device = torch.device(f"cuda:{self.local_rank}") _, total_memory = torch.cuda.mem_get_info(device) - migrate_ratio = math.ceil(dummy_cache_size / total_memory * 10000) / 10000 - cache_config.gpu_memory_utilization -= migrate_ratio + migration_memory_ratio = math.ceil(dummy_cache_size / total_memory * 10000) / 10000 + cache_config.gpu_memory_utilization -= migration_memory_ratio if cache_config.gpu_memory_utilization <= 0: raise ValueError("Nccl migration backend take {:.4f} gpu memory, which is greater than gpu_memory_utilization {:.4f}. " - "try to increase gpu-memory-utilization or reduce migration-cache-blocks." - .format(migrate_ratio, cache_config.gpu_memory_utilization)) + "try to increase gpu-memory-utilization or reduce migration-buffer-blocks." + .format(migration_memory_ratio, cache_config.gpu_memory_utilization)) logger.info("nccl migration backend take {:.4f} gpu memory, left gpu_memory_utilization {:.4f} for kv cache." - .format(migrate_ratio, cache_config.gpu_memory_utilization)) + .format(migration_memory_ratio, cache_config.gpu_memory_utilization)) return dummy_cache_size diff --git a/llumnix/config/default.py b/llumnix/config/default.py index a1f48043..ec6d060e 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -28,20 +28,24 @@ _C.SERVER.HOST = "localhost" # Port number for the server _C.SERVER.PORT = 8000 -# Path to SSL key file for secure connections +# Path to SSL key file _C.SERVER.SSL_KEYFILE = None -# Path to SSL certificate file for secure connections +# Path to SSL certificate file _C.SERVER.SSL_CERTFILE = None +# Log level for the server +_C.SERVER.LOG_LEVEL = "info" # Queue type for request output queue _C.SERVER.REQUEST_OUTPUT_QUEUE_TYPE = "rayqueue" -# Port number for the request output queue +# Port number for the zmq request output queue _C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234 # Disable logging requests in server _C.SERVER.DISABLE_LOG_REQUESTS_SERVER = False # Enable logging request timestamp _C.SERVER.LOG_REQUEST_TIMESTAMPS = False -# Config file of Llumnix arguments +# Path to config file of arguments _C.SERVER.CONFIG_FILE = None +# Disable keep serve process alive +_C.SERVER.DISABLE_KEEP_SERVE_PROCESS_ALIVE = False # ----------------------------------------------------------------------------- # RAY CONFIGURATION @@ -55,19 +59,22 @@ # MANAGER CONFIGURATION # ----------------------------------------------------------------------------- _C.MANAGER = LC() +# Number of instances created at initialization +_C.MANAGER.INITIAL_INSTANCES = 1 +# Time interval(s) to update instance info and pair migration +_C.MANAGER.POLLING_INTERVAL = 0.05 # Disable logging requests in manager _C.MANAGER.DISABLE_LOG_REQUESTS_MANAGER = False # Enable logging instance info _C.MANAGER.LOG_INSTANCE_INFO = False # Log filename _C.MANAGER.LOG_FILENAME = "server.log" -# Profiling result file path -_C.MANAGER.PROFILING_RESULT_FILE_PATH = "" - -# Number of instances created at initialization -_C.MANAGER.INITIAL_INSTANCES = 1 -# Time interval(s) to update instance info and pair migration -_C.MANAGER.POLLING_INTERVAL = 0.05 +# Enable simulator mode +_C.MANAGER.SIMULATOR_MODE = False +# Profiling result file path when using simulator +_C.MANAGER.PROFILING_RESULT_FILE_PATH = None +# Enable port increment when deploying multiple servers +_C.MANAGER.ENABLE_PORT_INCREMENT = False # ----------------------------------------------------------------------------- # DISPATCH CONFIGURATION @@ -76,14 +83,14 @@ _C.MANAGER.LOAD_METRIC = 'remaining_steps' # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' -# Number of available dispatch instances. math.inf indicates that all instances can be used for dispatching -_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf # ----------------------------------------------------------------------------- # MIGRATION CONFIGURATION # ----------------------------------------------------------------------------- # Enable migrate requests between instances _C.MANAGER.ENABLE_MIGRATION = False +# Enable defragmentation through migration based on virtual usage +_C.MANAGER.ENABLE_DEFRAG = False # Pair migration frequency _C.MANAGER.PAIR_MIGRATION_FREQUENCY = 1 # Pair migration policy @@ -92,8 +99,6 @@ _C.MANAGER.MIGRATE_OUT_THRESHOLD = 3.0 # Request migration policy _C.MANAGER.REQUEST_MIGRATION_POLICY = 'SR' -# Enable defragmentation through migration based on virtual usage -_C.MANAGER.ENABLE_DEFRAG = False # Drop migration if the number of stages > max_stages _C.MANAGER.MAX_STAGES = 3 # If the number of remain blocks < last_stage_max_blocks, do last stage migration @@ -101,23 +106,23 @@ # Communication backend of migration _C.MANAGER.MIGRATION_BACKEND = "gloo" -# Transfer type for migration backend kvTransfer -_C.MANAGER.MIGRATION_BACKEND_TRANSFER_TYPE = "rdma" -# Address of grpc server for migration backend -_C.MANAGER.MIGRATION_BACKEND_SERVER_ADDRESS = "127.0.0.1:50051" -# URL of naming server for kvtransfer migration backend -_C.MANAGER.MIGRATION_BACKEND_KVTRANSFER_NAMING_URL = "file:/tmp/llumnix/naming/" -# Timeout(s) for initializing migration backend -_C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 # Number of cache blocks in migration _C.MANAGER.MIGRATION_BUFFER_BLOCKS = 512 # Number of kv-cache layers to transfer in each round during migration _C.MANAGER.MIGRATION_NUM_LAYERS = 1 +# Timeout(s) for initializing migration backend +_C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 +# Transfer type for migration backend kvTransfer +_C.MANAGER.MIGRATION_BACKEND_TRANSFER_TYPE = "rdma" +# Address of grpc server for migration backend +_C.MANAGER.GRPC_MIGRATION_BACKEND_SERVER_ADDRESS = "127.0.0.1:50051" +# URL of naming server for kvtransfer migration backend +_C.MANAGER.KVTRANSFER_MIGRATION_BACKEND_NAMING_URL = "file:/tmp/llumnix/naming/" # ----------------------------------------------------------------------------- # SCALING CONFIGURATION # ----------------------------------------------------------------------------- -# Enable scaling instances based on load +# Enable auto scaling _C.MANAGER.ENABLE_SCALING = False # Minimum number of instances _C.MANAGER.MIN_INSTANCES = 1 @@ -137,3 +142,5 @@ # ----------------------------------------------------------------------------- # Enable prefill decoding disaggregation _C.MANAGER.ENABLE_PD_DISAGG = False +# Number of available instances for dispatch. math.inf indicates that all instances can be used for dispatching +_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf diff --git a/llumnix/entrypoints/bladellm/api_server.py b/llumnix/entrypoints/bladellm/api_server.py index 836b8757..537798f5 100644 --- a/llumnix/entrypoints/bladellm/api_server.py +++ b/llumnix/entrypoints/bladellm/api_server.py @@ -14,35 +14,34 @@ import asyncio from blade_llm.service.args import ServingArgs -from llumnix.config import get_llumnix_config, LlumnixConfig +from llumnix.config import get_llumnix_config from llumnix.backends.backend_interface import BackendType -from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs, LlumnixArgumentParser -from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix, is_gpu_available +from llumnix.arg_utils import (EntrypointsArgs, ManagerArgs, LlumnixArgumentParser, + LaunchArgs) +from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix from llumnix.entrypoints.bladellm.client import LlumnixClientBladeLLM -from llumnix.entrypoints.setup import LlumnixEntrypointsContext from llumnix.entrypoints.bladellm.utils import get_args +from llumnix.entrypoints.utils import EntrypointsContext, LaunchMode, is_gpu_available + def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractEventLoop): # generate llumnix_parser for checking parameters with choices - llumnix_parser: LlumnixArgumentParser = LlumnixArgumentParser() - llumnix_parser = LlumnixEntrypointsArgs.add_cli_args(llumnix_parser) - llumnix_parser = EngineManagerArgs.add_cli_args(llumnix_parser) - llumnix_config: LlumnixConfig = get_llumnix_config(bladellm_args.llumnix_config) - _, engine_manager_args, engine_args = get_args(llumnix_config, llumnix_parser, bladellm_args) + llumnix_parser = LlumnixArgumentParser() + llumnix_parser = EntrypointsArgs.add_cli_args(llumnix_parser) + llumnix_parser = ManagerArgs.add_cli_args(llumnix_parser) + llumnix_config = get_llumnix_config(bladellm_args.llumnix_config) + entrypoints_args, manager_args, engine_args = get_args(llumnix_config, llumnix_parser, bladellm_args) + + assert not manager_args.simulator_mode, "Only support the simulator mode for vLLM." + launch_args = LaunchArgs(launch_mode=LaunchMode.LOCAL, backend_type=BackendType.BLADELLM) - setup_ray_cluster(llumnix_config) + setup_ray_cluster(entrypoints_args) - llm_client = None + llumnix_client = None # if gpu is not available, it means that this node is head pod x any llumnix components if is_gpu_available(): - world_size = engine_args.tensor_parallel_size * engine_args.pipeline_parallel_size - instance_ids = None - if engine_args.enable_disagg: - instance_ids = [engine_args.disagg_options.inst_id] - - llumnix_context: LlumnixEntrypointsContext = \ - setup_llumnix(engine_manager_args, engine_args, llumnix_config, BackendType.BLADELLM, - world_size, instance_ids=instance_ids) - llm_client = LlumnixClientBladeLLM(bladellm_args, llumnix_context, loop) + llumnix_context: EntrypointsContext = \ + setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) + llumnix_client = LlumnixClientBladeLLM(bladellm_args, llumnix_context, loop) - return llm_client + return llumnix_client diff --git a/llumnix/entrypoints/bladellm/client.py b/llumnix/entrypoints/bladellm/client.py index 67a40af6..3eadd8fd 100644 --- a/llumnix/entrypoints/bladellm/client.py +++ b/llumnix/entrypoints/bladellm/client.py @@ -28,15 +28,18 @@ from blade_llm.service.communications.response import error_resp from llumnix.server_info import RequestTimestamps -from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.entrypoints.setup import EntrypointsContext from llumnix.logger import init_logger logger = init_logger(__name__) WAIT_MANAGER_INTERVAL = 5 +# TODO(KuilongCui): Update LlumnixCient of BladeLLM. + + class LlumnixClientBladeLLM(MultiProcessingLLMClient): - def __init__(self, args: ServingArgs, llumnix_context: LlumnixEntrypointsContext, loop: asyncio.AbstractEventLoop): + def __init__(self, args: ServingArgs, llumnix_context: EntrypointsContext, loop: asyncio.AbstractEventLoop): super().__init__(args, -1) self.entrypoint_id2llumnix_id = {} self.llumnix_id2entrypoint_id = {} @@ -56,7 +59,7 @@ async def background_process_outputs(self): continue await self.request_streams[request_id].put(request_output) if request_output.is_finished: - logger.info("Client Recv: {}".format(request_output)) + logger.debug("client recv request output: {}".format(request_output)) del self.entrypoint_id2llumnix_id[self.llumnix_id2entrypoint_id[request_id]] del self.llumnix_id2entrypoint_id[request_id] del self.request_streams[request_id] @@ -110,7 +113,7 @@ async def _manager_generate(self, request, request_id: str) -> LLMResponse: return await asyncio.create_task(self._manager_generate(request, request_id)) except (ray.exceptions.RayActorError, KeyError): if instance_id in self.llumnix_context.instances: - logger.info("[manager_generate] instance {} is dead".format(instance_id)) + logger.info("[_manager_generate] instance {} is dead".format(instance_id)) del self.llumnix_context.instances[instance_id] del self.llumnix_context.instance_num_requests[instance_id] return await asyncio.create_task(self._manager_generate(request, request_id)) diff --git a/llumnix/entrypoints/bladellm/utils.py b/llumnix/entrypoints/bladellm/utils.py index 3b9f8d14..3fa94cd6 100644 --- a/llumnix/entrypoints/bladellm/utils.py +++ b/llumnix/entrypoints/bladellm/utils.py @@ -15,7 +15,7 @@ from loguru import logger from blade_llm.service.args import ServingArgs -from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs +from llumnix.arg_utils import EntrypointsArgs, ManagerArgs def detect_unsupported_feature(engine_args: ServingArgs) -> None: unsupported_feature = None @@ -31,24 +31,24 @@ def detect_unsupported_feature(engine_args: ServingArgs) -> None: if unsupported_feature: raise ValueError(f'Llumnix does not support "{unsupported_feature}" for bladeLLM currently.') -def check_engine_args(engine_args: ServingArgs, engine_manager_args: EngineManagerArgs) -> None: - migration_config = engine_manager_args.create_migration_config() +def check_engine_args(engine_args: ServingArgs, manager_args: ManagerArgs) -> None: + migration_config = manager_args.create_migration_config() if (engine_args.tensor_parallel_size > 1 or engine_args.tensor_parallel_size > 1) and \ migration_config.migration_backend == 'nccl': - logger.info("Llumnix does not support TP or PP enabled model when the migration backend is nccl, \ - change migration backend to gloo.") - engine_manager_args.migration_backend = 'gloo' + logger.warning("Llumnix does not support TP or PP when the migration backend is nccl, \ + change migration backend to gloo.") + manager_args.migration_backend = 'gloo' detect_unsupported_feature(engine_args) def get_args(llumnix_cfg, llumnix_parser, engine_args): - llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(llumnix_cfg) - LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, llumnix_parser) - engine_manager_args = EngineManagerArgs.from_llumnix_config(llumnix_cfg) - EngineManagerArgs.check_args(engine_manager_args, llumnix_parser) - check_engine_args(engine_args, engine_manager_args) - - logger.info("llumnix_entrypoints_args: {}", llumnix_entrypoints_args) - logger.info("engine_manager_args: {}", engine_manager_args) + entrypoints_args = EntrypointsArgs.from_llumnix_config(llumnix_cfg) + EntrypointsArgs.check_args(entrypoints_args, llumnix_parser) + manager_args = ManagerArgs.from_llumnix_config(llumnix_cfg) + ManagerArgs.check_args(manager_args, llumnix_parser) + check_engine_args(engine_args, manager_args) + + logger.info("entrypoints_args: {}", entrypoints_args) + logger.info("manager_args: {}", manager_args) logger.info("engine_args: {}", engine_args) - return llumnix_entrypoints_args, engine_manager_args, engine_args + return entrypoints_args, manager_args, engine_args diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index 378f1205..16f2a5f3 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -15,49 +15,27 @@ import sys import os import time -from typing import Dict -import asyncio -import socket +from typing import Dict, Optional, List, Tuple import ray -from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME +from llumnix.manager import Manager from llumnix.llumlet.llumlet import Llumlet from llumnix.logger import init_logger -from llumnix.utils import random_uuid -from llumnix.arg_utils import EngineManagerArgs +from llumnix.utils import random_uuid, get_manager_name +from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs from llumnix.queue.queue_type import QueueType -from llumnix.server_info import ServerInfo, RequestTimestamps +from llumnix.server_info import ServerInfo from llumnix.queue.utils import init_request_output_queue_server +from llumnix.entrypoints.utils import EntrypointsContext, get_ip_address, retry_manager_method_sync +from llumnix.entrypoints.utils import LaunchMode +from llumnix.backends.backend_interface import BackendType from llumnix.queue.queue_server_base import QueueServerBase +MAX_RAY_RESTARTS = 5 +RAY_RESTART_INTERVALS = 10 + logger = init_logger(__name__) -MAX_RESTARTS = 30 -RESTART_INTERVALS = 1 -MAX_TASK_RETRIES = 300 -RETRIES_INTERVALS = 0.1 - - -class LlumnixEntrypointsContext: - def __init__(self, - manager: LLMEngineManager, - instances: Dict[str, Llumlet], - request_output_queue: QueueServerBase, - server_info: ServerInfo, - log_requests: bool, - log_request_timestamps: bool): - self.manager = manager - self.instances = instances - self.request_output_queue = request_output_queue - self.server_info = server_info - self.log_requests = log_requests - self.log_request_timestamps = log_request_timestamps - - -def get_ip_address(): - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - return ip_address def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: head_node_ip = os.getenv('HEAD_NODE_IP') @@ -66,11 +44,11 @@ def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: # Stop the existing ray processes on the node first. subprocess.run(['ray', 'stop'], check=True, text=True, capture_output=True) except subprocess.CalledProcessError as e: - logger.info("'ray stop' failed with: \n{}".format(e.stderr)) + logger.error("'ray stop' failed with: \n{}".format(e.stderr)) sys.exit(1) # Need to specify the head node ip through environment variable currently. if head_node_ip is None: - logger.info("Environment variable 'HEAD_NODE_IP' should be set for ray cluster launch.") + logger.error("Environment variable 'HEAD_NODE_IP' should be set for ray cluster launch.") sys.exit(1) ray_start_command = None if 'HEAD_NODE' in os.environ: @@ -78,160 +56,108 @@ def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: try: result = subprocess.run(['ray', 'start', '--head', f'--port={port}'], check=True, text=True, capture_output=True) except subprocess.CalledProcessError as e: - logger.info("'{}' failed with: \n{}".format(ray_start_command, e.stderr)) + logger.error("'{}' failed with: \n{}".format(ray_start_command, e.stderr)) sys.exit(1) else: ray_start_command = f"ray start --address={head_node_ip}:{port} --node-ip-address={node_ip_address}" - for attempt in range(MAX_RESTARTS): + for attempt in range(MAX_RAY_RESTARTS): try: # wait about 2 mins by default result = subprocess.run(['ray', 'start', f'--address={head_node_ip}:{port}'], check=True, text=True, capture_output=True) break except subprocess.CalledProcessError as e: - if attempt < MAX_RESTARTS: - print("Execute '{}' repeatedly until the head node starts...".format(ray_start_command)) - time.sleep(RESTART_INTERVALS) + if attempt < MAX_RAY_RESTARTS: + logger.warning("execute '{}' repeatedly until the head node starts".format(ray_start_command)) + time.sleep(RAY_RESTART_INTERVALS) else: - logger.info("'{}' failed after {} attempts with: \n{}".format(ray_start_command, attempt, e.stderr)) + logger.error("'{}' failed after {} attempts with: \n{}".format(ray_start_command, attempt, e.stderr)) sys.exit(1) logger.info("'{}' succeeed with: \n{}".format(ray_start_command, result.stdout)) return result -def connect_to_ray_cluster(port: int, namespace="llumnix") -> None: - head_node_ip = os.getenv('HEAD_NODE_IP') - ray.init(address=f"{head_node_ip}:{port}", ignore_reinit_error=True, namespace=namespace) - -def setup_ray_cluster(cfg): - if cfg.SERVER.LAUNCH_RAY_CLUSTER: - launch_ray_cluster(cfg.SERVER.RAY_CLUSTER_PORT) - connect_to_ray_cluster(port=cfg.SERVER.RAY_CLUSTER_PORT) - -def is_gpu_available() -> bool: - try: - subprocess.check_output(['nvidia-smi']) - return True - except (subprocess.CalledProcessError, FileNotFoundError): - return False - -def retry_manager_method_sync(ray_call, method_name, *args, **kwargs): - for attempt in range(MAX_TASK_RETRIES): - try: - ret = ray.get(ray_call(*args, **kwargs)) - break - except ray.exceptions.RayActorError: - if attempt < MAX_TASK_RETRIES - 1: - logger.info("Manager is unavailable, sleep {}s, and retry {} again...".format(RETRIES_INTERVALS, method_name)) - time.sleep(RETRIES_INTERVALS) - else: - logger.info("After {} times retries, manager is still unavailable".format(MAX_TASK_RETRIES)) - raise - return ret - -async def retry_manager_method_async(ray_call, method_name, *args, **kwargs): - for attempt in range(MAX_TASK_RETRIES): - try: - ret = await ray_call(*args, **kwargs) - break - except ray.exceptions.RayActorError: - if attempt < MAX_TASK_RETRIES - 1: - logger.info("Manager is unavailable, sleep {}s, and retry {} again...".format(RETRIES_INTERVALS, method_name)) - await asyncio.sleep(RETRIES_INTERVALS) - else: - logger.info("After {} times retries, manager is still unavailable".format(MAX_TASK_RETRIES)) - raise - return ret - -def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager: +def connect_to_ray_cluster(head_node_ip: str = None, port: int = None, namespace="llumnix") -> None: + if head_node_ip is not None and port is not None: + ray.init(address=f"{head_node_ip}:{port}", ignore_reinit_error=True, namespace=namespace) + else: + ray.init(ignore_reinit_error=True, namespace=namespace) + +def setup_ray_cluster(entrypoints_args) -> None: + if entrypoints_args.launch_ray_cluster: + launch_ray_cluster(entrypoints_args.ray_cluster_port) + connect_to_ray_cluster(head_node_ip=os.getenv('HEAD_NODE_IP'), port=entrypoints_args.ray_cluster_port, namespace="llumnix") + +def init_manager(manager_args: ManagerArgs, + entrypoints_args: EntrypointsArgs = None, + engine_args = None, + launch_args: LaunchArgs = None, + ) -> Manager: # Only one instance create the manager actor, the other instances get the existing manager actor through ray. try: - manager = LLMEngineManager.from_args(engine_manager_args, None) - logger.info("Init LLMEngineManager on current node") + manager = Manager.from_args(manager_args=manager_args, + entrypoints_args=entrypoints_args, + engine_args=engine_args, + launch_args=launch_args) + logger.info("Init Manager on current node.") except ValueError: - manager = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') - logger.info("Get existing LLMEngineManager") + manager = ray.get_actor(get_manager_name(), namespace='llumnix') + logger.info("Get existing Manager.") return manager -def init_llumnix_components(engine_manager_args: EngineManagerArgs, +def init_llumnix_components(manager_args: ManagerArgs, engine_args, request_output_queue_type: QueueType, - ip: str, request_output_queue_port: str, - *args, - **kwargs - ): - manager = init_manager(engine_manager_args) - instance_ids, llumlets = retry_manager_method_sync( - manager.init_llumlets.remote, 'init_llumlets', engine_args, request_output_queue_type, *args, **kwargs) - - available_instance_ids = [] - dead_instance_ids = [] - available_llumlets = [] - ready_tasks = [llumlet.is_ready.remote() for llumlet in llumlets] - for idx, task in enumerate(ready_tasks): - try: - ray.get(task) - available_instance_ids.append(instance_ids[idx]) - available_llumlets.append(llumlets[idx]) - except ray.exceptions.RayActorError: - dead_instance_ids.append(instance_ids[idx]) - if len(dead_instance_ids) > 0: - retry_manager_method_sync(manager.scale_down.remote, 'scale_down', dead_instance_ids) - if len(available_instance_ids) > 0: - retry_manager_method_sync(manager.scale_up.remote, 'scale_up', - available_instance_ids, available_llumlets) - logger.info("Init Llumnix components done, {} instances are ready, instance_ids: {}." - .format(len(available_instance_ids), available_instance_ids)) + backend_type: BackendType) -> Tuple[Manager, List[str], List[Llumlet], QueueServerBase]: + manager = init_manager(manager_args) + + instance_ids, instances = retry_manager_method_sync( + manager.init_instances.remote, 'init_instances', request_output_queue_type, backend_type, engine_args) + ip = get_ip_address() request_output_queue = init_request_output_queue_server(ip, request_output_queue_port, request_output_queue_type) - return manager, available_instance_ids, available_llumlets, request_output_queue + return manager, instance_ids, instances, request_output_queue + +def setup_entrypoints_context(entrypoints_args, manager, instance_ids, instances, request_output_queue) -> EntrypointsContext: + instances_dict: Dict[str, Llumlet] = {} + for idx, ins_id in enumerate(instance_ids): + instances_dict[ins_id] = instances[idx] -def setup_llumnix(engine_manager_args, engine_args, cfg, *args, **kwargs): - ip = get_ip_address() - manager, instance_ids, llumlets, request_output_queue = \ - init_llumnix_components(engine_manager_args, - engine_args, - cfg.SERVER.REQUEST_OUTPUT_QUEUE_TYPE, - ip, - cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT, - *args, - **kwargs) server_id = random_uuid() + ip = get_ip_address() server_info = ServerInfo(server_id, - cfg.SERVER.REQUEST_OUTPUT_QUEUE_TYPE, + QueueType(entrypoints_args.request_output_queue_type), request_output_queue, ip, - cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT) - instances: Dict[str, Llumlet] = {} - for idx, ins_id in enumerate(instance_ids): - instances[ins_id] = llumlets[idx] + entrypoints_args.request_output_queue_port) - log_requests = not cfg.SERVER.DISABLE_LOG_REQUESTS_SERVER - log_request_timestamps = cfg.SERVER.LOG_REQUEST_TIMESTAMPS + log_requests = not entrypoints_args.disable_log_requests_server + log_request_timestamps = entrypoints_args.log_request_timestamps logger.info("log_requests: {}, log_request_timestamps: {}".format(log_requests, log_request_timestamps)) - llumnix_entrypoints_context = LlumnixEntrypointsContext(manager, - instances, - request_output_queue, - server_info, - log_requests, - log_request_timestamps) - - return llumnix_entrypoints_context - -def init_per_token_latency_breakdown_dict() -> Dict[str, int]: - per_token_latency_breakdown_dict = { - 'step_latency_engine': [], - 'process_model_outputs_latency': [], - 'step_postprocess_latency': [], - 'across_async_put_queue_thread_latency': [], - 'across_async_put_queue_actor_latency': [], - 'queue_rpc_latency': [], - 'background_process_get_queue_latency': [], - 'generate_benchmark_return_output_latency': [] - } - return per_token_latency_breakdown_dict - -def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: RequestTimestamps): - for key in per_token_latency_breakdown_dict.keys(): - per_token_latency_breakdown_dict[key].append(getattr(request_timestamps, key)) + entrypoints_context = EntrypointsContext(manager, + instances_dict, + request_output_queue, + server_info, + log_requests, + log_request_timestamps) + + return entrypoints_context +def _setup_llumnix_local(manager_args, entrypoints_args, engine_args, launch_args) -> EntrypointsContext: + manager, instance_ids, instances, request_output_queue = \ + init_llumnix_components(manager_args, + engine_args, + QueueType(entrypoints_args.request_output_queue_type), + entrypoints_args.request_output_queue_port, + launch_args.backend_type) + + return setup_entrypoints_context(entrypoints_args, manager, instance_ids, instances, request_output_queue) + +def _setup_llumnix_global(manager_args, entrypoints_args, engine_args, launch_args) -> None: + _ = init_manager(manager_args, entrypoints_args, engine_args, launch_args) + +def setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) -> Optional[EntrypointsContext]: + if launch_args.launch_mode == LaunchMode.LOCAL: + return _setup_llumnix_local(manager_args, entrypoints_args, engine_args, launch_args) + + return _setup_llumnix_global(manager_args, entrypoints_args, engine_args, launch_args) diff --git a/llumnix/entrypoints/utils.py b/llumnix/entrypoints/utils.py new file mode 100644 index 00000000..31c3fa28 --- /dev/null +++ b/llumnix/entrypoints/utils.py @@ -0,0 +1,92 @@ +import socket +from enum import Enum +from typing import Dict +import subprocess +import asyncio +import time +import ray + +from llumnix.logger import init_logger + +MAX_TASK_RETRIES = 300 +RETRIES_INTERVALS = 0.1 + +logger = init_logger(__name__) + + +class LaunchMode(str, Enum): + LOCAL = "LOCAL" + GLOBAL = "GLOBAL" + +# Use "" type hint to avoid circular import. +class EntrypointsContext: + def __init__(self, + manager: "Manager", + instances: Dict[str, "Llumlet"], + request_output_queue: "QueueServerBase", + server_info: "ServerInfo", + log_requests: bool, + log_request_timestamps: bool): + self.manager = manager + self.instances = instances + self.request_output_queue = request_output_queue + self.server_info = server_info + self.log_requests = log_requests + self.log_request_timestamps = log_request_timestamps + +def get_ip_address(): + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + return ip_address + +def is_gpu_available() -> bool: + try: + subprocess.check_output(['nvidia-smi']) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + +def retry_manager_method_sync(ray_call, method_name, *args, **kwargs): + for attempt in range(MAX_TASK_RETRIES): + try: + ret = ray.get(ray_call(*args, **kwargs)) + break + except ray.exceptions.RayActorError: + if attempt < MAX_TASK_RETRIES - 1: + logger.warning("manager is unavailable, sleep {}s, and retry {} again".format(RETRIES_INTERVALS, method_name)) + time.sleep(RETRIES_INTERVALS) + else: + logger.error("manager is still unavailable after {} times retries".format(MAX_TASK_RETRIES)) + raise + return ret + +async def retry_manager_method_async(ray_call, method_name, *args, **kwargs): + for attempt in range(MAX_TASK_RETRIES): + try: + ret = await ray_call(*args, **kwargs) + break + except ray.exceptions.RayActorError: + if attempt < MAX_TASK_RETRIES - 1: + logger.warning("manager is unavailable, sleep {}s, and retry {} again".format(RETRIES_INTERVALS, method_name)) + await asyncio.sleep(RETRIES_INTERVALS) + else: + logger.error("manager is still unavailable after {} times retries".format(MAX_TASK_RETRIES)) + raise + return ret + +def init_per_token_latency_breakdown_dict() -> Dict[str, int]: + per_token_latency_breakdown_dict = { + 'step_latency_engine': [], + 'process_model_outputs_latency': [], + 'step_postprocess_latency': [], + 'across_async_put_queue_thread_latency': [], + 'across_async_put_queue_actor_latency': [], + 'queue_rpc_latency': [], + 'background_process_get_queue_latency': [], + 'generate_benchmark_return_output_latency': [] + } + return per_token_latency_breakdown_dict + +def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: "RequestTimestamps"): + for key in per_token_latency_breakdown_dict.keys(): + per_token_latency_breakdown_dict[key].append(getattr(request_timestamps, key)) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 46cbf842..a1e1b955 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -22,19 +22,16 @@ from vllm.sampling_params import SamplingParams -from llumnix.arg_utils import LlumnixArgumentParser -from llumnix.entrypoints.setup import (setup_ray_cluster, - setup_llumnix, - is_gpu_available, - init_per_token_latency_breakdown_dict, - record_per_token_latency_breakdown) -from llumnix.entrypoints.vllm.arg_utils import (add_cli_args, - get_args) +from llumnix.arg_utils import LlumnixArgumentParser, LaunchArgs +from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix +from llumnix.entrypoints.utils import init_per_token_latency_breakdown_dict, record_per_token_latency_breakdown +from llumnix.entrypoints.vllm.arg_utils import add_cli_args, get_args from llumnix.entrypoints.vllm.client import LlumnixClientVLLM from llumnix.logger import init_logger from llumnix.utils import random_uuid -from llumnix.config import get_llumnix_config, LlumnixConfig +from llumnix.config import get_llumnix_config from llumnix.backends.backend_interface import BackendType +from llumnix.entrypoints.utils import LaunchMode, is_gpu_available # Code file with __main__ should set the logger name to inherit the llumnix logger configuration. logger = init_logger("llumnix.entrypoints.vllm.api_server") @@ -147,8 +144,8 @@ async def generate_benchmark(request: Request) -> Response: if llumnix_client.log_requests: llumnix_client.num_finished_requests += 1 - logger.info("entrypoints finished request {}.".format(request_id)) - logger.info("num_finished_requests {}.".format(llumnix_client.num_finished_requests)) + logger.info("entrypoints finished request {}".format(request_id)) + logger.info("num_finished_requests {}".format(llumnix_client.num_finished_requests)) generation = final_output.outputs[0].text num_output_tokens = len(final_output.outputs[0].token_ids) @@ -168,7 +165,7 @@ async def generate_benchmark(request: Request) -> Response: @app.get("/is_ready") -async def is_ready(): +async def is_ready() -> bool: return await llumnix_client.is_ready() @@ -179,27 +176,29 @@ async def is_ready(): parser.add_argument("--port", type=int) parser.add_argument("--ssl-keyfile", type=str) parser.add_argument("--ssl-certfile", type=str) + parser.add_argument("--log-level", type=str, choices=["debug", "info", "warning", "error"]) cli_args = add_cli_args(parser) - cfg: LlumnixConfig = get_llumnix_config(cli_args.config_file, cli_args) - _, engine_manager_args, engine_args = get_args(cfg, parser, cli_args) + cfg = get_llumnix_config(cli_args.config_file, cli_args) + entrypoints_args, manager_args, engine_args = get_args(cfg, parser, cli_args) + + backend_type = BackendType.VLLM if not manager_args.simulator_mode else BackendType.SIM_VLLM + launch_args = LaunchArgs(launch_mode=LaunchMode.LOCAL, backend_type=backend_type) # Launch or connect to the ray cluster for multi-node serving. - setup_ray_cluster(cfg) + setup_ray_cluster(entrypoints_args) # if gpu is not available, it means that this node is head pod without any llumnix components if is_gpu_available(): - engine_config = engine_args.create_engine_config() - parallel_config = engine_config.parallel_config - llumnix_entrypoints_context = setup_llumnix(engine_manager_args, engine_args, cfg, BackendType.VLLM, parallel_config.world_size) - llumnix_client = LlumnixClientVLLM(llumnix_entrypoints_context) + entrypoints_context = setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) + llumnix_client = LlumnixClientVLLM(entrypoints_context) # Start the api server after all the components of llumnix are ready. - logger.info("Start Api Server on '{}:{}'".format(cfg.SERVER.HOST, cfg.SERVER.PORT)) + logger.info("Start api server on '{}:{}'.".format(entrypoints_args.host, entrypoints_args.port)) uvicorn.run(app, - host=cfg.SERVER.HOST, - port=cfg.SERVER.PORT, - log_level="debug", + host=entrypoints_args.host, + port=entrypoints_args.port, + log_level=entrypoints_args.log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=cfg.SERVER.SSL_KEYFILE, - ssl_certfile=cfg.SERVER.SSL_CERTFILE) + ssl_keyfile=entrypoints_args.ssl_keyfile, + ssl_certfile=entrypoints_args.ssl_certfile) diff --git a/llumnix/entrypoints/vllm/api_server_actor.py b/llumnix/entrypoints/vllm/api_server_actor.py new file mode 100644 index 00000000..e2bf0fbe --- /dev/null +++ b/llumnix/entrypoints/vllm/api_server_actor.py @@ -0,0 +1,92 @@ +import threading +import traceback +import uvicorn + +import ray +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from llumnix.arg_utils import EntrypointsArgs +from llumnix.entrypoints.utils import EntrypointsContext, get_ip_address +from llumnix.llumlet.llumlet import Llumlet +from llumnix.queue.utils import init_request_output_queue_server, QueueType +from llumnix.logger import init_logger + +logger = init_logger(__name__) + + +class FastAPIServerActor: + def __init__(self, entrypoints_args: EntrypointsArgs): + self.entrypoints_args = entrypoints_args + self.request_output_queue_port = self.entrypoints_args.request_output_queue_port + self.request_output_queue_type = QueueType(self.entrypoints_args.request_output_queue_type) + ip = get_ip_address() + self.request_output_queue = init_request_output_queue_server( + ip, self.request_output_queue_port, self.request_output_queue_type) + + def _setup_entrypoints_context(self, + manager: "ray.actor.ActorHandle", + instance_id: str, + instance: Llumlet): + # avoid circular import + # pylint: disable=import-outside-toplevel + from llumnix.entrypoints.setup import setup_entrypoints_context + self.entrypoints_context = setup_entrypoints_context( + self.entrypoints_args,manager, [instance_id], [instance], self.request_output_queue) + + def _run_uvicorn_server(self, + entrypoints_args: EntrypointsArgs, + entrypoints_context: EntrypointsContext): + # pylint: disable=import-outside-toplevel + import llumnix.entrypoints.vllm.api_server + from llumnix.entrypoints.vllm.client import LlumnixClientVLLM + llumnix.entrypoints.vllm.api_server.llumnix_client = LlumnixClientVLLM(entrypoints_context) + app = llumnix.entrypoints.vllm.api_server.app + + logger.info("Start api server on '{}:{}'.".format(entrypoints_args.host, entrypoints_args.port)) + uvicorn.run(app, + host=entrypoints_args.host, + port=entrypoints_args.port, + log_level=entrypoints_args.log_level, + timeout_keep_alive=llumnix.entrypoints.vllm.api_server.TIMEOUT_KEEP_ALIVE, + ssl_keyfile=entrypoints_args.ssl_keyfile, + ssl_certfile=entrypoints_args.ssl_certfile) + + def run(self, + manager: "ray.actor.ActorHandle", + instance_id: str, + instance: Llumlet): + self._setup_entrypoints_context(manager, instance_id, instance) + self.run_uvicorn_server_thread = threading.Thread( + target=self._run_uvicorn_server, args=(self.entrypoints_args, self.entrypoints_context), + daemon=True, name="run_uvicorn_server" + ) + self.run_uvicorn_server_thread.start() + + @classmethod + def from_args(cls, + server_name: str, + placement_group: PlacementGroup, + entrypoints_args: EntrypointsArgs): + try: + fastapi_server_class = ray.remote(num_cpus=1, + name=server_name, + namespace="llumnix", + lifetime="detached")(cls).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=0, + placement_group_capture_child_tasks=True + ) + ) + fastapi_server = fastapi_server_class.remote(entrypoints_args) + # pylint: disable=broad-except + except Exception as e: + logger.error("failed to initialize FastAPIServer: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) + raise + + return fastapi_server + + def is_ready(self) -> bool: + return True diff --git a/llumnix/entrypoints/vllm/arg_utils.py b/llumnix/entrypoints/vllm/arg_utils.py index bb7daacd..6329b227 100644 --- a/llumnix/entrypoints/vllm/arg_utils.py +++ b/llumnix/entrypoints/vllm/arg_utils.py @@ -1,7 +1,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from llumnix.backends.vllm.utils import check_engine_args -from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs +from llumnix.arg_utils import EntrypointsArgs, ManagerArgs from llumnix.logger import init_logger logger = init_logger(__name__) @@ -9,23 +9,23 @@ def add_cli_args(parser): parser.set_namespace("llumnix") - parser = LlumnixEntrypointsArgs.add_cli_args(parser) - parser = EngineManagerArgs.add_cli_args(parser) + parser = EntrypointsArgs.add_cli_args(parser) + parser = ManagerArgs.add_cli_args(parser) parser.set_namespace("vllm") parser = AsyncEngineArgs.add_cli_args(parser) cli_args = parser.parse_args() return cli_args def get_args(cfg, parser, cli_args): - llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(cfg) - LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, parser) - engine_manager_args = EngineManagerArgs.from_llumnix_config(cfg) - EngineManagerArgs.check_args(engine_manager_args, parser) + entrypoints_args = EntrypointsArgs.from_llumnix_config(cfg) + EntrypointsArgs.check_args(entrypoints_args, parser) + manager_args = ManagerArgs.from_llumnix_config(cfg) + ManagerArgs.check_args(manager_args, parser) engine_args = AsyncEngineArgs.from_cli_args(cli_args) - check_engine_args(engine_args, engine_manager_args) + check_engine_args(engine_args, manager_args) - logger.info("llumnix_entrypoints_args: {}".format(llumnix_entrypoints_args)) - logger.info("engine_manager_args: {}".format(engine_manager_args)) + logger.info("entrypoints_args: {}".format(entrypoints_args)) + logger.info("manager_args: {}".format(manager_args)) logger.info("engine_args: {}".format(engine_args)) - return llumnix_entrypoints_args, engine_manager_args, engine_args + return entrypoints_args, manager_args, engine_args diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py index d72cfc79..044c241f 100644 --- a/llumnix/entrypoints/vllm/client.py +++ b/llumnix/entrypoints/vllm/client.py @@ -7,10 +7,12 @@ from vllm import SamplingParams from llumnix.logger import init_logger -from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.entrypoints.setup import EntrypointsContext from llumnix.server_info import RequestTimestamps from llumnix.queue.queue_server_base import QueueServerBase from llumnix.server_info import ServerInfo +from llumnix.manager import Manager +from llumnix.llumlet.llumlet import Llumlet logger = init_logger(__name__) @@ -19,20 +21,20 @@ class LlumnixClientVLLM: def __init__(self, - llumnix_entrypoints_context: LlumnixEntrypointsContext): - self.manager: LLMEngineManager = llumnix_entrypoints_context.manager - self.instances: Dict[str, Llumlet] = llumnix_entrypoints_context.instances - self.request_output_queue: QueueServerBase = llumnix_entrypoints_context.request_output_queue - self.server_info: ServerInfo = llumnix_entrypoints_context.server_info - self.log_requests: bool = llumnix_entrypoints_context.log_requests - self.log_request_timestamps: bool = llumnix_entrypoints_context.log_request_timestamps + entrypoints_context: EntrypointsContext): + self.manager: Manager = entrypoints_context.manager + self.instances: Dict[str, Llumlet] = entrypoints_context.instances + self.request_output_queue: QueueServerBase = entrypoints_context.request_output_queue + self.server_info: ServerInfo = entrypoints_context.server_info + self.log_requests = entrypoints_context.log_requests + self.log_request_timestamps = entrypoints_context.log_request_timestamps self.request_streams: Dict[str, AsyncStream] = {} self.instance_num_requests: Dict[str, int] = {} for ins_id in self.instances.keys(): self.instance_num_requests[ins_id] = 0 - self.num_finished_requests: int = 0 - self.manager_available: bool = True + self.num_finished_requests = 0 + self.manager_available = True async def generate(self, prompt: str, @@ -43,6 +45,8 @@ async def generate(self, if sampling_params.n > 1 or sampling_params.use_beam_search: raise ValueError("Unsupported feature: multiple sequence decoding") + logger.info("[generate] entrypoints received request {}".format(request_id)) + results_generator = AsyncStream(request_id) self.request_streams[request_id] = results_generator server_info_copy = copy.deepcopy(self.server_info) @@ -85,10 +89,10 @@ async def _generate_by_instance(self, instance_id = min(self.instance_num_requests, key=self.instance_num_requests.get) self.instance_num_requests[instance_id] += 1 await self.instances[instance_id].generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs) - logger.info("LLMEngineManager is unavailable temporarily, dispatch request {} to instance {}".format( + logger.warning("Manager is unavailable temporarily, dispatch request {} to instance {}".format( request_id, instance_id)) else: - logger.info("LLMEngineManager is unavailable temporarily, but there is no instance behind this api server, " + logger.warning("Manager is unavailable temporarily, but there is no instance behind this api server, " "sleep {}s, waiting for manager available".format(WAIT_MANAGER_INTERVAL)) await asyncio.sleep(WAIT_MANAGER_INTERVAL) return await asyncio.create_task(self.generate(prompt, sampling_params, request_id, *args, **kwargs)) @@ -110,7 +114,7 @@ async def abort(self, request_id: str) -> None: logger.info("abort request: {}.".format(request_id)) await self.manager.abort.remote(request_id) except ray.exceptions.RayActorError: - logger.info("manager is unavailable") + logger.warning("manager is unavailable") async def is_ready(self) -> bool: ready_status = await self.manager.is_ready.remote() diff --git a/llumnix/entrypoints/vllm/serve.py b/llumnix/entrypoints/vllm/serve.py new file mode 100644 index 00000000..a73f1ce9 --- /dev/null +++ b/llumnix/entrypoints/vllm/serve.py @@ -0,0 +1,43 @@ + +import time +from ray.util.queue import Queue as RayQueue + +from llumnix.entrypoints.vllm.arg_utils import add_cli_args, get_args +from llumnix.entrypoints.setup import connect_to_ray_cluster +from llumnix.config import get_llumnix_config +from llumnix.arg_utils import LlumnixArgumentParser, LaunchArgs +from llumnix.entrypoints.utils import LaunchMode +from llumnix.backends.backend_interface import BackendType +from llumnix.entrypoints.setup import setup_llumnix + + +if __name__ == "__main__": + parser: LlumnixArgumentParser = LlumnixArgumentParser() + + parser.add_argument("--host", type=str) + parser.add_argument("--port", type=int) + parser.add_argument("--ssl-keyfile", type=str) + parser.add_argument("--ssl-certfile", type=str) + parser.add_argument("--log-level", type=str) + parser.add_argument('--disable-keep-serve-process-alive', action='store_true') + + cli_args = add_cli_args(parser) + cfg = get_llumnix_config(cli_args.config_file, cli_args) + entrypoints_args, manager_args, engine_args = get_args(cfg, parser, cli_args) + + backend_type = BackendType.VLLM if not manager_args.simulator_mode else BackendType.SIM_VLLM + launch_args = LaunchArgs(launch_mode=LaunchMode.GLOBAL, backend_type=BackendType.VLLM) + + # Assume that there is an existing ray cluster when using centralized deployment. + connect_to_ray_cluster() + + # magic actor to avoid fast api server actor initialization error + request_output_queue = RayQueue(actor_options={"namespace": "llumnix", + "name": "magic_ray_queue"}) + + setup_llumnix(manager_args, entrypoints_args, engine_args, launch_args) + + # keep the process alive to get the terminal output. + if not entrypoints_args.disable_keep_serve_process_alive: + while True: + time.sleep(100.0) diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 059a4a1e..7adf0008 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -49,7 +49,7 @@ def dispatch(self) -> str: if self.num_requests % 100 == 0: logger.info("self.num_requests: {}".format(self.num_requests)) for instance_id, num_requests in self.instance_num_requests.items(): - logger.info("Instance {} num_dispatched_requests: {}".format(instance_id, num_requests)) + logger.info("instance {} num_dispatched_requests: {}".format(instance_id, num_requests)) return dispatch_instance_id def update_instance_infos(self, diff --git a/llumnix/global_scheduler/migration_policy.py b/llumnix/global_scheduler/migration_policy.py index c917cce7..eafe5cf3 100644 --- a/llumnix/global_scheduler/migration_policy.py +++ b/llumnix/global_scheduler/migration_policy.py @@ -22,10 +22,10 @@ logger = init_logger(__name__) + class PairMigrationConstraints(str, Enum): """Target of Migration.""" NO_CONSTRAINTS = "NO_CONSTRAINTS" - # Enable the prefill-decoding disaggregration. DECODING_2_DECODING = "DECODING_2_DECODING" PREFILL_2_DECODING = "PREFILL_2_DECODING" diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index 7607d88a..3c862f8a 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -21,9 +21,9 @@ logger = init_logger(__name__) + class InstanceType(str, Enum): NO_CONSTRAINTS = "NO_CONSTRAINTS" - # Specific to Prefill-Decoding disaggregation. PREFILL = "PREFILL" DECODE = "DECODE" diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index b21d45d7..60c4b593 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -22,8 +22,8 @@ def __init__( max_stages: int, migration_backend_init_timeout: float, migration_backend_transfer_type: str = "", - migration_backend_server_address: str = "", - migration_backend_kvtransfer_naming_url: str = "", + grpc_migration_backend_server_address: str = "", + kvtransfer_migration_backend_naming_url: str = "", ) -> None: self.request_migration_policy = request_migration_policy self.migration_backend = migration_backend @@ -33,8 +33,8 @@ def __init__( self.last_stage_max_blocks = last_stage_max_blocks self.max_stages = max_stages self.migration_backend_init_timeout = migration_backend_init_timeout - self.migration_backend_server_address = migration_backend_server_address - self.migration_backend_kvtransfer_naming_url = migration_backend_kvtransfer_naming_url + self.grpc_migration_backend_server_address = grpc_migration_backend_server_address + self.kvtransfer_migration_backend_naming_url = kvtransfer_migration_backend_naming_url class GlobalSchedulerConfig: def __init__( diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index d656ed82..85c14ba1 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -17,13 +17,13 @@ import time import ray -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.placement_group import PlacementGroup from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState -from llumnix.backends.utils import init_backend_engine +from llumnix.backends.utils import init_backend_engine, get_engine_world_size from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler from llumnix.server_info import ServerInfo @@ -40,22 +40,23 @@ class Llumlet: def __init__(self, instance_id: str, + placement_group: PlacementGroup, request_output_queue_type: QueueType, - backend_type: BackendType, migration_config: MigrationConfig, - placement_group: PlacementGroup, - *args, - **kwargs) -> None: + backend_type: BackendType, + engine_args, + profiling_result_file_path: str = None) -> None: try: + logger.info("Llumlet backend type: {}".format(backend_type)) self.instance_id = instance_id - self.instance_name = get_instance_name(instance_id) - self.backend_engine: BackendInterface = init_backend_engine(self.instance_id, + self.actor_name = get_instance_name(instance_id) + self.backend_engine: BackendInterface = init_backend_engine(instance_id, + placement_group, request_output_queue_type, - backend_type, migration_config, - placement_group, - *args, - **kwargs) + backend_type, + engine_args, + profiling_result_file_path) self.migration_coordinator = MigrationCoordinator(self.backend_engine, migration_config.last_stage_max_blocks, migration_config.max_stages) @@ -66,62 +67,51 @@ def __init__(self, asyncio.create_task(self._check_engine_state_loop()) # pylint: disable=broad-except except Exception as e: - logger.error("Failed to initialize llumlet: {}".format(e)) + logger.error("failed to initialize Llumlet: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) + raise @classmethod def from_args(cls, - request_output_queue_type: QueueType, instance_id: str, - backend_type: BackendType, - world_size: int, - migration_config: MigrationConfig, placement_group: PlacementGroup, - *args, - **kwargs): + request_output_queue_type: QueueType, + migration_config: MigrationConfig, + backend_type: BackendType, + engine_args, + profiling_result_file_path: str = None): try: - assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM, backend_type.BLADELLM], \ + assert backend_type in [backend_type.VLLM, backend_type.BLADELLM, backend_type.SIM_VLLM], \ f'unimplemented backend {backend_type}' num_gpus = 0 if backend_type == backend_type.BLADELLM: + world_size = get_engine_world_size(engine_args, backend_type) num_gpus = world_size - instance_name = get_instance_name(instance_id) - if backend_type in [backend_type.VLLM, backend_type.BLADELLM]: - llumlet_class = ray.remote(num_cpus=1, - num_gpus=num_gpus, - name=instance_name, - namespace='llumnix', - max_concurrency=4, - lifetime="detached")(cls).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_bundle_index=0, - placement_group_capture_child_tasks=True, - ) - ) - else: # backend_type == backend_type.SIM_VLLM: - llumlet_class = ray.remote(num_cpus=1, - num_gpus=num_gpus, - name=instance_name, - namespace='llumnix', - max_concurrency=4, - lifetime="detached")(cls).options( - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), - soft=False, - ) + # TODO(s5u13b): Check the max_concurrency. + llumlet_class = ray.remote(num_cpus=1, + num_gpus=num_gpus, + name=get_instance_name(instance_id), + namespace='llumnix', + max_concurrency=4, + lifetime="detached")(cls).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=0, + placement_group_capture_child_tasks=True ) + ) llumlet = llumlet_class.remote(instance_id, + placement_group, request_output_queue_type, - backend_type, migration_config, - placement_group, - *args, - **kwargs) + backend_type, + engine_args, + profiling_result_file_path) # pylint: disable=broad-except except Exception as e: - logger.error("Failed to initialize llumlet: {}".format(e)) + logger.error("failed to initialize Llumlet: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) + raise return llumlet @@ -129,11 +119,11 @@ async def _check_engine_state_loop(self): while True: await asyncio.sleep(CHECK_ENGINE_STATE_INTERVAL) if self.backend_engine.state == EngineState.CRASHED: - logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id)) + logger.error("Llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id)) # pylint: disable=protected-access self.backend_engine._stop_event.set() await asyncio.sleep(0) - self_actor = ray.get_actor(self.instance_name) + self_actor = ray.get_actor(self.actor_name) ray.kill(self_actor) async def migrate_out(self, dst_instance_name: str) -> List[str]: diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index bc356f48..dfebf828 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -89,7 +89,8 @@ async def _migrate_out_multistage(self, migrate_out_request: LlumnixRequest) -> "MigrationStatus": """Migrate out requests to a specified instance, return migrated request id. Args: - migrate_in_ray_actor: instance actor name, used to get ray actor handle + migrate_in_ray_actor: instance actor name, used to get ray actor handle. + migrate_out_request: request to migrate out. """ try: stage_count = 0 diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index d6c7dac5..15bb3c85 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -16,6 +16,7 @@ from llumnix.server_info import ServerInfo + class RequestInferenceType(str, Enum): PREFILL = "prefill" DECODE = "decode" diff --git a/llumnix/llm_engine_manager.py b/llumnix/manager.py similarity index 56% rename from llumnix/llm_engine_manager.py rename to llumnix/manager.py index 83f89c37..8b0d63f4 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/manager.py @@ -14,12 +14,15 @@ import asyncio import time import csv +import copy import os from typing import Dict, List, Tuple, Union, Iterable from collections import defaultdict import traceback from functools import partial import ray +from ray.util.state import list_placement_groups, list_actors +from ray.util.placement_group import PlacementGroup from llumnix.llumlet.llumlet import Llumlet from llumnix.logger import init_logger @@ -27,70 +30,92 @@ from llumnix.global_scheduler.migration_scheduler import PairMigrationConstraints from llumnix.global_scheduler.migration_filter import CustomFilter from llumnix.instance_info import InstanceInfo -from llumnix.internal_config import GlobalSchedulerConfig -from llumnix.arg_utils import EngineManagerArgs -from llumnix.backends.profiling import ProfilingDatabase +from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs from llumnix.server_info import ServerInfo from llumnix.backends.backend_interface import BackendType -from llumnix.utils import (random_uuid, - clear_gloo_backend_state, - remove_placement_group, - get_instance_name, - INSTANCE_NAME_PREFIX) +from llumnix.utils import (random_uuid, clear_gloo_backend_state, remove_placement_group, + get_instance_name, get_manager_name, INSTANCE_NAME_PREFIX, + SERVER_NAME_PREFIX, get_placement_group_name, run_async_func_sync, + kill_server, kill_instance, initialize_placement_group, + get_server_name, get_actor_data_from_ray_internal_kv, + put_actor_data_to_ray_internal_kv) +from llumnix.entrypoints.utils import LaunchMode +from llumnix.backends.utils import get_engine_world_size from llumnix.queue.queue_type import QueueType -from llumnix.backends.utils import initialize_placement_group +from llumnix.entrypoints.vllm.api_server_actor import FastAPIServerActor logger = init_logger(__name__) -MANAGER_ACTOR_NAME = 'manager' -CLEAR_REQUEST_INSTANCE_INTERVAL = 3600 -NO_INSTANCE_RETRY_INTERVAL = 5.0 -WAIT_ALL_MIGRATIONS_DONE_INTERVAL = 1.0 +CLEAR_REQUEST_INSTANCE_INTERVAL = 600.0 +NO_INSTANCE_RETRY_INTERVAL = 0.1 +WAIT_ALL_MIGRATIONS_DONE_INTERVAL = 0.1 +AUTO_SCALE_UP_INTERVAL = 1.0 +WAIT_PLACEMENT_GROUP_TIMEOUT = 5.0 +CHECK_DEPLOYMENT_STATES_INTERVAL = 30.0 +WATCH_DEPLOYMENT_INTERVAL = 40.0 -# TODO(s5u13b): Fix the logger when manager failover. # TODO(s5u13b): Handle exception of ray operations. +# TODO(s5u13b): Add exeception handling wrapper. +# TODO(s5u13b): Reorganize constant variables. -class LLMEngineManager: +class Manager: def __init__(self, - engine_manager_args: EngineManagerArgs, - global_scheduler_config: GlobalSchedulerConfig, + manager_args: ManagerArgs, work_dir: str, - log_requests: bool = True, - profiling_database: ProfilingDatabase = None) -> None: + entrypoints_args: EntrypointsArgs = None, + engine_args = None, + launch_args: LaunchArgs = None + ) -> None: os.chdir(work_dir) - self.actor_name = MANAGER_ACTOR_NAME - self.engine_manager_args = engine_manager_args - self.profiling_database = profiling_database + self.actor_name = get_manager_name() + self.manager_args = manager_args + # engine_args and entrypoints_args are used in global deployment. + self.entrypoints_args = entrypoints_args + self.engine_args = engine_args + self.launch_args = launch_args + + # launch args + if launch_args is not None: + self.launch_mode: LaunchMode = launch_args.launch_mode + self.backend_type: BackendType = launch_args.backend_type + + # migration args + self.enable_migration = manager_args.enable_migration + self.pair_migration_frequency = manager_args.pair_migration_frequency + self.enable_pd_disagg = manager_args.enable_pd_disagg + + # scaling args + self.enable_scaling = manager_args.enable_scaling + self.max_instances = manager_args.max_instances + self.min_instances = manager_args.min_instances + self.scaling_interval = manager_args.scaling_interval + self.scaling_policy = manager_args.scaling_policy + self.scale_up_threshold = manager_args.scale_up_threshold + self.scale_down_threshold = manager_args.scale_down_threshold + + self.polling_interval = manager_args.polling_interval + + global_scheduler_config = manager_args.create_global_scheduler_config() + self.global_scheduler = GlobalScheduler(global_scheduler_config) - self.log_requests = log_requests + # log args + self.log_requests = not manager_args.disable_log_requests_manager + self.log_instance_info = manager_args.log_instance_info + if self.log_instance_info: + self._init_instance_info_csv(manager_args) + self.instance_last_logged_empty = {} + # instance states self.num_instances = 0 - self.enable_migration = engine_manager_args.enable_migration - self.enable_scaling = engine_manager_args.enable_scaling - self.max_instances = engine_manager_args.max_instances - self.min_instances = engine_manager_args.min_instances - - self.enable_pd_disagg = global_scheduler_config.enable_pd_disagg - self.instances: Dict[str, Llumlet] = {} self.instance_migrating: Dict[str, bool] = {} self.pending_rebuild_migration_instances = 0 - self.global_scheduler = GlobalScheduler(global_scheduler_config) - - self.polling_interval = engine_manager_args.polling_interval - asyncio.create_task(self._update_instance_info_loop(self.polling_interval)) - - # args - self.pair_migration_frequency = engine_manager_args.pair_migration_frequency - self.scaling_interval = engine_manager_args.scaling_interval # request states self.request_instance: Dict[str, str] = {} - self.clear_request_intance_interval = CLEAR_REQUEST_INSTANCE_INTERVAL - asyncio.create_task(self._clear_request_instance_loop(self.clear_request_intance_interval)) - # migrate states + # migration states self.num_instance_info_updates = 0 self.migrating = False @@ -99,20 +124,26 @@ def __init__(self, self.scale_down_time = -1 self.scaling_up = False self.scaling_down = False - self.last_check_scale_time = time.time() + 100 - - self.log_instance_info = engine_manager_args.log_instance_info - if self.log_instance_info: - self._init_instance_info_csv(engine_manager_args) - self.instance_last_logged_empty = {} + self.last_check_scale_time = time.time() + # tasks # When manager starts, it automatically connects to all existing instances. - asyncio.run_coroutine_threadsafe(self._connect_to_instances(), asyncio.get_event_loop()) + run_async_func_sync(self._connect_to_instances()) + asyncio.create_task(self._update_instance_info_loop(self.polling_interval)) + asyncio.create_task(self._clear_request_instance_loop(CLEAR_REQUEST_INSTANCE_INTERVAL)) + + value = get_actor_data_from_ray_internal_kv("manager", "port_offset") + self.port_offset = 0 if value is None else int(value) + if hasattr(self, "launch_mode") and self.launch_mode == LaunchMode.GLOBAL: + assert self.entrypoints_args is not None and self.engine_args is not None + self.last_timeout_instance_id = None + asyncio.create_task(self._auto_scale_up_loop(AUTO_SCALE_UP_INTERVAL)) + asyncio.create_task(self._check_deployment_states_loop(CHECK_DEPLOYMENT_STATES_INTERVAL)) async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: - logger.info("[generate] no instance available temporarily, sleep {}s, " - "and retry generate request {} again".format(NO_INSTANCE_RETRY_INTERVAL, request_id)) + logger.warning("[generate] no instance available temporarily, sleep {}s, " + "and regenerate request {}".format(NO_INSTANCE_RETRY_INTERVAL, request_id)) await asyncio.sleep(NO_INSTANCE_RETRY_INTERVAL) instance_id, request_expected_steps = self.global_scheduler.dispatch() @@ -121,7 +152,7 @@ async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwar server_info.request_timestamps.manager_generate_timestamp = time.time() await self.instances[instance_id].generate.remote(request_id, server_info, request_expected_steps, *args, **kwargs) if self.log_requests: - logger.info("[generate] received request {}.".format(request_id)) + logger.info("[generate] manager received request {}".format(request_id)) logger.info("[generate] dispath request {} to instance {}".format(request_id, instance_id)) self.request_instance[request_id] = instance_id except (ray.exceptions.RayActorError, KeyError): @@ -159,30 +190,6 @@ def abort_done_callback(instance_id: str, request_ids: List[str], fut): tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) - async def _get_request_instance(self) -> None: - def get_request_instance_done_callback(instance_id: str, fut): - ret = fut.result()[0] - if not isinstance(ret, ray.exceptions.RayActorError): - instance_requests.append(ret) - instance_ids.append(instance_id) - else: - logger.info("[_get_request_instance] instance {} is dead".format(instance_id)) - self.scale_down(instance_id) - - instance_requests = [] - instance_ids = [] - tasks = [] - for instance_id, instance_actor_handle in self.instances.items(): - task = asyncio.gather(instance_actor_handle.get_instance_info.remote(), return_exceptions=True) - task.add_done_callback(partial(get_request_instance_done_callback, instance_id)) - tasks.append(task) - await asyncio.gather(*tasks, return_exceptions=True) - logger.info("[_get_request_instance] instance_ids: {}".format(instance_ids)) - logger.info("[_get_request_instance] instance_requests: {}".format(instance_requests)) - for (instance_id, requests) in zip(instance_ids, instance_requests): - for request_id in requests: - self.request_instance[request_id] = instance_id - async def _update_instance_info_loop(self, interval: float) -> None: def update_instance_info_done_callback(instance_id: str, fut): ret = fut.result()[0] @@ -191,9 +198,8 @@ def update_instance_info_done_callback(instance_id: str, fut): instance_infos.append(ret) self.global_scheduler.update_instance_infos([ret]) else: + logger.info("[_update_instance_info_loop] instance {} is dead".format(instance_id)) self.scale_down(instance_id) - logger.info("[_update_instance_info_loop] dead instances: {}.".format(ret)) - logger.info("[_update_instance_info_loop] dead instances: {}.".format(self.instances)) while True: try: @@ -201,7 +207,7 @@ def update_instance_info_done_callback(instance_id: str, fut): tasks = [] instance_infos = [] for instance_id, instance in self.instances.items(): - # Use asyncio.gather to wrap ray remote call to add done callback. + # Use asyncio.gather to wrap ray remote call to add done callback, asyncio.create_task will get error. task = asyncio.gather(instance.get_instance_info.remote(), return_exceptions=True) task.add_done_callback(partial(update_instance_info_done_callback, instance_id)) tasks.append(task) @@ -218,13 +224,6 @@ def update_instance_info_done_callback(instance_id: str, fut): logger.error("[_update_instance_info_loop] unexpected exception occurs: {}".format(e)) logger.error("[_update_instance_info_loop] exception traceback: {}".format(traceback.format_exc())) - async def _clear_request_instance_loop(self, interval: float): - await self._get_request_instance() - # Clear the request_instance at a certain interval to prevent memory leaking. - while True: - await asyncio.sleep(interval) - self.request_instance = {} - async def _push_migrations(self) -> None: if self.enable_pd_disagg: asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING)) @@ -276,7 +275,6 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - self.instance_migrating[migrate_out_instance_id] = True self.instance_migrating[migrate_in_instance_id] = True migrate_in_instance_name = get_instance_name(migrate_in_instance_id) - # Use asyncio.gather to wrap ray remote call to add done callback. task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name), return_exceptions=True) task.add_done_callback(partial(migrate_done_callback_wrapper, migrate_instance_pair)) @@ -287,12 +285,54 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - logger.error("[_migrate] unexpected exception occurs: {}".format(e)) logger.error("[_migrate] exception traceback: {}".format(traceback.format_exc())) - async def rebuild_migrate_backend(self) -> None: + async def _auto_scale_up_loop(self, interval: float) -> None: + while True: + try: + new_pg = None + if self.last_timeout_instance_id is not None: + last_timeout_pg_name = get_placement_group_name(self.last_timeout_instance_id) + last_timeout_pg_states = list_placement_groups(filters=[("name", "=", last_timeout_pg_name)]) + if len(last_timeout_pg_states) > 0: + new_instance_id = self.last_timeout_instance_id + # pending, created(without server and instance) or rescheduling + new_pg = ray.util.get_placement_group(last_timeout_pg_name) + # reset + self.last_timeout_instance_id = None + pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) + pending_pg_states.extend(list_placement_groups(filters=[("state", "=", "RESCHEDULING")])) + for pending_pg_state in pending_pg_states: + instance_id = pending_pg_state["name"].split("_")[-1] + if new_pg is not None and instance_id == new_instance_id: + continue + self.scale_down(instance_id) + if new_pg is None: + new_instance_id = random_uuid() + new_pg = self._init_placement_group(get_placement_group_name(new_instance_id), self.engine_args, self.backend_type, + init_server=True, block=False) + try: + await asyncio.wait_for(new_pg.ready(), WAIT_PLACEMENT_GROUP_TIMEOUT) + except asyncio.TimeoutError: + logger.info("[_auto_scale_up_loop] waiting for new placement group ready timeout") + # After timeout, the new placement group might be pending, + # created(without server and instance), rescheduling. + self.last_timeout_instance_id = new_instance_id + await asyncio.sleep(interval) + continue + self._init_server_and_instance(new_instance_id, new_pg) + logger.info("[_auto_scale_up_loop] deploy server and instance to new placement group done, " + "instance_id: {}".format(new_instance_id)) + # pylint: disable=broad-except + except Exception as e: + logger.error("[_auto_scale_up_loop] unexpected exception occurs: {}".format(e)) + logger.error("[_auto_scale_up_loop] exception traceback: {}".format(traceback.format_exc())) + + # TODO(KuilongCui): Add comments for this function. + async def _rebuild_migration_backend(self) -> None: # Wait for all instances to finish migration while any(self.instance_migrating.values()): await asyncio.sleep(WAIT_ALL_MIGRATIONS_DONE_INTERVAL) - # During rebuilding migration backend, disable migrate + # During rebuilding migration backend, disable migration. origin_config = self.enable_migration self.enable_migration = False @@ -307,8 +347,8 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): if isinstance(ret, ray.exceptions.RayActorError): dead_instances.add(instance_name) if len(dead_instances) > 0: - self.scale_down(dead_instances, rebuild_migrate_backend=False) - if self.engine_manager_args.migration_backend == 'gloo': + self.scale_down(dead_instances, rebuild_migration_backend=False) + if self.manager_args.migration_backend == 'gloo': clear_gloo_backend_state() return dead_instances @@ -316,7 +356,7 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): pending_task = self.pending_rebuild_migration_instances group_name = None - if self.engine_manager_args.migration_backend == 'gloo': + if self.manager_args.migration_backend == 'gloo': clear_gloo_backend_state() while len(alive_instances) > 0 and self.pending_rebuild_migration_instances > 0: @@ -342,16 +382,20 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): src_filter=lambda instance_info: instance_info.instance_id in alive_instances, dst_filter=lambda instance_info: instance_info.instance_id in alive_instances) - logger.info("[rebuild_migrate_backend] rebuild {} migrate backend done, group_name: {}, alive instance ({}): {}" - .format(self.engine_manager_args.migration_backend, group_name, len(alive_instances), alive_instances)) + logger.info("[rebuild_migration_backend] rebuild {} migration backend done, group_name: {}, alive instance ({}): {}" + .format(self.manager_args.migration_backend, group_name, len(alive_instances), alive_instances)) # Restore migrate config self.enable_migration = origin_config - def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles: List["ray.actor.ActorHandle"]) -> None: + def scale_up(self, + instance_id: Union[str, Iterable[str]], + instance_actor_handle: Union["ray.actor.ActorHandle", List["ray.actor.ActorHandle"]]) -> None: if isinstance(instance_id, str): instance_id = [instance_id,] + instance_actor_handle = [instance_actor_handle,] instance_ids = list(instance_id) + instance_actor_handles = list(instance_actor_handle) indeed_update = False no_pending_instance = (self.pending_rebuild_migration_instances == 0) @@ -359,7 +403,7 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles for idx, ins_id in enumerate(instance_ids): if ins_id not in self.instances: indeed_update = True - self.instances[ins_id] = llumlet_actor_handles[idx] + self.instances[ins_id] = instance_actor_handles[idx] self.instance_migrating[ins_id] = False if self.log_instance_info: self.instance_last_logged_empty[ins_id] = False @@ -369,15 +413,15 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles # When scaling up, we need to rebuild the migration backend. But if initially self.pending_rebuild_migration_instances != 0, # a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes - # caused by this scale-up (see rebuild_migrate_backend for details). Therefore, we simply return in this case. Specifically, - # for RPC, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. - if self.enable_migration and self.engine_manager_args.migration_backend in ['gloo', 'nccl'] \ + # caused by this scale-up (see rebuild_migration_backend for details). Therefore, we simply return in this case. + # Specifically, for RayRPC migration backend, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. + if self.enable_migration and self.manager_args.migration_backend in ['gloo', 'nccl'] \ and indeed_update and no_pending_instance: - asyncio.create_task(self.rebuild_migrate_backend()) + asyncio.create_task(self._rebuild_migration_backend()) return self.num_instances - def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_backend: bool = True) -> None: + def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_backend: bool = True) -> None: if isinstance(instance_id, str): instance_id = [instance_id,] instance_ids = list(instance_id) @@ -386,37 +430,44 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac no_pending_instance = self.pending_rebuild_migration_instances == 0 for ins_id in instance_ids: + self._clear_instance_ray_resources(ins_id) if ins_id in self.instances: indeed_update = True if ins_id in self.instances: del self.instances[ins_id] else: - logger.warning("[scale_down] instance {} is not in self.instances".format(ins_id)) + logger.debug("[scale_down] instance {} is not in self.instances".format(ins_id)) if ins_id in self.instance_migrating: del self.instance_migrating[ins_id] else: - logger.warning("[scale_down] instance {} is not in self.instance_migrating".format(ins_id)) - if not remove_placement_group(ins_id): - logger.warning("[scale_down] failed to remove placement group of instance {}".format(ins_id)) + logger.debug("[scale_down] instance {} is not in self.instance_migrating".format(ins_id)) if self.log_instance_info: if ins_id in self.instance_last_logged_empty: del self.instance_last_logged_empty[ins_id] else: - logger.warning("[scale_down] instance {} is not in self.instance_last_logged_empty".format(ins_id)) + logger.debug("[scale_down] instance {} is not in self.instance_last_logged_empty".format(ins_id)) self.pending_rebuild_migration_instances += 1 self.global_scheduler.scale_down(instance_ids) self.num_instances = len(self.instances) - if self.enable_migration and self.engine_manager_args.migration_backend in ['gloo', 'nccl']: + if self.enable_migration and self.manager_args.migration_backend in ['gloo', 'nccl']: if len(self.instances) == 0: self.pending_rebuild_migration_instances = 0 - if self.engine_manager_args.migration_backend == 'gloo': + if self.manager_args.migration_backend == 'gloo': clear_gloo_backend_state() - elif indeed_update and no_pending_instance and rebuild_migrate_backend: - asyncio.create_task(self.rebuild_migrate_backend()) + elif indeed_update and no_pending_instance and rebuild_migration_backend: + asyncio.create_task(self._rebuild_migration_backend()) return self.num_instances + def _clear_instance_ray_resources(self, instance_id: str): + if not remove_placement_group(instance_id): + logger.debug("[clear_instance_ray_resources] failed to remove placement group {}".format(instance_id)) + if not kill_server(instance_id): + logger.debug("[clear_instance_ray_resources] failed to kill server {}".format(instance_id)) + if not kill_instance(instance_id): + logger.debug("[clear_instance_ray_resources] failed to kill instance {}".format(instance_id)) + async def _connect_to_instances(self): def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: "ray.actor.ActorHandle", fut): ret = fut.result()[0] @@ -425,8 +476,7 @@ def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: scale_up_instance_actor_handles.append(instance_actor_handle) logger.info("[_connect_to_instances] connect to instance {}.".format(instance_id)) else: - logger.info("[_connect_to_instances] connect to instance {} abort, " - "which may be not ready or alive, err: {}".format(instance_id, e)) + logger.warning("[_connect_to_instances] connect to instance {} failed, exception: {}".format(instance_id, ret)) # Must set True despite set namespance to llumnix. actor_names_dict = ray.util.list_named_actors(all_namespaces=True) @@ -446,6 +496,145 @@ def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: # The only function that can add instance actor handles to manager. self.scale_up(scale_up_instance_ids, scale_up_instance_actor_handles) + @classmethod + def from_args(cls, + manager_args: ManagerArgs, + entrypoints_args: EntrypointsArgs = None, + engine_args = None, + launch_args: LaunchArgs = None, + ) -> "Manager": + manager_class = ray.remote(num_cpus=1, + max_restarts=-1, + name=get_manager_name(), + namespace="llumnix", + lifetime="detached")(cls) + manager = manager_class.remote(manager_args, + os.getcwd(), + entrypoints_args, + engine_args, + launch_args) + + return manager + + def _init_placement_group(self, + placement_group_name: str, + engine_args, + backend_type: BackendType, + init_server: bool = False, + block: bool = True) -> PlacementGroup: + if not BackendType.is_sim_backend(backend_type): + # num_cpus=3, for Llumlet + AsyncPutQueueActor + ProxyActor + # num_gpus=world_size, for world_size Workers + world_size = get_engine_world_size(engine_args, backend_type) + placement_group = initialize_placement_group(placement_group_name, + num_cpus=3+int(init_server), num_gpus=world_size, detached=True, block=block) + else: + # num_cpus=1, for Llumlet + AsyncPutQueueActor + placement_group = initialize_placement_group(placement_group_name, + num_cpus=2+int(init_server), num_gpus=0, detached=True, block=block) + + return placement_group + + def _init_server(self, + server_name: str, + placement_group: PlacementGroup, + entrypoints_args: EntrypointsArgs) -> FastAPIServerActor: + entrypoints_args = copy.deepcopy(entrypoints_args) + if self.manager_args.enable_port_increment: + entrypoints_args.port += self.port_offset + entrypoints_args.request_output_queue_port += self.port_offset + self.port_offset += 1 + put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset) + fastapi_server = FastAPIServerActor.from_args(server_name, placement_group, entrypoints_args) + return fastapi_server + + def _init_instance(self, + instance_id: str, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, + backend_type: BackendType, + engine_args + ) -> Tuple[str, Llumlet]: + instance = Llumlet.from_args( + instance_id, + placement_group, + request_output_queue_type, + self.manager_args.create_migration_config(), + backend_type, + engine_args, + self.manager_args.profiling_result_file_path) + + return instance + + def init_instances(self, + request_output_queue_type: QueueType, + backend_type: BackendType, + engine_args + ) -> Tuple[List[str], List[Llumlet]]: + instance_ids: List[str] = [] + instances: List[Llumlet] = [] + for _ in range(self.manager_args.initial_instances): + instance_id = random_uuid() + placement_group = self._init_placement_group(get_placement_group_name(instance_id), engine_args, backend_type) + instance = self._init_instance(instance_id, placement_group, request_output_queue_type, backend_type, engine_args) + instance_ids.append(instance_id) + instances.append(instance) + + self.scale_up(instance_ids, instances) + + return instance_ids, instances + + def _init_server_and_instance(self, + instance_id: str, + placement_group: PlacementGroup): + async def done_scale_up(): + try: + manager = ray.get_actor(get_manager_name(), namespace="llumnix") + await instance.is_ready.remote() + await server.run.remote(manager, instance_id, instance) + self.scale_up(instance_id, instance) + # pylint: disable=broad-except + except Exception as e: + logger.error("[_init_server_and_instance] unexpected exception occurs: {}".format(e)) + logger.error("[_init_server_and_instance] exception traceback: {}".format(traceback.format_exc())) + self._clear_instance_ray_resources(instance_id) + + request_output_queue_type = QueueType(self.entrypoints_args.request_output_queue_type) + instance = self._init_instance(instance_id, placement_group, request_output_queue_type, self.backend_type, self.engine_args) + server = self._init_server(get_server_name(instance_id), placement_group, self.entrypoints_args) + asyncio.create_task(done_scale_up()) + + async def _check_deployment_states_loop(self, interval: float) -> None: + async def watch_deployment(instance_id: str): + await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL) + curr_pgs, curr_servers, curr_instances = self.get_curr_deployment() + if instance_id in curr_pgs and (instance_id not in curr_servers or instance_id not in curr_instances): + logger.warning("[_check_deployment_states_loop] instance {} deployment states incorrect, " + "states: (pg {}, server {}, instance {})" + .format(instance_id, instance_id in curr_pgs, instance_id in curr_servers, instance_id in curr_instances)) + self.scale_down(instance_id) + + while True: + try: + curr_pgs, curr_servers, curr_instances = self.get_curr_deployment() + assert len(curr_pgs) >= max(len(curr_servers), len(curr_instances)) + tasks = [] + for instance_id in curr_pgs: + if instance_id not in curr_servers or instance_id not in curr_instances: + tasks.append(asyncio.create_task(watch_deployment(instance_id))) + await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.sleep(interval) + # pylint: disable=broad-except + except Exception as e: + logger.error("[_check_deployment_states_loop] unexpected exception occurs: {}".format(e)) + logger.error("[_check_deployment_states_loop] exception traceback: {}".format(traceback.format_exc())) + + async def is_ready(self) -> bool: + """Called by api server, return true when all the instances have been successfully created.""" + tasks = [instance.is_ready.remote() for instance in self.instances.values()] + is_ready_list = await asyncio.gather(*tasks) + return all(is_ready_list) + async def _check_instance_error(self, migrate_instance_pairs: Tuple[str, str]) -> List[bool]: def check_instance_error_done_callback(idx: int, instance_id: str, fut): ret = fut.result()[0] @@ -466,87 +655,61 @@ def check_instance_error_done_callback(idx: int, instance_id: str, fut): return results - @classmethod - def from_args(cls, - engine_manager_args: EngineManagerArgs, - profiling_database: ProfilingDatabase=None) -> "LLMEngineManager": - global_scheduler_config = engine_manager_args.create_global_scheduler_configs() - manager_class = ray.remote(num_cpus=0, - max_restarts=-1, - name=MANAGER_ACTOR_NAME, - namespace='llumnix', - lifetime="detached" - )(cls) - manager = manager_class.remote(engine_manager_args, - global_scheduler_config, - os.getcwd(), - log_requests=not engine_manager_args.disable_log_requests_manager, - profiling_database=profiling_database) + def get_curr_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, FastAPIServerActor], Dict[str, Llumlet]]: + curr_pgs: Dict[str, PlacementGroup] = {} + curr_servers: Dict[str, PlacementGroup] = {} + curr_instances: Dict[str, Llumlet] = {} - return manager + created_pg_states = list_placement_groups(filters=[("state", "=", "CREATED")]) + for created_pg_state in created_pg_states: + instance_id = created_pg_state["name"].split("_")[-1] + curr_pgs[instance_id] = ray.util.get_placement_group(created_pg_state["name"]) - def init_llumlets(self, - engine_args, - request_output_queue_type: QueueType, - backend_type: BackendType, - world_size: int, - *args, - **kwargs) -> Tuple[List[str], List[Llumlet]]: - engine_manager_args = self.engine_manager_args - instance_ids: List[str] = [] - llumlets: List[Llumlet] = [] - if 'instance_ids' in kwargs and kwargs['instance_ids'][0]: - instance_ids = kwargs['instance_ids'] - for _ in range(engine_manager_args.initial_instances): - instance_id = random_uuid() - if not engine_manager_args.profiling_result_file_path: - # num_cpus=3, for Llumlet + AsyncPutQueueActor + ProxyActor, num_gpus=world_size, for Workers - placement_group = initialize_placement_group(instance_id, num_cpus=3, num_gpus=world_size, detached=True) - llumlet = Llumlet.from_args( - request_output_queue_type, - instance_id, - backend_type, - world_size, - engine_manager_args.create_migration_config(), - placement_group, - engine_args, - *args, - **kwargs - ) - else: - assert backend_type == backend_type.VLLM, f'unimplemented backend SIM_{backend_type}' - # num_cpus=1, for Llumlet + AsyncPutQueueActor - logger.info("[init_llumlets] use simulator backend") - placement_group = initialize_placement_group(instance_id, num_cpus=2, num_gpus=0, detached=True) - llumlet = Llumlet.from_args( - request_output_queue_type, - instance_id, - BackendType.SIM_VLLM, - world_size, - engine_manager_args.create_migration_config(), - placement_group, - engine_args, - engine_manager_args.profiling_result_file_path, - *args, - **kwargs - ) - instance_ids.append(instance_id) - llumlets.append(llumlet) + alive_actor_states = list_actors(filters=[("state", "=", "ALIVE")]) + for alive_actor_state in alive_actor_states: + if alive_actor_state["name"].startswith(SERVER_NAME_PREFIX): + instance_id = alive_actor_state["name"].split("_")[-1] + curr_servers[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") + elif alive_actor_state["name"].startswith(INSTANCE_NAME_PREFIX): + instance_id = alive_actor_state["name"].split("_")[-1] + curr_instances[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") - return instance_ids, llumlets + return curr_pgs, curr_servers, curr_instances - def get_actor_name(self) -> str: - return self.actor_name + async def _get_request_instance(self) -> None: + def get_request_instance_done_callback(instance_id: str, fut): + ret = fut.result()[0] + if not isinstance(ret, ray.exceptions.RayActorError): + instance_requests.append(ret) + instance_ids.append(instance_id) + else: + logger.info("[_get_request_instance] instance {} is dead".format(instance_id)) + self.scale_down(instance_id) - async def is_ready(self) -> bool: - """Called by api server, return true when all the instances have been successfully created.""" - tasks = [llumlet.is_ready.remote() for llumlet in self.instances.values()] - is_ready_list = await asyncio.gather(*tasks) - return all(is_ready_list) + instance_requests = [] + instance_ids = [] + tasks = [] + for instance_id, instance_actor_handle in self.instances.items(): + task = asyncio.gather(instance_actor_handle.get_instance_info.remote(), return_exceptions=True) + task.add_done_callback(partial(get_request_instance_done_callback, instance_id)) + tasks.append(task) + await asyncio.gather(*tasks, return_exceptions=True) + logger.debug("[_get_request_instance] instance_ids: {}".format(instance_ids)) + logger.debug("[_get_request_instance] instance_requests: {}".format(instance_requests)) + for (instance_id, requests) in zip(instance_ids, instance_requests): + for request_id in requests: + self.request_instance[request_id] = instance_id + + async def _clear_request_instance_loop(self, interval: float): + await self._get_request_instance() + # Clear the request_instance at a certain interval to prevent memory leaking. + while True: + await asyncio.sleep(interval) + self.request_instance = {} - def _init_instance_info_csv(self, engine_manager_args: EngineManagerArgs) -> None: + def _init_instance_info_csv(self, manager_args: ManagerArgs) -> None: # pylint: disable=consider-using-with - self.instance_info_file = open(engine_manager_args.log_filename + '_instance.csv', 'w', encoding='utf-8') + self.instance_info_file = open(manager_args.log_filename + '_instance.csv', 'w', encoding='utf-8') self.instance_info_csv = csv.writer(self.instance_info_file) self.instance_info_csv.writerow([ 'timestamp', diff --git a/llumnix/queue/ray_queue_server.py b/llumnix/queue/ray_queue_server.py index 6cff2607..b8648157 100644 --- a/llumnix/queue/ray_queue_server.py +++ b/llumnix/queue/ray_queue_server.py @@ -52,4 +52,8 @@ async def run_server_loop(self): pass def cleanup(self): - pass + try: + ray.kill(self.queue) + # pylint: disable=broad-except, unused-variable + except Exception as e: + pass diff --git a/llumnix/utils.py b/llumnix/utils.py index 301fb565..08769ca8 100644 --- a/llumnix/utils.py +++ b/llumnix/utils.py @@ -12,13 +12,72 @@ # limitations under the License. import uuid +import asyncio +import threading +from typing import Any, Union import ray +from ray.util.placement_group import PlacementGroup +from ray.experimental.internal_kv import ( + _internal_kv_get, + _internal_kv_initialized, + _internal_kv_put, +) +from llumnix.logger import init_logger + +logger = init_logger(__name__) + +MANAGER_NAME = "manager" PLACEMENT_GROUP_NAME_PREFIX = "pg_" SERVER_NAME_PREFIX = "server_" INSTANCE_NAME_PREFIX = "instance_" +def initialize_placement_group( + placement_group_name: str, + num_cpus: int, + num_gpus: int, + detached: bool = False, + block: bool = True +) -> PlacementGroup: + """Initialize the distributed cluster probably with Ray. + + Args: + placement_group_name: The name of placement group. + num_cpus: The number of cpus in placement group. + num_cpus: The number of cpus in placement group. + detached: Whether the lifetime of the placement group being detached. + block: If True, the function will block until the placement group is ready. + + Returns: + `placement_group`. `placement_group` includes the specification + of the resources for each distributed worker. + """ + if ray is None: + raise ImportError( + "Ray is not installed. Please install Ray to use distributed " + "serving.") + + lifetime = "detached" if detached else None + + num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) + if num_gpus > num_gpus_in_cluster: + raise ValueError( + "The number of required GPUs exceeds the total number of " + "available GPUs in the cluster.") + # Create a new placement group + # bundle_0: Llumlet + AsyncPutQueueActor + ProxyActor, bundle_1: Workers + placement_group_specs = ([{"CPU": num_cpus}] + [{"GPU": 1}] * num_gpus) + current_placement_group = ray.util.placement_group( + placement_group_specs, "STRICT_PACK", name=placement_group_name, lifetime=lifetime) + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + if block: + ray.get(current_placement_group.ready(), timeout=1800) + + return current_placement_group + def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -44,20 +103,82 @@ def clear_gloo_backend_state(): # gloo_queue may not have been created yet; just ignore this error. pass +def get_manager_name() -> str: + return MANAGER_NAME + def get_placement_group_name(instance_id: str) -> str: return f"{PLACEMENT_GROUP_NAME_PREFIX}{instance_id}" +def get_server_name(instance_id: str) -> str: + return f"{SERVER_NAME_PREFIX}{instance_id}" + def get_instance_name(instance_id: str) -> str: return f"{INSTANCE_NAME_PREFIX}{instance_id}" def remove_placement_group(instance_id: str) -> bool: try: placement_group = ray.util.get_placement_group(get_placement_group_name(instance_id)) - if not placement_group: - return False # asynchronous api ray.util.remove_placement_group(placement_group) + logger.info("remove placement group {}".format(instance_id)) + # pylint: disable=broad-except + except Exception: + return False + return True + +def kill_server(instance_id: str) -> bool: + try: + server = ray.get_actor(get_server_name(instance_id), namespace="llumnix") + ray.kill(server) + logger.info("kill server {}".format(instance_id)) + # pylint: disable=broad-except + except Exception: + return False + return True + +def kill_instance(instance_id: str) -> bool: + try: + instance = ray.get_actor(get_instance_name(instance_id), namespace="llumnix") + ray.kill(instance) + logger.info("kill instance {}".format(instance_id)) # pylint: disable=broad-except except Exception: return False return True + +def run_async_func_sync(func): + def run_task(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + future = loop.create_task(func) + loop.run_until_complete(future) + loop.close() + thread = threading.Thread(target=run_task) + thread.start() + thread.join() + +def _make_key(actor_name: str, data_name: str): + """Generate a binary key for the given actor name and data. + + Args: + actor_name: The name of the actor + data_name: The data member of the actor + + Returns: + The key to use for storing a the value. + """ + return (actor_name.encode("ascii") + b"." + data_name.encode("ascii")) + +def get_actor_data_from_ray_internal_kv(actor_name: str, data_name: str) -> Union[str, None]: + value = None + if _internal_kv_initialized(): + value = _internal_kv_get(_make_key(actor_name, data_name)) + if value is not None: + value = value.decode() + logger.info("get {}.{} from ray internal key value store, value: {}".format(actor_name, data_name, value)) + return value + +def put_actor_data_to_ray_internal_kv(actor_name: str, data_name: str, value: Any): + if _internal_kv_initialized(): + _internal_kv_put(_make_key(actor_name, data_name), f"{value}".encode(), overwrite=True) + logger.debug("put {}.{} to ray internal key value store, value: {}".format(actor_name, data_name, value)) diff --git a/requirements/requirements_bladellm.txt b/requirements/requirements_bladellm.txt index 3c66e6c7..ecaa9301 100644 --- a/requirements/requirements_bladellm.txt +++ b/requirements/requirements_bladellm.txt @@ -1,4 +1,4 @@ -ray >= 2.9.0 +ray[default] >= 2.9.0 pyarrow # Required for Ray data. aiohttp pandas diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt index f9fbe6a6..8af54fbd 100644 --- a/requirements/requirements_vllm.txt +++ b/requirements/requirements_vllm.txt @@ -1,5 +1,5 @@ vllm == 0.4.2 -ray >= 2.9.0 +ray[default] >= 2.9.0 pyarrow # Required for Ray data. aiohttp scipy diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index 0719a524..5567db1a 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -23,7 +23,8 @@ # pylint: disable=unused-import from tests.conftest import ray_env from .utils import (generate_launch_command, generate_bench_command, to_markdown_table, - wait_for_llumnix_service_ready, shutdown_llumnix_service) + wait_for_llumnix_service_ready, shutdown_llumnix_service, + generate_serve_command) BENCH_TEST_TIMEOUT_MINS = 30 @@ -63,21 +64,36 @@ def get_markdown_data(key: str, head_name: str): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for simple benchmark") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model): - device_count = torch.cuda.device_count() +@pytest.mark.parametrize("launch_mode", ['global', 'local']) +async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch_mode): ip = "127.0.0.1" base_port = 37037 ip_ports = [] - for i in range(device_count): - port = base_port+i - ip_port = f"{ip}:{port}" - ip_ports.append(ip_port) - launch_command = generate_launch_command(result_filename=str(base_port+i)+".out", - launch_ray_cluster=False, - ip=ip, - port=port, - model=model) - subprocess.run(launch_command, shell=True, check=True) + if launch_mode == 'local': + device_count = torch.cuda.device_count() + for i in range(device_count): + port = base_port+i + ip_port = f"{ip}:{port}" + ip_ports.append(ip_port) + launch_command = generate_launch_command(result_filename=str(base_port+i)+".out", + launch_ray_cluster=False, + ip=ip, + port=port, + model=model) + subprocess.run(launch_command, shell=True, check=True) + else: # global + device_count = torch.cuda.device_count() + for i in range(device_count): + port = base_port+i + ip_port = f"{ip}:{port}" + ip_ports.append(ip_port) + serve_command = generate_serve_command(result_filename=str(base_port)+".out", + ip=ip, + port=base_port, + model=model) + # pylint: disable=subprocess-run-check + subprocess.run('ray start --head', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(serve_command, shell=True, check=True) wait_for_llumnix_service_ready(ip_ports) @@ -113,7 +129,8 @@ def run_bench_command(command): process.kill() assert False, "bench_test timed out after {} minutes.".format(BENCH_TEST_TIMEOUT_MINS) - with open("performance.txt", "w", encoding="utf-8") as f: - f.write(parse_log_file()) + if launch_mode == 'local': + with open("performance.txt", "w", encoding="utf-8") as f: + f.write(parse_log_file()) await asyncio.sleep(3) diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 7b454c2c..da71f32a 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -56,6 +56,41 @@ def generate_launch_command(result_filename: str = "", ) return command +def generate_serve_command(result_filename: str = "", + ip: str = "127.0.0.1", + port: int = 37000, + dispatch_policy: str = "load", + migration_backend = "gloo", + model = "facebook/opt-125m", + max_model_len: int = 4096, + log_instance_info: bool = False, + request_migration_policy: str = 'SR', + max_num_batched_tokens: int = 16000): + command = ( + f"RAY_DEDUP_LOGS=0 " + f"nohup python -u -m llumnix.entrypoints.vllm.serve " + f"--host {ip} " + f"--port {port} " + f"{'--log-filename manager ' if log_instance_info else ''}" + f"{'--log-instance-info ' if log_instance_info else ''}" + f"--enable-migration " + f"--model {model} " + f"--engine-use-ray " + f"--worker-use-ray " + f"--max-model-len {max_model_len} " + f"--dispatch-policy {dispatch_policy} " + f"--trust-remote-code " + f"--request-migration-policy {request_migration_policy} " + f"--migration-backend {migration_backend} " + f"--migration-buffer-blocks 32 " + f"--tensor-parallel-size 1 " + f"--request-output-queue-port {1234+port} " + f"--max-num-batched-tokens {max_num_batched_tokens} " + f"--enable-port-increment " + f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &" + ) + return command + def wait_for_llumnix_service_ready(ip_ports, timeout=120): start_time = time.time() while True: @@ -112,6 +147,7 @@ def generate_bench_command(ip_ports: str, def shutdown_llumnix_service_func(): subprocess.run('pkill -f llumnix.entrypoints.vllm.api_server', shell=True, check=False) subprocess.run('pkill -f benchmark_serving.py', shell=True, check=False) + subprocess.run('pkill -f llumnix.entrypoints.vllm.serve', shell=True, check=False) @pytest.fixture def shutdown_llumnix_service(): diff --git a/tests/unit_test/backends/vllm/test_llm_engine.py b/tests/unit_test/backends/vllm/test_llm_engine.py index 9b01c8af..86d5ff61 100644 --- a/tests/unit_test/backends/vllm/test_llm_engine.py +++ b/tests/unit_test/backends/vllm/test_llm_engine.py @@ -28,7 +28,7 @@ from llumnix.backends.vllm.sequence import LlumnixRequest from llumnix.queue.queue_type import QueueType from llumnix.server_info import ServerInfo -from llumnix.backends.utils import initialize_placement_group +from llumnix.utils import initialize_placement_group, get_placement_group_name from tests.conftest import ray_env from .utils import create_dummy_prompt, initialize_scheduler @@ -88,21 +88,26 @@ def test_llm_engine_process_model_outputs(): ret, _ = llm_engine._process_model_outputs(sampler_outputs, scheduled_seq_groups,[], metas) assert len(ret) == 1 -def test_llm_engine_from_engine_args(): +def test_llm_engine_from_engine_args(ray_env): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - llm_engine = MockEngine.from_engine_args(engine_args, request_output_queue_type=QueueType.RAYQUEUE, - instance_id="0", migration_config=None) + placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=3, num_gpus=1, detached=True) + llm_engine = MockEngine.from_engine_args(engine_args=engine_args, request_output_queue_type=QueueType.RAYQUEUE, + instance_id="0", migration_config=None, placement_group=placement_group) assert llm_engine.executor_class == LlumnixRayGPUExecutor +def test_llm_engine_from_engine_args_sim(ray_env): latency_data = LatencyMemData({},{},{}) - llm_engine = MockEngine.from_engine_args(engine_args, request_output_queue_type=QueueType.RAYQUEUE, - instance_id="0", migration_config=None, latency_mem=latency_data) + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=2, num_gpus=1, detached=True) + llm_engine = MockEngine.from_engine_args(engine_args=engine_args, request_output_queue_type=QueueType.RAYQUEUE, + instance_id="0", migration_config=None, latency_mem=latency_data, + placement_group=placement_group) assert llm_engine.executor_class == SimGPUExecutor def test_llm_engine_add_requset(ray_env): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - placement_group = initialize_placement_group(instance_id="0", num_cpus=3, num_gpus=1, detached=True) - llm_engine = LLMEngineLlumnix.from_engine_args(engine_args, + placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=3, num_gpus=1, detached=True) + llm_engine = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, request_output_queue_type=QueueType.RAYQUEUE, instance_id="0", placement_group=placement_group, diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index c0c808cb..4f28d753 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -16,7 +16,6 @@ from unittest.mock import MagicMock import pytest import ray -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from vllm import EngineArgs, SamplingParams from vllm.utils import random_uuid @@ -28,7 +27,7 @@ from llumnix.internal_config import MigrationConfig from llumnix.llumlet.request import RequestInferenceType, RequestStatus from llumnix.queue.queue_type import QueueType -from llumnix.backends.utils import initialize_placement_group +from llumnix.utils import initialize_placement_group, get_placement_group_name from tests.unit_test.queue.utils import request_output_queue_server # pylint: disable=unused-import @@ -46,15 +45,14 @@ ] def init_llumlet(request_output_queue_type, instance_id, migration_config, engine_args): - placement_group = initialize_placement_group(instance_id=instance_id, num_cpus=3, num_gpus=1, detached=True) + placement_group = initialize_placement_group(get_placement_group_name(instance_id), num_cpus=3, num_gpus=1, detached=True) llumlet = Llumlet.from_args( - request_output_queue_type, - instance_id, - BackendType.VLLM, - 1, - migration_config, - placement_group, - engine_args,) + instance_id=instance_id, + placement_group=placement_group, + request_output_queue_type=request_output_queue_type, + migration_config=migration_config, + backend_type=BackendType.VLLM, + engine_args=engine_args) return llumlet class MockBackendVLLM(BackendVLLM): @@ -70,7 +68,7 @@ def __init__(self): class MockLlumletDoNotSchedule(Llumlet): def __init__(self, *args, **kwargs): instance_id = kwargs["instance_id"] - placement_group = initialize_placement_group(instance_id=instance_id, num_cpus=3, num_gpus=1, detached=True) + placement_group = initialize_placement_group(get_placement_group_name(instance_id), num_cpus=3, num_gpus=1, detached=True) kwargs["placement_group"] = placement_group super().__init__(*args, **kwargs) # stop the schedule in engine step loop @@ -114,15 +112,13 @@ async def test_migration_correctness(ray_env, migration_backend, migration_reque request_output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(request_output_queue_type) asyncio.create_task(que.run_server_loop()) - scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=ray.get_runtime_context().get_node_id(), soft=False) llumlet_0 = init_llumlet(request_output_queue_type, "0", migration_config, engine_args) llumlet_1 = init_llumlet(request_output_queue_type, "1", migration_config, engine_args) llumlet_2: Llumlet = MockLlumletDoNotSchedule.options( name='instance_2', - namespace='llumnix', - scheduling_strategy=scheduling_strategy).remote( + namespace='llumnix').remote( instance_id="2", request_output_queue_type=request_output_queue_type, backend_type=BackendType.VLLM, diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index 8c2cb4fa..f6b1d50d 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -19,9 +19,8 @@ from vllm.engine.arg_utils import EngineArgs from llumnix.backends.vllm.worker import MigrationWorker -from llumnix.arg_utils import EngineManagerArgs -from llumnix.utils import random_uuid -from llumnix.backends.utils import initialize_placement_group +from llumnix.arg_utils import ManagerArgs +from llumnix.utils import random_uuid, initialize_placement_group, get_placement_group_name # pylint: disable=unused-import from tests.conftest import ray_env @@ -41,7 +40,7 @@ def get_gpu_cache(self): @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_migrate_cache(ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_buffer_blocks=3, migration_num_layers=5).create_migration_config() + migraiton_config = ManagerArgs(migration_buffer_blocks=3, migration_num_layers=5).create_migration_config() migraiton_config.migration_backend = backend worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config, @@ -59,7 +58,7 @@ def test_migrate_cache(ray_env, backend): ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) worker0_id = random_uuid() - placement_group0 = initialize_placement_group(instance_id=worker0_id, num_cpus=1, num_gpus=1, detached=True) + placement_group0 = initialize_placement_group(get_placement_group_name(worker0_id), num_cpus=1, num_gpus=1, detached=True) ray.get(worker0.execute_method.remote( 'init_migration', instance_id=worker0_id, @@ -68,7 +67,7 @@ def test_migrate_cache(ray_env, backend): placement_group=placement_group0)) worker1_id = random_uuid() - placement_group1 = initialize_placement_group(instance_id=worker1_id, num_cpus=1, num_gpus=1, detached=True) + placement_group1 = initialize_placement_group(get_placement_group_name(worker1_id), num_cpus=1, num_gpus=1, detached=True) ray.get(worker1.execute_method.remote( 'init_migration', instance_id=worker1_id, diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index 417be632..f71c5a95 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -12,7 +12,7 @@ from llumnix.backends.profiling import LatencyMemData from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType -from llumnix.backends.utils import initialize_placement_group +from llumnix.utils import initialize_placement_group, get_placement_group_name # pylint: disable=unused-import from tests.conftest import ray_env @@ -82,11 +82,11 @@ class DummyActor: def __init__(self): pass dummy_actor = ray.remote(num_cpus=1, - name="instance_0", - namespace='llumnix', - max_concurrency=4)(DummyActor) + name="instance_0", + namespace='llumnix', + max_concurrency=4)(DummyActor) dummy_actor = dummy_actor.remote() - placement_group = initialize_placement_group("0", num_cpus=2, num_gpus=0, detached=True) + placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=2, num_gpus=0, detached=True) sim_backend = MockBackendSim(instance_id="0", request_output_queue_type=request_output_queue_type, migration_config=migration_config, diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index fae20162..15e8e6d6 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -21,9 +21,9 @@ from vllm.config import EngineConfig from vllm.executor.ray_gpu_executor import RayWorkerWrapper -from llumnix.arg_utils import EngineManagerArgs +from llumnix.arg_utils import ManagerArgs from llumnix.utils import random_uuid -from llumnix.backends.utils import initialize_placement_group +from llumnix.utils import initialize_placement_group, get_placement_group_name # pylint: disable=unused-import from tests.conftest import ray_env @@ -60,7 +60,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_reserve_memory_for_migration(ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config = ManagerArgs(migration_buffer_blocks=1).create_migration_config() migration_config.migration_backend = backend worker = create_worker(rank=0, local_rank=0, engine_config=engine_config) ray.get(worker.execute_method.remote('init_device')) @@ -81,12 +81,12 @@ def test_reserve_memory_for_migration(ray_env, backend): @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_rebuild_migration_backend(ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config = ManagerArgs(migration_buffer_blocks=1).create_migration_config() migration_config.migration_backend = backend worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config) worker0_id = random_uuid() - placement_group0 = initialize_placement_group(instance_id=worker0_id, num_cpus=1, num_gpus=1, detached=True) + placement_group0 = initialize_placement_group(get_placement_group_name(worker0_id), num_cpus=1, num_gpus=1, detached=True) ray.get(worker0.execute_method.remote('init_device')) ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0)) ray.get(worker0.execute_method.remote( @@ -102,7 +102,7 @@ def test_rebuild_migration_backend(ray_env, backend): worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config) worker1_id = random_uuid() - placement_group1 = initialize_placement_group(instance_id=worker1_id, num_cpus=1, num_gpus=1, detached=True) + placement_group1 = initialize_placement_group(get_placement_group_name(worker1_id), num_cpus=1, num_gpus=1, detached=True) ray.get(worker1.execute_method.remote('init_device')) ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0)) ray.get(worker1.execute_method.remote( diff --git a/tests/unit_test/entrypoints/test_utils.py b/tests/unit_test/entrypoints/test_utils.py index 9705cb57..413906ee 100644 --- a/tests/unit_test/entrypoints/test_utils.py +++ b/tests/unit_test/entrypoints/test_utils.py @@ -15,14 +15,11 @@ import pytest import ray -from llumnix.arg_utils import EngineManagerArgs -from llumnix.entrypoints.setup import (get_ip_address, - launch_ray_cluster, - init_manager, - retry_manager_method_sync, - retry_manager_method_async) -from llumnix.llm_engine_manager import MANAGER_ACTOR_NAME +from llumnix.arg_utils import ManagerArgs +from llumnix.entrypoints.setup import launch_ray_cluster, init_manager +from llumnix.entrypoints.utils import get_ip_address, retry_manager_method_sync, retry_manager_method_async from llumnix.queue.utils import init_request_output_queue_server +from llumnix.utils import get_manager_name # pylint: disable=unused-import from tests.conftest import ray_env @@ -36,10 +33,10 @@ def test_launch_ray_cluster(): assert result.returncode == 0 def test_init_manager(ray_env): - engine_manager_args = EngineManagerArgs() - manager = init_manager(engine_manager_args) + manager_args = ManagerArgs() + manager = init_manager(manager_args) assert manager is not None - manager_actor_handle = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') + manager_actor_handle = ray.get_actor(get_manager_name(), namespace='llumnix') assert manager_actor_handle is not None assert manager == manager_actor_handle @@ -50,14 +47,14 @@ def test_init_zmq(ray_env): assert request_output_queue is not None def test_retry_manager_method_sync(ray_env): - engine_manager_args = EngineManagerArgs() - manager = init_manager(engine_manager_args) + manager_args = ManagerArgs() + manager = init_manager(manager_args) ret = retry_manager_method_sync(manager.is_ready.remote, 'is_ready') assert ret is True @pytest.mark.asyncio async def test_retry_manager_method_async(ray_env): - engine_manager_args = EngineManagerArgs() - manager = init_manager(engine_manager_args) + manager_args = ManagerArgs() + manager = init_manager(manager_args) ret = await retry_manager_method_async(manager.is_ready.remote, 'is_ready') assert ret is True diff --git a/tests/unit_test/entrypoints/vllm/api.py b/tests/unit_test/entrypoints/vllm/api.py new file mode 100644 index 00000000..a72f4fd6 --- /dev/null +++ b/tests/unit_test/entrypoints/vllm/api.py @@ -0,0 +1,14 @@ +from fastapi.responses import JSONResponse, Response +import ray + +import llumnix.entrypoints.vllm.api_server + +manager = None +llumnix_client = llumnix.entrypoints.vllm.api_server.llumnix_client +app = llumnix.entrypoints.vllm.api_server.app + + +@app.get("/stats") +def stats() -> Response: + """Get the statistics of the engine.""" + return JSONResponse(ray.get(manager.testing_stats.remote())) diff --git a/tests/unit_test/entrypoints/vllm/api_server_manager.py b/tests/unit_test/entrypoints/vllm/api_server.py similarity index 58% rename from tests/unit_test/entrypoints/vllm/api_server_manager.py rename to tests/unit_test/entrypoints/vllm/api_server.py index bafbd599..78e6294a 100644 --- a/tests/unit_test/entrypoints/vllm/api_server_manager.py +++ b/tests/unit_test/entrypoints/vllm/api_server.py @@ -14,26 +14,21 @@ import argparse import uvicorn import ray -from fastapi.responses import JSONResponse, Response from vllm.outputs import CompletionOutput, RequestOutput import llumnix.entrypoints.vllm.api_server -import llumnix.llm_engine_manager -from llumnix.arg_utils import EngineManagerArgs +import llumnix.manager from llumnix.server_info import ServerInfo, RequestTimestamps -from llumnix.utils import random_uuid +from llumnix.utils import random_uuid, get_manager_name from llumnix.queue.utils import init_request_output_queue_server, init_request_output_queue_client, QueueType -from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.entrypoints.setup import EntrypointsContext from llumnix.entrypoints.vllm.client import LlumnixClientVLLM -app = llumnix.entrypoints.vllm.api_server.app -manager = None -MANAGER_ACTOR_NAME = llumnix.llm_engine_manager.MANAGER_ACTOR_NAME +import tests.unit_test.entrypoints.vllm.api -@ray.remote(num_cpus=0) -class MockLLMEngineManager: +class MockManager: def __init__(self, request_output_queue_type: QueueType): self._num_generates = 0 self._num_aborts = 0 @@ -52,16 +47,40 @@ async def abort(self, request_id): def testing_stats(self): return {"num_aborted_requests": self._num_aborts} + @classmethod + def from_args(cls, request_output_queue_type: QueueType): + manager_class = ray.remote(num_cpus=1, + name=get_manager_name(), + namespace='llumnix', + lifetime='detached')(cls) + manager = manager_class.remote(request_output_queue_type) + return manager + +def setup_entrypoints_context(request_output_queue_type: QueueType): + manager = ray.get_actor(get_manager_name(), namespace="llumnix") + tests.unit_test.entrypoints.vllm.api.manager = manager + ip = '127.0.0.1' + port = 1234 + request_output_queue = init_request_output_queue_server(ip, port, request_output_queue_type) + server_info = ServerInfo(random_uuid(), request_output_queue_type, request_output_queue, ip, port) + entrypoints_context = EntrypointsContext(manager, + {'0': None}, + request_output_queue, + server_info, + None, + None) + return entrypoints_context + +def run_uvicorn_server(host: str, port: int, entrypoints_context: EntrypointsContext): + llumnix.entrypoints.vllm.api_server.llumnix_client = LlumnixClientVLLM(entrypoints_context) + app = tests.unit_test.entrypoints.vllm.api.app -def init_manager(request_output_queue_type: QueueType): - manager = MockLLMEngineManager.options(name=MANAGER_ACTOR_NAME, - namespace='llumnix').remote(request_output_queue_type) - return manager - -@app.get("/stats") -def stats() -> Response: - """Get the statistics of the engine.""" - return JSONResponse(ray.get(manager.testing_stats.remote())) + uvicorn.run( + app, + host=host, + port=port, + log_level="debug", + timeout_keep_alive=llumnix.entrypoints.vllm.api_server.TIMEOUT_KEEP_ALIVE) if __name__ == "__main__": @@ -69,29 +88,10 @@ def stats() -> Response: parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--request-output-queue-type", type=str, choices=["zmq", "rayqueue"]) - parser = EngineManagerArgs.add_cli_args(parser) - args = parser.parse_args() + entrypoints_args = parser.parse_args() - request_output_queue_type = QueueType(args.request_output_queue_type) - manager = init_manager(request_output_queue_type) - ip = '127.0.0.1' - port = 1234 - request_output_queue = init_request_output_queue_server(ip, port, request_output_queue_type) - ray_queue_server = None - if request_output_queue_type == QueueType.RAYQUEUE: - ray_queue_server = request_output_queue - server_info = ServerInfo(random_uuid(), request_output_queue_type, ray_queue_server, ip, port) - llumnix_context = LlumnixEntrypointsContext(manager, - {'0': None}, - request_output_queue, - server_info, - None, - None) - llumnix.entrypoints.vllm.api_server.llumnix_client = LlumnixClientVLLM(llumnix_context) + request_output_queue_type = QueueType(entrypoints_args.request_output_queue_type) + manager = MockManager.from_args(request_output_queue_type) + entrypoints_context = setup_entrypoints_context(request_output_queue_type) - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=llumnix.entrypoints.vllm.api_server.TIMEOUT_KEEP_ALIVE) + run_uvicorn_server(entrypoints_args.host, entrypoints_args.port, entrypoints_context) diff --git a/tests/unit_test/entrypoints/vllm/api_server_actor.py b/tests/unit_test/entrypoints/vllm/api_server_actor.py new file mode 100644 index 00000000..95ae5eef --- /dev/null +++ b/tests/unit_test/entrypoints/vllm/api_server_actor.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +import threading +import ray +from ray.util.queue import Queue as RayQueue + +from llumnix.queue.utils import init_request_output_queue_client, QueueType +from llumnix.utils import get_manager_name + +from tests.unit_test.entrypoints.vllm.api_server import (MockManager, setup_entrypoints_context, + run_uvicorn_server) + +ENTRYPOINTS_ACTOR_NAME = "entrypoints" + + +class MockManagerServer(MockManager): + def __init__(self, entrypoints_args): + self._num_generates = 0 + self._num_aborts = 0 + self.request_output_queue = init_request_output_queue_client( + QueueType(entrypoints_args.request_output_queue_type)) + self.server = self.init_server(entrypoints_args) + ray.get(self.server.run.remote()) + + def init_server(self, entrypoints_args): + server = FastAPIServerActor.options(name=ENTRYPOINTS_ACTOR_NAME, + namespace='llumnix').remote(entrypoints_args) + return server + + # pylint: disable=arguments-renamed + @classmethod + def from_args(cls, entrypoints_args): + manager_class = ray.remote(num_cpus=1, + name=get_manager_name(), + namespace='llumnix', + lifetime='detached')(cls) + manager = manager_class.remote(entrypoints_args) + return manager + + +@ray.remote(num_cpus=1, lifetime="detached") +class FastAPIServerActor: + def __init__(self, entrypoints_args): + self.host = entrypoints_args.host + self.port = entrypoints_args.port + self.request_output_queue_type = QueueType(entrypoints_args.request_output_queue_type) + + def _setup_entrypoints_context(self): + self.entrypoints_context = setup_entrypoints_context(self.request_output_queue_type) + + def _run_uvicorn_server(self): + run_uvicorn_server(self.host, self.port, self.entrypoints_context) + + def run(self): + self._setup_entrypoints_context() + self.run_uvicorn_server_thread = threading.Thread( + target=self._run_uvicorn_server, args=(), + daemon=True, name="run_uvicorn_server" + ) + self.run_uvicorn_server_thread.start() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--request-output-queue-type", type=str, choices=["zmq", "rayqueue"]) + entrypoints_args = parser.parse_args() + + # magic actor, without this actor, FastAPIServer cannot initialize correctly. + # If this actor is placed globally, + # pylint will hangs if testing api_server_manager and api_server_service concurrently (--jobs > 1). + request_output_queue = RayQueue() + + manager = MockManagerServer.from_args(entrypoints_args) + + while True: + time.sleep(100.0) diff --git a/tests/unit_test/entrypoints/vllm/test_api_server.py b/tests/unit_test/entrypoints/vllm/test_api_server.py index bf6689bf..103ea2ca 100644 --- a/tests/unit_test/entrypoints/vllm/test_api_server.py +++ b/tests/unit_test/entrypoints/vllm/test_api_server.py @@ -45,11 +45,16 @@ def _query_server_generate(prompt: str) -> dict: def _query_server_generate_benchmark(prompt: str) -> dict: return _query_server(prompt, interface='generate_benchmark') -@pytest.fixture(params=["zmq", "rayqueue"]) +@pytest.fixture(params=[("zmq", "api_server"), ("rayqueue", "api_server"), ("zmq", "api_server_actor"), ("rayqueue", "api_server_actor")]) def api_server(request): - request_output_queue_type = QueueType(request.param) - script_path = Path(__file__).parent.joinpath( - "api_server_manager.py").absolute() + request_output_queue_type = QueueType(request.param[0]) + print(f"{request.param[0]}-{request.param[1]}") + if request.param[1] == "api_server": + script_path = Path(__file__).parent.joinpath( + "api_server.py").absolute() + else: + script_path = Path(__file__).parent.joinpath( + "api_server_actor.py").absolute() commands = [ sys.executable, "-u", diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index 7079c96f..9a30b6c9 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -19,7 +19,7 @@ from llumnix.instance_info import InstanceInfo from llumnix.utils import random_uuid -from .test_llm_engine_manager import get_instance_info_migrate_in, get_instance_info_migrate_out +from .test_manager import get_instance_info_migrate_in, get_instance_info_migrate_out def init_global_scheduler(): diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_manager.py similarity index 52% rename from tests/unit_test/global_scheduler/test_llm_engine_manager.py rename to tests/unit_test/global_scheduler/test_manager.py index 57de44ff..518424a2 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -19,9 +19,8 @@ from vllm import EngineArgs -from llumnix.utils import random_uuid, get_instance_name -from llumnix.arg_utils import EngineManagerArgs -from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME +from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, LaunchArgs +from llumnix.manager import Manager from llumnix.instance_info import InstanceInfo from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType @@ -29,6 +28,10 @@ from llumnix.backends.vllm.simulator import BackendSimVLLM from llumnix.backends.backend_interface import BackendType from llumnix.backends.profiling import LatencyMemData +from llumnix.entrypoints.utils import LaunchMode +from llumnix.utils import (get_placement_group_name, get_server_name, get_instance_name, + remove_placement_group, INSTANCE_NAME_PREFIX, kill_server, + kill_instance, random_uuid, get_manager_name) # pylint: disable=unused-import from tests.conftest import ray_env @@ -105,26 +108,38 @@ def _get_lantecy_mem(self, *args, **kwargs): def init_manager(): try: - engine_manager_args = EngineManagerArgs(migration_backend="rayrpc", enable_migration=True) - engine_manager_args.log_instance_info = False - manager = LLMEngineManager.from_args(engine_manager_args, None) + manager_args = ManagerArgs(migration_backend="rayrpc", enable_migration=True) + manager_args.log_instance_info = False + manager = Manager.from_args(manager_args=manager_args) except ValueError: - manager = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') + manager = ray.get_actor(get_manager_name(), namespace='llumnix') ray.get(manager.is_ready.remote()) return manager -def init_llumlets(initial_instances): +def init_manager_with_launch_mode(launch_mode, request_output_queue_type="rayqueue"): + manager_args = ManagerArgs(migration_backend="rayrpc", enable_port_increment=True) + entrypoints_args = EntrypointsArgs(host="127.0.0.1", port=8000, request_output_queue_type=request_output_queue_type) + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + launch_args = LaunchArgs(launch_mode=launch_mode, backend_type=BackendType.VLLM) + manager = Manager.from_args(manager_args=manager_args, + entrypoints_args=entrypoints_args, + engine_args=engine_args, + launch_args=launch_args) + ray.get(manager.is_ready.remote()) + return manager, manager_args, entrypoints_args, engine_args, launch_args + +def init_instances(initial_instances): instance_ids = [] - llumlets = [] + instances = [] for _ in range(initial_instances): instance_id = random_uuid() instance_name = get_instance_name(instance_id) llumlet = MockLlumlet.options(name=instance_name, namespace='llumnix').remote(instance_id) instance_ids.append(instance_id) - llumlets.append(llumlet) - ray.get([llumlet.is_ready.remote() for llumlet in llumlets]) - return instance_ids, llumlets + instances.append(llumlet) + ray.get([instance.is_ready.remote() for instance in instances]) + return instance_ids, instances @pytest.fixture def manager(): @@ -141,9 +156,23 @@ def llumlet(): ray.get(llumlet.is_ready.remote()) return llumlet +def is_actor_exists(actor_name): + try: + ray.get_actor(actor_name, namespace='llumnix') + return True + except ValueError: + return False + +def is_placement_group_exists(pg_name): + try: + ray.util.get_placement_group(pg_name) + return True + except ValueError: + return False + def test_init_manager(ray_env, manager): assert manager is not None - manager_actor_handle = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') + manager_actor_handle = ray.get_actor(get_manager_name(), namespace='llumnix') assert manager_actor_handle is not None assert manager == manager_actor_handle @@ -151,33 +180,33 @@ def test_init_llumlet(ray_env, llumlet): assert llumlet is not None ray.get(llumlet.is_ready.remote()) -def test_init_llumlets(ray_env, manager): +def test_init_instances(ray_env, manager): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - instance_ids, llumlets = ray.get(manager.init_llumlets.remote(engine_args, QueueType("rayqueue"), BackendType.VLLM, 1)) - num_instances = ray.get(manager.scale_up.remote(instance_ids, llumlets)) - engine_manager_args = EngineManagerArgs() - assert num_instances == engine_manager_args.initial_instances + _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.VLLM, engine_args)) + num_instances = len(instances) + manager_args = ManagerArgs() + assert num_instances == manager_args.initial_instances -def test_init_llumlets_sim(ray_env, manager): +def test_init_instances_sim(ray_env, manager): manager.profiling_result_file_path="//" # pylint: disable=import-outside-toplevel import llumnix.backends.vllm.simulator llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - instance_ids, llumlets = ray.get(manager.init_llumlets.remote(engine_args, QueueType("rayqueue"), BackendType.VLLM, 1)) - num_instances = ray.get(manager.scale_up.remote(instance_ids, llumlets)) - engine_manager_args = EngineManagerArgs() - assert num_instances == engine_manager_args.initial_instances + _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, engine_args)) + num_instances = len(instances) + manager_args = ManagerArgs() + assert num_instances == manager_args.initial_instances def test_scale_up_and_down(ray_env, manager): initial_instances = 4 - instance_ids, llumlets = init_llumlets(initial_instances) - num_instances = ray.get(manager.scale_up.remote(instance_ids, llumlets)) + instance_ids, instances = init_instances(initial_instances) + num_instances = ray.get(manager.scale_up.remote(instance_ids, instances)) assert num_instances == initial_instances - instance_ids_1, llumlets_1 = init_llumlets(initial_instances) + instance_ids_1, instances_1 = init_instances(initial_instances) num_instances = ray.get(manager.scale_down.remote(instance_ids_1)) assert num_instances == initial_instances - num_instances = ray.get(manager.scale_up.remote(instance_ids_1, llumlets_1)) + num_instances = ray.get(manager.scale_up.remote(instance_ids_1, instances_1)) assert num_instances == initial_instances * 2 num_instances = ray.get(manager.scale_down.remote(instance_ids)) assert num_instances == initial_instances @@ -186,18 +215,18 @@ def test_scale_up_and_down(ray_env, manager): def test_connect_to_instances(ray_env): initial_instances = 4 - instance_ids, llumlets = init_llumlets(initial_instances) - ray.get([llumlet.is_ready.remote() for llumlet in llumlets]) + instance_ids, instances = init_instances(initial_instances) + ray.get([instance.is_ready.remote() for instance in instances]) manager = init_manager() - instance_ids_1, llumlets_1 = init_llumlets(initial_instances) - num_instances = ray.get(manager.scale_up.remote(instance_ids_1, llumlets_1)) + instance_ids_1, instances_1 = init_instances(initial_instances) + num_instances = ray.get(manager.scale_up.remote(instance_ids_1, instances_1)) assert num_instances == initial_instances * 2 num_instances = ray.get(manager.scale_down.remote(instance_ids)) assert num_instances == initial_instances def test_generate_and_abort(ray_env, manager, llumlet): instance_id = ray.get(llumlet.get_instance_id.remote()) - ray.get(manager.scale_up.remote(instance_id, [llumlet])) + ray.get(manager.scale_up.remote(instance_id, llumlet)) request_id = random_uuid() num_requests = ray.get(llumlet.get_num_requests.remote()) assert num_requests == 0 @@ -216,8 +245,8 @@ def test_generate_and_abort(ray_env, manager, llumlet): assert num_requests == 0 def test_get_request_instance(ray_env): - _, llumlets = init_llumlets(2) - llumlet, llumlet_1 = llumlets[0], llumlets[1] + _, instances = init_instances(2) + llumlet, llumlet_1 = instances[0], instances[1] manager = init_manager() request_id = random_uuid() request_id_1 = random_uuid() @@ -252,37 +281,109 @@ def get_instance_info_migrate_out(instance_id): return instance_info def test_update_instance_info_loop_and_migrate(ray_env, manager): - num_llumlets = 5 - instance_ids, llumlets = init_llumlets(num_llumlets) + num_instances = 5 + instance_ids, instances = init_instances(num_instances) - for i in range(num_llumlets): + for i in range(num_instances): for _ in range(2*(i+1)): - ray.get(llumlets[i].generate.remote(random_uuid(), None, math.inf, None, None)) + ray.get(instances[i].generate.remote(random_uuid(), None, math.inf, None, None)) instance_info = InstanceInfo() instance_info.instance_type = InstanceType.NO_CONSTRAINTS - for i in range(num_llumlets): + for i in range(num_instances): instance_info.instance_id = instance_ids[i] instance_info.num_available_gpu_blocks = 40 - i * 10 instance_info.num_running_requests = i instance_info.num_blocks_first_waiting_request = i - ray.get(llumlets[i].set_instance_info.remote(instance_info)) + ray.get(instances[i].set_instance_info.remote(instance_info)) - for i in range(num_llumlets): - num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + for i in range(num_instances): + num_migrate_out = ray.get(instances[i].get_num_migrate_out.remote()) assert num_migrate_out == 0 - ray.get(manager.scale_up.remote(instance_ids, llumlets)) + ray.get(manager.scale_up.remote(instance_ids, instances)) time.sleep(2) - for i in range(num_llumlets): - num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) - num_migrate_in = ray.get(llumlets[i].get_num_migrate_in.remote()) + for i in range(num_instances): + num_migrate_out = ray.get(instances[i].get_num_migrate_out.remote()) + num_migrate_in = ray.get(instances[i].get_num_migrate_in.remote()) if i == 0: assert num_migrate_in > 1 and num_migrate_out == 0 - elif i == num_llumlets - 1: + elif i == num_instances - 1: assert num_migrate_in == 0 and num_migrate_out > 1 else: assert num_migrate_in == 0 and num_migrate_out == 0 + +def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env): + manager, _, _, engine_args, _ = init_manager_with_launch_mode(LaunchMode.LOCAL) + instance_id = random_uuid() + pg = ray.get(manager._init_placement_group.remote(get_placement_group_name(instance_id), + engine_args, BackendType.VLLM, init_server=True)) + pg = ray.util.get_placement_group(get_placement_group_name(instance_id)) + ray.get(pg.ready()) + ray.get(manager._init_server_and_instance.remote(instance_id, pg)) + # wait for scale up + time.sleep(5.0) + server = ray.get_actor(get_server_name(instance_id), namespace="llumnix") + ray.get(server.is_ready.remote()) + instance = ray.get_actor(get_instance_name(instance_id), namespace="llumnix") + ray.get(instance.is_ready.remote()) + num_instances = ray.get(manager.scale_up.remote(instance_id, instance)) + assert num_instances == 1 + + # test clear_instance_ray_resources + ray.get(manager._clear_instance_ray_resources.remote(instance_id)) + # wait for remove and kill + time.sleep(1.0) + pg_exists = is_placement_group_exists(get_placement_group_name(instance_id)) + assert not pg_exists + server_exists = is_actor_exists(get_server_name(instance_id)) + assert not server_exists + instance_exists = is_actor_exists(get_instance_name(instance_id)) + assert not instance_exists + +@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq']) +def test_auto_scale_up_loop_and_get_curr_deployment(ray_env, request_output_queue_type): + manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type) + time.sleep(30.0) + num_instances = ray.get(manager.scale_up.remote([], [])) + assert num_instances == 4 + curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote()) + assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 + + actor_names_dict = ray.util.list_named_actors(all_namespaces=True) + instance_ids = [actor_name_dict['name'].split("_")[-1] for actor_name_dict in actor_names_dict + if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)] + assert len(instance_ids) == 4 + ray.get(manager._clear_instance_ray_resources.remote(instance_ids[0])) + ray.get(manager._clear_instance_ray_resources.remote(instance_ids[1])) + time.sleep(30.0) + num_instances = ray.get(manager.scale_up.remote([], [])) + assert num_instances == 4 + curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote()) + assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 + +@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq']) +def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_output_queue_type): + manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type) + time.sleep(30.0) + num_instances = ray.get(manager.scale_up.remote([], [])) + assert num_instances == 4 + curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote()) + assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 + + actor_names_dict = ray.util.list_named_actors(all_namespaces=True) + instance_ids = [actor_name_dict['name'].split("_")[-1] for actor_name_dict in actor_names_dict + if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)] + assert len(instance_ids) == 4 + remove_placement_group(instance_ids[0]) + kill_server(instance_ids[1]) + kill_instance(instance_ids[2]) + # Wait for check deployment states, scale down instance and auto scale up. + time.sleep(120.0) + num_instances = ray.get(manager.scale_up.remote([], [])) + assert num_instances == 4 + curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote()) + assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index b5ea1749..a15dc52f 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -17,15 +17,13 @@ import torch import pytest -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - from vllm.engine.arg_utils import EngineArgs from llumnix.backends.backend_interface import BackendType from llumnix.llumlet.llumlet import Llumlet from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType -from llumnix.backends.utils import initialize_placement_group +from llumnix.utils import initialize_placement_group, get_placement_group_name # pylint: disable=unused-import from tests.conftest import ray_env @@ -34,7 +32,7 @@ class MockLlumlet(Llumlet): def __init__(self, *args, **kwargs) -> None: instance_id = kwargs["instance_id"] - placement_group = initialize_placement_group(instance_id=instance_id, num_cpus=3, num_gpus=1, detached=True) + placement_group = initialize_placement_group(get_placement_group_name(instance_id), num_cpus=3, num_gpus=1, detached=True) kwargs["placement_group"] = placement_group super().__init__(*args, **kwargs) self.origin_step = self.backend_engine.engine.step_async @@ -57,13 +55,17 @@ async def raise_error_step(): def test_engine_step_exception(ray_env): engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True) migration_config = MigrationConfig("SR", "rayrpc", 16, 1, 4, 5, 20) - scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=ray.get_runtime_context().get_node_id(), soft=False) - origin_free_memory, _ = torch.cuda.mem_get_info() + time.sleep(5.0) + + device_count = torch.cuda.device_count() + origin_free_memory_list = [] + for device_id in range(device_count): + origin_free_memory, _ = torch.cuda.mem_get_info(device_id) + origin_free_memory_list.append(origin_free_memory) actor_name = "instance_0" - llumlet = MockLlumlet.options(name=actor_name, namespace='llumnix', - scheduling_strategy=scheduling_strategy).remote( + llumlet = MockLlumlet.options(name=actor_name, namespace='llumnix').remote( instance_id="0", request_output_queue_type=QueueType.RAYQUEUE, backend_type=BackendType.VLLM, @@ -76,8 +78,11 @@ def test_engine_step_exception(ray_env): all_actor_names = [actor["name"] for actor in all_actors] assert actor_name in all_actor_names - cur_free_memory, _ = torch.cuda.mem_get_info() - assert cur_free_memory < origin_free_memory + cur_free_memory_list = [] + for device_id in range(device_count): + cur_free_memory, _ = torch.cuda.mem_get_info(device_id) + cur_free_memory_list.append(cur_free_memory) + assert origin_free_memory_list != cur_free_memory_list ray.get(llumlet.set_error_step.remote(True)) time.sleep(3) @@ -86,5 +91,8 @@ def test_engine_step_exception(ray_env): all_actor_names = [actor["name"] for actor in all_actors] assert actor_name not in all_actor_names - cur_free_memory, _ = torch.cuda.mem_get_info() - assert origin_free_memory == cur_free_memory + cur_free_memory_list = [] + for device_id in range(device_count): + cur_free_memory, _ = torch.cuda.mem_get_info(device_id) + cur_free_memory_list.append(cur_free_memory) + assert origin_free_memory_list == cur_free_memory_list