diff --git a/configs/base.yml b/configs/base.yml index d91a5135..afce7127 100644 --- a/configs/base.yml +++ b/configs/base.yml @@ -1,6 +1,6 @@ SERVER: HOST: '127.0.0.1' - PORT: 37000 + PORT: 1234 QUEUE_TYPE: "rayqueue" RAY: diff --git a/examlpes/offline_inference.py b/examlpes/offline_inference.py index 96d86b0e..dabb4f94 100644 --- a/examlpes/offline_inference.py +++ b/examlpes/offline_inference.py @@ -27,7 +27,7 @@ ray_cluster_port=6379 # Note: launch_ray_cluster will stop current ray cluster first, then init a new one. -launch_ray_cluster(ray_cluster_port=ray_cluster_port) +launch_ray_cluster(port=ray_cluster_port) connect_to_ray_cluster(port=ray_cluster_port) # Set manager args and engine args. diff --git a/llumnix/__init__.py b/llumnix/__init__.py index 71de719c..a6892514 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -15,8 +15,10 @@ from vllm import * from llumnix.server_info import ServerInfo -from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster, - init_manager, init_llumlets) +from llumnix.entrypoints.utils import (launch_ray_cluster, + connect_to_ray_cluster, + init_manager, + init_llumlets) from llumnix.arg_utils import EngineManagerArgs from llumnix.llm_engine_manager import LLMEngineManager from llumnix.llumlet.llumlet import Llumlet diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 404b3477..70a643cf 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -22,6 +22,84 @@ from llumnix.config import LlumnixConfig, get_llumnix_config from llumnix.config.default import _C + +class LlumnixArgumentParser(argparse.ArgumentParser): + def __init__(self, *args, **kwargs): + self.cur_namespace = "llumnix" + super().__init__(*args, **kwargs) + + def set_namespace(self, namespace: str): + self.cur_namespace = namespace + + def add_argument(self, *args, **kwargs): + if self.cur_namespace == 'llumnix' and "--help" not in args: + assert 'default' not in kwargs or kwargs['default'] is None, \ + f"Do not set the default value for '{args[0]}' in CLI, or set default value to None. " \ + f"The default value will be retrieved from config/default.py in get_llumnix_config." + if kwargs.get('action') == 'store_true': + kwargs['default'] = None + super().add_argument(*args, **kwargs) + + +# All the default values of llumnix arguments are set in default.py. So all the arguments here are set to None. + +@dataclass +class LlumnixEntrypointsArgs: + launch_ray_cluster: bool = None + ray_cluster_port: int = None + queue_type: str = None + request_output_queue_port: int = None + disable_log_requests_server: bool = None + log_request_timestamps: bool = None + config_file: bool = None + + def __post_init__(self): + for attr in dataclasses.fields(self): + if getattr(self, attr.name) is None: + setattr(self, attr.name, getattr(_C.SERVER, attr.name.upper())) + + @classmethod + def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'LlumnixEntrypointsArgs': + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + # The defalut values of attributes are defined in default.py. + llumnix_entrypoints_args = cls(**{attr: getattr(cfg.SERVER, attr.upper()) for attr in attrs}) + return llumnix_entrypoints_args + + @classmethod + def check_args(cls, args: 'LlumnixEntrypointsArgs', parser: argparse.ArgumentParser): + # pylint: disable=protected-access + for action in parser._optionals._actions: + if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): + assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}." + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument('--launch-ray-cluster', + action='store_true', + help='if launch ray cluster in api server') + parser.add_argument("--ray-cluster-port", + type=int, + help='ray cluster port') + parser.add_argument("--queue-type", + type=str, + choices=['rayqueue', 'zmq'], + help='queue type for request output queue') + parser.add_argument("--request-output-queue-port", + type=int, + help='port for zmq') + parser.add_argument('--disable-log-requests-server', + action='store_true', + help='disable logging requests in server') + parser.add_argument("--log-request-timestamps", + action='store_true', + help='if log request timestamps') + parser.add_argument("--config-file", + type=str, + help="path to config file") + return parser + @dataclass class EngineManagerArgs: disable_init_instance_by_manager: bool = None @@ -106,6 +184,7 @@ def from_llumnix_config(cls, cfg: LlumnixConfig = get_llumnix_config()) -> 'Engi # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. + # The defalut values of attributes are defined in default.py. engine_manager_args = cls(**{attr: getattr(cfg.MANAGER, attr.upper()) for attr in attrs}) return engine_manager_args diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 3e25b393..34a93524 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -39,7 +39,7 @@ from llumnix.server_info import ServerInfo from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_client_base import QueueClientBase -from llumnix.queue.utils import get_output_queue_client, QueueType +from llumnix.queue.utils import init_output_queue_client, QueueType logger = init_logger(__name__) @@ -48,7 +48,7 @@ class AsyncPutQueueActor: def __init__(self, instance_id, output_queue_type: QueueType): self.instance_id = instance_id self.output_queue_type = output_queue_type - self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type) + self.request_output_queue_client: QueueClientBase = init_output_queue_client(output_queue_type) self.engine_actor_handle = None async def put_nowait_to_servers(self, @@ -225,7 +225,7 @@ def step(self) -> None: tot_blocks = set(tot_blocks) instance_info.num_blocks_last_running_request = len(tot_blocks) if request_outputs: - self.put_queue_args_queue.put((request_outputs, server_infos)) + self.put_queue_args_queue.put_nowait((request_outputs, server_infos)) self.instance_info = instance_info for request_output in request_outputs: if hasattr(request_output, 'request_timestamps'): diff --git a/llumnix/backends/vllm/sequence.py b/llumnix/backends/vllm/sequence.py index 4df053f1..3c41a5c6 100644 --- a/llumnix/backends/vllm/sequence.py +++ b/llumnix/backends/vllm/sequence.py @@ -25,6 +25,9 @@ def __init__(self, request_id, server_info, expected_steps: int, *args, **kwargs def prompt_len(self) -> int: return self.get_seqs()[0].get_prompt_len() + def is_finished(self) -> bool: + return self.get_seqs()[0].is_finished() + @property def request_len(self) -> int: return self.get_seqs()[0].get_len() diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 3ff053cc..17849463 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -21,34 +21,35 @@ _C = LC() # ----------------------------------------------------------------------------- -# API SERVER CONFIGURATION +# SERVER CONFIGURATION # ----------------------------------------------------------------------------- _C.SERVER = LC() # Hostname for the server _C.SERVER.HOST = "localhost" # Port number for the server _C.SERVER.PORT = 8000 -# Queue type for request output queue -_C.SERVER.QUEUE_TYPE = "rayqueue" -# Port number for the request output queue -_C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234 # Path to SSL key file for secure connections _C.SERVER.SSL_KEYFILE = None # Path to SSL certificate file for secure connections _C.SERVER.SSL_CERTFILE = None +# Queue type for request output queue +_C.SERVER.QUEUE_TYPE = "rayqueue" +# Port number for the request output queue +_C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234 # Disable logging requests in server _C.SERVER.DISABLE_LOG_REQUESTS_SERVER = False # Enable logging request timestamp _C.SERVER.LOG_REQUEST_TIMESTAMPS = False +# Config file of Llumnix arguments +_C.SERVER.CONFIG_FILE = None # ----------------------------------------------------------------------------- # RAY CONFIGURATION # ----------------------------------------------------------------------------- -_C.RAY = LC() -# Port number for the Ray cluster -_C.RAY.RAY_CLUSTER_PORT = 6379 # If True, launch Ray cluster in API server -_C.RAY.LAUNCH_RAY_CLUSTER = False +_C.SERVER.LAUNCH_RAY_CLUSTER = False +# Port number for the Ray cluster +_C.SERVER.RAY_CLUSTER_PORT = 6379 # ----------------------------------------------------------------------------- # MANAGER CONFIGURATION diff --git a/llumnix/entrypoints/llumnix_utils.py b/llumnix/entrypoints/utils.py similarity index 57% rename from llumnix/entrypoints/llumnix_utils.py rename to llumnix/entrypoints/utils.py index 99b1e791..7fea885b 100644 --- a/llumnix/entrypoints/llumnix_utils.py +++ b/llumnix/entrypoints/utils.py @@ -15,7 +15,7 @@ import sys import os import time -from typing import List, Tuple +from typing import List, Tuple, Dict import asyncio import socket import ray @@ -27,6 +27,8 @@ from llumnix.utils import random_uuid from llumnix.arg_utils import EngineManagerArgs from llumnix.queue.queue_type import QueueType +from llumnix.server_info import ServerInfo, RequestTimestamps +from llumnix.queue.utils import init_output_queue_server logger = init_logger(__name__) @@ -36,12 +38,26 @@ MAX_TASK_RETRIES = 300 RETRIES_INTERVALS = 0.1 + +class LlumnixEntrypointsContext: + def __init__(self): + self.engine_manager: LLMEngineManager = None + self.instances: Dict[str, Llumlet] = {} + self.request_output_queue: QueueServerBase = None + self.server_info: ServerInfo = None + self.request_streams: Dict[str, AsyncStream] = {} + self.manager_available = True + self.num_finished_requests = 0 + self.instance_num_requests: Dict[str, int] = {} + self.log_requests: bool = None + self.log_request_timestamps: bool = None + def get_ip_address(): hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) return ip_address -def launch_ray_cluster(ray_cluster_port: int) -> subprocess.CompletedProcess: +def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: head_node_ip = os.getenv('HEAD_NODE_IP') node_ip_address = get_ip_address() try: @@ -56,18 +72,18 @@ def launch_ray_cluster(ray_cluster_port: int) -> subprocess.CompletedProcess: sys.exit(1) ray_start_command = None if 'HEAD_NODE' in os.environ: - ray_start_command = f"ray start --head --node-ip-address={node_ip_address} --port={ray_cluster_port}" + ray_start_command = f"ray start --head --node-ip-address={node_ip_address} --port={port}" try: - result = subprocess.run(['ray', 'start', '--head', f'--port={ray_cluster_port}'], check=True, text=True, capture_output=True) + result = subprocess.run(['ray', 'start', '--head', f'--port={port}'], check=True, text=True, capture_output=True) except subprocess.CalledProcessError as e: logger.info("'{}' failed with: \n{}".format(ray_start_command, e.stderr)) sys.exit(1) else: - ray_start_command = f"ray start --address={head_node_ip}:{ray_cluster_port} --node-ip-address={node_ip_address}" + ray_start_command = f"ray start --address={head_node_ip}:{port} --node-ip-address={node_ip_address}" for attempt in range(MAX_RESTARTS): try: # wait about 2 mins by default - result = subprocess.run(['ray', 'start', f'--address={head_node_ip}:{ray_cluster_port}'], check=True, text=True, capture_output=True) + result = subprocess.run(['ray', 'start', f'--address={head_node_ip}:{port}'], check=True, text=True, capture_output=True) break except subprocess.CalledProcessError as e: if attempt < MAX_RESTARTS: @@ -83,6 +99,11 @@ def connect_to_ray_cluster(port: int, namespace="llumnix") -> None: head_node_ip = os.getenv('HEAD_NODE_IP') ray.init(address=f"{head_node_ip}:{port}", ignore_reinit_error=True, namespace=namespace) +def setup_ray_cluster(cfg): + if cfg.SERVER.LAUNCH_RAY_CLUSTER: + launch_ray_cluster(cfg.SERVER.RAY_CLUSTER_PORT) + connect_to_ray_cluster(port=cfg.SERVER.RAY_CLUSTER_PORT) + def is_gpu_available() -> bool: try: subprocess.check_output(['nvidia-smi']) @@ -129,7 +150,7 @@ def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager: return engine_manager def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: str, - output_queue_type: QueueType) -> Tuple[List[str], List[Llumlet]]: + request_output_queue_type: QueueType) -> Tuple[List[str], List[Llumlet]]: engine_config = engine_args.create_engine_config() parallel_config = engine_config.parallel_config instance_ids: List[str] = [] @@ -141,7 +162,7 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: instance_id = instance_ids[idx] if not engine_manager_args.profiling_result_file_path: llumlet = Llumlet.from_args( - output_queue_type, + request_output_queue_type, engine_manager_args.disable_fixed_node_init_instance, False, node_id, @@ -153,7 +174,7 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: ) else: llumlet = Llumlet.from_args( - output_queue_type, + request_output_queue_type, engine_manager_args.disable_fixed_node_init_instance, False, node_id, @@ -170,20 +191,20 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: def init_llumnix_components(engine_manager_args: EngineManagerArgs, engine_args, node_id: str, - output_queue_type: QueueType): + request_output_queue_type: QueueType, + ip: str, + request_output_queue_port: str): engine_manager = init_manager(engine_manager_args) if engine_manager_args.disable_init_instance_by_manager: - instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id, output_queue_type) + instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id, request_output_queue_type) else: instance_ids, llumlets = retry_manager_method_sync( - engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id, output_queue_type) + engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id, request_output_queue_type) available_instance_ids = [] dead_instance_ids = [] available_llumlets = [] - ready_tasks = [llumlet.is_ready.remote() for llumlet in llumlets] - for idx, task in enumerate(ready_tasks): try: ray.get(task) @@ -191,14 +212,89 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs, available_llumlets.append(llumlets[idx]) except ray.exceptions.RayActorError: dead_instance_ids.append(instance_ids[idx]) - if len(dead_instance_ids) > 0: retry_manager_method_sync(engine_manager.scale_down.remote, 'scale_down', dead_instance_ids) - if len(available_instance_ids) > 0: retry_manager_method_sync(engine_manager.scale_up.remote, 'scale_up', available_instance_ids, available_llumlets) logger.info("Init Llumnix components done, {} instances are ready, instance_ids: {}." .format(len(available_instance_ids), available_instance_ids)) - return engine_manager, available_instance_ids, available_llumlets + request_output_queue = init_output_queue_server(ip, request_output_queue_port, request_output_queue_type) + + return engine_manager, available_instance_ids, available_llumlets, request_output_queue + +def setup_llumnix(engine_manager_args, engine_args, cfg): + ip = get_ip_address() + node_id = ray.get_runtime_context().get_node_id() + engine_manager, instance_ids, llumlets, request_output_queue = \ + init_llumnix_components(engine_manager_args, + engine_args, + node_id, + cfg.SERVER.QUEUE_TYPE, + ip, + cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT) + server_id = random_uuid() + server_info = ServerInfo(server_id, + cfg.SERVER.QUEUE_TYPE, + request_output_queue, + ip, + cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT) + instances: Dict[str, Llumlet] = {} + instance_num_requests: Dict[str, int] = {} + for idx, ins_id in enumerate(instance_ids): + instances[ins_id] = llumlets[idx] + instance_num_requests[ins_id] = 0 + log_requests = not cfg.SERVER.DISABLE_LOG_REQUESTS_SERVER + log_request_timestamps = cfg.SERVER.LOG_REQUEST_TIMESTAMPS + logger.info("log_requests: {}, log_request_timestamps: {}".format(log_requests, log_request_timestamps)) + + context = LlumnixEntrypointsContext() + context.engine_manager = engine_manager + context.instances = instances + context.request_output_queue = request_output_queue + context.server_info = server_info + context.instance_num_requests = instance_num_requests + context.log_requests = log_requests + context.log_request_timestamps = log_request_timestamps + + return context + +async def _background_process_outputs(llumnix_context): + while True: + request_outputs = await llumnix_context.request_output_queue.get() + for request_output in request_outputs: + if hasattr(request_output, 'request_timestamps'): + request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time() + for request_output in request_outputs: + request_id = request_output.request_id + # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. + if request_id not in llumnix_context.request_streams: + continue + llumnix_context.request_streams[request_id].put(request_output) + if request_output.finished: + llumnix_context.request_streams[request_id].finish() + del llumnix_context.request_streams[request_id] + +def init_per_token_latency_breakdown_dict() -> Dict[str, int]: + per_token_latency_breakdown_dict = { + 'step_latency_engine': [], + 'process_model_outputs_latency': [], + 'step_postprocess_latency': [], + 'across_async_put_queue_thread_latency': [], + 'across_async_put_queue_actor_latency': [], + 'queue_rpc_latency': [], + 'background_process_get_queue_latency': [], + 'generate_benchmark_return_output_latency': [] + } + return per_token_latency_breakdown_dict + +def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: RequestTimestamps): + per_token_latency_breakdown_dict['step_latency_engine'].append(request_timestamps.step_latency_engine) + per_token_latency_breakdown_dict['process_model_outputs_latency'].append(request_timestamps.process_model_outputs_latency) + per_token_latency_breakdown_dict['step_postprocess_latency'].append(request_timestamps.step_postprocess_latency) + per_token_latency_breakdown_dict['across_async_put_queue_thread_latency'].append(request_timestamps.across_async_put_queue_thread_latency) + per_token_latency_breakdown_dict['across_async_put_queue_actor_latency'].append(request_timestamps.across_async_put_queue_actor_latency) + per_token_latency_breakdown_dict['queue_rpc_latency'].append(request_timestamps.queue_rpc_latency) + per_token_latency_breakdown_dict['background_process_get_queue_latency'].append(request_timestamps.background_process_get_queue_latency) + per_token_latency_breakdown_dict['generate_benchmark_return_output_latency'].append(request_timestamps.generate_benchmark_return_output_latency) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index f2c12220..da2c2b10 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -11,125 +11,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, AsyncGenerator +from typing import AsyncGenerator from contextlib import asynccontextmanager -import argparse import time import asyncio import json -import copy from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn -import ray from vllm.sampling_params import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncStream - -from llumnix.arg_utils import EngineManagerArgs -from llumnix.server_info import ServerInfo, RequestTimestamps -from llumnix.entrypoints.llumnix_utils import (get_ip_address, - launch_ray_cluster, connect_to_ray_cluster, - is_gpu_available, init_llumnix_components) + +from llumnix.arg_utils import LlumnixArgumentParser +from llumnix.entrypoints.utils import (setup_ray_cluster, + setup_llumnix, + is_gpu_available, + LlumnixEntrypointsContext, + _background_process_outputs, + init_per_token_latency_breakdown_dict, + record_per_token_latency_breakdown) +from llumnix.entrypoints.vllm.utils import (add_cli_args, + get_args, + manager_generate, + manager_abort, + manager_is_ready) from llumnix.logger import init_logger from llumnix.utils import random_uuid -from llumnix.backends.vllm.utils import check_engine_args -from llumnix.queue.queue_server_base import QueueServerBase -from llumnix.queue.utils import get_output_queue_server from llumnix.config import get_llumnix_config, LlumnixConfig -logger = init_logger("llumnix.api_server") +# Code file with __main__ should set the logger name to inherit the llumnix logger configuration. +logger = init_logger("llumnix.entrypoints.vllm.api_server") -engine_manager = None -instances = {} -instance_num_requests: Dict[str, int] = {} -# request_output_queue could be None if initialzed in lifespan. -request_output_queue: QueueServerBase = None -server_info = None TIMEOUT_KEEP_ALIVE = 5 # seconds. -request_streams: Dict[str, AsyncStream] = {} -log_requests = None -log_request_timestamps = None -num_finished_requests = 0 -WAIT_MANAGER_INTERVAL = 5 -manager_available = True - - -async def _background_process_outputs(): - while True: - request_outputs = await request_output_queue.get() - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time() - for request_output in request_outputs: - request_id = request_output.request_id - # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. - if request_id not in request_streams: - continue - request_streams[request_id].put(request_output) - if request_output.finished: - request_streams[request_id].finish() - del request_streams[request_id] + +llumnix_context: LlumnixEntrypointsContext = None + # pylint: disable=unused-argument @asynccontextmanager async def lifespan(fastapi_app: FastAPI): - asyncio.create_task(request_output_queue.run_server_loop()) - asyncio.create_task(_background_process_outputs()) + asyncio.create_task(llumnix_context.request_output_queue.run_server_loop()) + asyncio.create_task(_background_process_outputs(llumnix_context)) yield - request_output_queue.cleanup() + llumnix_context.request_output_queue.cleanup() app = FastAPI(lifespan=lifespan) -async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream: - if sampling_params.n > 1 or sampling_params.use_beam_search: - raise ValueError("Unsupported feature: multiple sequence decoding") - results_generator = AsyncStream(request_id) - request_streams[request_id] = results_generator - # This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in. - # If manager is unavailable, request will be directly added to the llumlet held by api server. - global manager_available - try: - server_info_copy = copy.deepcopy(server_info) - if log_request_timestamps: - # Hack request timestamps in server_info for latency breakdown. - server_info_copy.request_timestamps = RequestTimestamps() - server_info_copy.request_timestamps.api_server_manager_generate_timestamp = time.time() - # await to catch exception - await engine_manager.generate.remote(request_id, server_info_copy, prompt, sampling_params) - manager_available = True - except ray.exceptions.RayActorError: - # Do not re-generate the request to avoid duplicate requests. - if manager_available: - manager_available = False - return results_generator - try: - if instance_num_requests: - instance_id = min(instance_num_requests, key=instance_num_requests.get) - instance_num_requests[instance_id] += 1 - await instances[instance_id].generate.remote(request_id, server_info, prompt, sampling_params) - logger.info("Manager is unavailable, directly pass request {} to instance {}".format(request_id, instance_id)) - else: - logger.info("Manager is unavailable, but there is no instance behind this api server, " - "sleep {}s, waiting for manager restarts".format(WAIT_MANAGER_INTERVAL)) - await asyncio.sleep(WAIT_MANAGER_INTERVAL) - return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id)) - except (ray.exceptions.RayActorError, KeyError): - if instance_id in instances: - logger.info("[manager_generate] instance {} is dead".format(instance_id)) - del instances[instance_id] - del instance_num_requests[instance_id] - return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id)) - return results_generator - -async def manager_abort(request_id: str) -> None: - try: - logger.info("abort request: {}.".format(request_id)) - await engine_manager.abort.remote(request_id) - except ray.exceptions.RayActorError: - logger.info("Manager is unavailable") - @app.get("/health") async def health() -> Response: @@ -152,7 +79,8 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = await manager_generate(prompt, sampling_params, request_id) + # Use manager_generate and manager_abort to replace with vllm async engine generate and abort api. + results_generator = await manager_generate(prompt, sampling_params, request_id, llumnix_context) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -172,7 +100,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. - await manager_abort(request_id) + await manager_abort(request_id, llumnix_context) return Response(status_code=499) final_output = request_output @@ -182,31 +110,16 @@ async def stream_results() -> AsyncGenerator[bytes, None]: ret = {"text": text_outputs} return JSONResponse(ret) -def init_per_token_latency_breakdown_dict() -> Dict[str, int]: - per_token_latency_breakdown_dict = { - 'step_latency_engine': [], - 'process_model_outputs_latency': [], - 'step_postprocess_latency': [], - 'across_async_put_queue_thread_latency': [], - 'across_async_put_queue_actor_latency': [], - 'queue_rpc_latency': [], - 'background_process_get_queue_latency': [], - 'generate_benchmark_return_output_latency': [] - } - return per_token_latency_breakdown_dict - -def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: RequestTimestamps): - per_token_latency_breakdown_dict['step_latency_engine'].append(request_timestamps.step_latency_engine) - per_token_latency_breakdown_dict['process_model_outputs_latency'].append(request_timestamps.process_model_outputs_latency) - per_token_latency_breakdown_dict['step_postprocess_latency'].append(request_timestamps.step_postprocess_latency) - per_token_latency_breakdown_dict['across_async_put_queue_thread_latency'].append(request_timestamps.across_async_put_queue_thread_latency) - per_token_latency_breakdown_dict['across_async_put_queue_actor_latency'].append(request_timestamps.across_async_put_queue_actor_latency) - per_token_latency_breakdown_dict['queue_rpc_latency'].append(request_timestamps.queue_rpc_latency) - per_token_latency_breakdown_dict['background_process_get_queue_latency'].append(request_timestamps.background_process_get_queue_latency) - per_token_latency_breakdown_dict['generate_benchmark_return_output_latency'].append(request_timestamps.generate_benchmark_return_output_latency) - @app.post("/generate_benchmark") async def generate_benchmark(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + # Add some benchmark-related codes comparing to the generate API. request_dict = await request.json() prompt = request_dict.pop("prompt") _ = request_dict.pop("stream", False) @@ -215,7 +128,7 @@ async def generate_benchmark(request: Request) -> Response: start = time.time() - results_generator = await manager_generate(prompt, sampling_params, request_id) + results_generator = await manager_generate(prompt, sampling_params, request_id, llumnix_context) # Non-streaming case final_output = None @@ -224,21 +137,21 @@ async def generate_benchmark(request: Request) -> Response: async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. - await manager_abort(request_id) + await manager_abort(request_id, llumnix_context) return Response(status_code=499) now = time.time() per_token_latency.append([now, (now - start)*1000]) + start = now + final_output = request_output if hasattr(request_output, 'request_timestamps'): request_output.request_timestamps.api_server_generate_benchmark_timestamp_end = now record_per_token_latency_breakdown(per_token_latency_breakdown_dict, request_output.request_timestamps) - start = now - final_output = request_output + assert final_output is not None - global num_finished_requests - if log_requests: + if llumnix_context.log_requests: + llumnix_context.num_finished_requests += 1 logger.info("Finished request {}.".format(request_id)) - num_finished_requests += 1 - logger.info("num_finished_requests {}.".format(num_finished_requests)) + logger.info("num_finished_requests {}.".format(llumnix_context.num_finished_requests)) generation = final_output.outputs[0].text num_output_tokens = len(final_output.outputs[0].token_ids) @@ -259,85 +172,29 @@ async def generate_benchmark(request: Request) -> Response: @app.get("/is_ready") async def is_ready(): - ready_status = await engine_manager.is_ready.remote() - return ready_status - -class LlumnixArgumentParser(argparse.ArgumentParser): - def __init__(self, *args, **kwargs): - self.cur_namespace = "llumnix" - super().__init__(*args, **kwargs) - - def set_namespace(self, namespace: str): - self.cur_namespace = namespace + return await manager_is_ready(llumnix_context) - def add_argument(self, *args, **kwargs): - if self.cur_namespace == 'llumnix' and "--help" not in args: - assert 'default' not in kwargs or kwargs['default'] is None, \ - f"Do not set the default value for '{args[0]}' in CLI, or set default value to None. " \ - f"The default value will be retrieved from config/default.py in get_llumnix_config." - - if kwargs.get('action') == 'store_true': - kwargs['default'] = None - - super().add_argument(*args, **kwargs) if __name__ == "__main__": parser: LlumnixArgumentParser = LlumnixArgumentParser() - parser.set_namespace("llumnix") parser.add_argument("--host", type=str) parser.add_argument("--port", type=int) parser.add_argument("--ssl-keyfile", type=str) parser.add_argument("--ssl-certfile", type=str) - parser.add_argument('--disable-log-requests-server', action='store_true', help='disable logging requests in server') - parser.add_argument("--ray-cluster-port", type=int) - parser.add_argument('--launch-ray-cluster', action='store_true', help='if launch ray cluster in api server') - parser.add_argument("--queue-type", type=str, choices=['rayqueue', 'zmq'], help='queue type for request output queue') - parser.add_argument("--request-output-queue-port", type=int, help='port for zmq') - parser.add_argument("--log-request-timestamps", action='store_true', help='if log request timestamps') - parser.add_argument("--config-file", help="path to config file") - parser = EngineManagerArgs.add_cli_args(parser) - - parser.set_namespace("vllm") - parser = AsyncEngineArgs.add_cli_args(parser) - - cli_args = parser.parse_args() - cfg: LlumnixConfig = get_llumnix_config(cli_args.config_file, cli_args) - engine_manager_args = EngineManagerArgs.from_llumnix_config(cfg) - EngineManagerArgs.check_args(engine_manager_args, parser) - engine_args = AsyncEngineArgs.from_cli_args(cli_args) - check_engine_args(engine_args, engine_manager_args) - - logger.info("engine_args: {}".format(engine_args)) - - if cfg.RAY.LAUNCH_RAY_CLUSTER: - # Launch the ray cluster for multi-node serving. - launch_ray_cluster(cfg.RAY.RAY_CLUSTER_PORT) + cli_args = add_cli_args(parser) + cfg: LlumnixConfig = get_llumnix_config(cli_args.config_file, cli_args) + _, engine_manager_args, engine_args = get_args(cfg, parser, cli_args) - # Connect to a ray cluster. - connect_to_ray_cluster(port=cfg.RAY.RAY_CLUSTER_PORT) + # Launch or connect to the ray cluster for multi-node serving. + setup_ray_cluster(cfg) # if gpu is not available, it means that this node is head pod without any llumnix components if is_gpu_available(): - # Launch the Llumnix componets on current node. - server_id = random_uuid() - ip = get_ip_address() - node_id = ray.get_runtime_context().get_node_id() - engine_manager, instance_ids, llumlets = \ - init_llumnix_components(engine_manager_args, engine_args, node_id, cfg.SERVER.QUEUE_TYPE) - request_output_queue = get_output_queue_server(ip, cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT, cfg.SERVER.QUEUE_TYPE) - server_info = ServerInfo(server_id, cfg.SERVER.QUEUE_TYPE, request_output_queue, ip, - cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT) - - for idx, ins_id in enumerate(instance_ids): - instances[ins_id] = llumlets[idx] - instance_num_requests[ins_id] = 0 - log_requests = not cfg.SERVER.DISABLE_LOG_REQUESTS_SERVER - log_request_timestamps = cfg.SERVER.LOG_REQUEST_TIMESTAMPS + llumnix_context = setup_llumnix(engine_manager_args, engine_args, cfg) # Start the api server after all the components of llumnix are ready. logger.info("Start Api Server on '{}:{}'".format(cfg.SERVER.HOST, cfg.SERVER.PORT)) - logger.info("log_requests: {}, log_request_timestamps: {}".format(log_requests, log_request_timestamps)) uvicorn.run(app, host=cfg.SERVER.HOST, port=cfg.SERVER.PORT, diff --git a/llumnix/entrypoints/vllm/utils.py b/llumnix/entrypoints/vllm/utils.py new file mode 100644 index 00000000..64af2dfc --- /dev/null +++ b/llumnix/entrypoints/vllm/utils.py @@ -0,0 +1,98 @@ +import copy +import time +import asyncio +import ray + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncStream +from vllm import SamplingParams + +from llumnix.backends.vllm.utils import check_engine_args +from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs +from llumnix.logger import init_logger +from llumnix.entrypoints.utils import LlumnixEntrypointsContext +from llumnix.server_info import RequestTimestamps + +logger = init_logger(__name__) + +WAIT_MANAGER_INTERVAL = 5 + + +def add_cli_args(parser): + parser.set_namespace("llumnix") + parser = LlumnixEntrypointsArgs.add_cli_args(parser) + parser = EngineManagerArgs.add_cli_args(parser) + parser.set_namespace("vllm") + parser = AsyncEngineArgs.add_cli_args(parser) + cli_args = parser.parse_args() + return cli_args + +def get_args(cfg, parser, cli_args): + llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(cfg) + LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, parser) + engine_manager_args = EngineManagerArgs.from_llumnix_config(cfg) + EngineManagerArgs.check_args(engine_manager_args, parser) + engine_args = AsyncEngineArgs.from_cli_args(cli_args) + check_engine_args(engine_args, engine_manager_args) + + logger.info("llumnix_entrypoints_args: {}".format(llumnix_entrypoints_args)) + logger.info("engine_manager_args: {}".format(engine_manager_args)) + logger.info("engine_args: {}".format(engine_args)) + + return llumnix_entrypoints_args, engine_manager_args, engine_args + +async def manager_generate(prompt: str, + sampling_params: SamplingParams, + request_id: str, + llumnix_context: LlumnixEntrypointsContext) -> AsyncStream: + results_generator = AsyncStream(request_id) + llumnix_context.request_streams[request_id] = results_generator + + if sampling_params.n > 1 or sampling_params.use_beam_search: + raise ValueError("Unsupported feature: multiple sequence decoding") + # This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in. + # If manager is unavailable, request will be directly added to the llumlet held by api server. + try: + server_info_copy = copy.deepcopy(llumnix_context.server_info) + if llumnix_context.log_request_timestamps: + # Hack request timestamps in server_info for latency breakdown. + server_info_copy.request_timestamps = RequestTimestamps() + server_info_copy.request_timestamps.api_server_manager_generate_timestamp = time.time() + # await to catch exception + await llumnix_context.engine_manager.generate.remote(request_id, server_info_copy, prompt, sampling_params) + llumnix_context.manager_available = True + except ray.exceptions.RayActorError: + # Do not re-generate the request to avoid duplicate requests. + if llumnix_context.manager_available: + llumnix_context.manager_available = False + return results_generator + try: + if llumnix_context.instance_num_requests: + instance_id = min(llumnix_context.instance_num_requests, key=llumnix_context.instance_num_requests.get) + llumnix_context.instance_num_requests[instance_id] += 1 + await llumnix_context.instances[instance_id].generate.remote(request_id, server_info_copy, prompt, sampling_params) + logger.info("Manager is unavailable, directly pass request {} to instance {}".format(request_id, instance_id)) + else: + logger.info("Manager is unavailable, but there is no instance behind this api server, " + "sleep {}s, waiting for manager restarts".format(WAIT_MANAGER_INTERVAL)) + await asyncio.sleep(WAIT_MANAGER_INTERVAL) + return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id, llumnix_context)) + except (ray.exceptions.RayActorError, KeyError): + if instance_id in llumnix_context.instances: + logger.info("[manager_generate] instance {} is dead".format(instance_id)) + del llumnix_context.instances[instance_id] + del llumnix_context.instance_num_requests[instance_id] + return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id, llumnix_context)) + + return results_generator + +async def manager_abort(request_id: str, llumnix_context: LlumnixEntrypointsContext) -> None: + try: + logger.info("abort request: {}.".format(request_id)) + await llumnix_context.engine_manager.abort.remote(request_id) + except ray.exceptions.RayActorError: + logger.info("Manager is unavailable") + +async def manager_is_ready(llumnix_context: LlumnixEntrypointsContext): + ready_status = await llumnix_context.engine_manager.is_ready.remote() + return ready_status diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 35a6a574..7a972b7a 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -443,7 +443,6 @@ def from_args(cls, os.getcwd(), log_requests=not engine_manager_args.disable_log_requests_manager, profiling_database=profiling_database) - logger.info("engine_manager_args: {}".format(engine_manager_args)) return engine_manager # TODO(s5u13b): Significant duplication with llumlet_utils.init_llumlets. Consider reducing duplicate codes. diff --git a/llumnix/queue/utils.py b/llumnix/queue/utils.py index 3ed4b797..84e270fa 100644 --- a/llumnix/queue/utils.py +++ b/llumnix/queue/utils.py @@ -20,7 +20,7 @@ from llumnix.queue.zmq_utils import get_open_zmq_ipc_path from llumnix.queue.queue_type import QueueType -def get_output_queue_server(zmq_ip: str, zmq_port: int, queue_type: QueueType) -> QueueServerBase: +def init_output_queue_server(zmq_ip: str, zmq_port: int, queue_type: QueueType) -> QueueServerBase: output_queue_server: QueueServerBase = None if queue_type == QueueType.ZMQ: rpc_path = get_open_zmq_ipc_path(zmq_ip, zmq_port) @@ -29,7 +29,7 @@ def get_output_queue_server(zmq_ip: str, zmq_port: int, queue_type: QueueType) - output_queue_server = RayQueueServer() return output_queue_server -def get_output_queue_client(queue_type: QueueType) -> QueueClientBase: +def init_output_queue_client(queue_type: QueueType) -> QueueClientBase: output_queue_client: QueueClientBase = None if queue_type == QueueType.ZMQ: output_queue_client= ZmqClient() diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index c19d581c..280b8d75 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -21,7 +21,7 @@ from vllm import LLM, SamplingParams def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool = True, HEAD_NODE_IP: str = "127.0.0.1", - ip: str = "127.0.0.1", port: int = 37000, instances_num = 1, dispatch_policy: str = "load", + ip: str = "127.0.0.1", port: int = 1234, instances_num = 1, dispatch_policy: str = "load", migration_backend = "rpc", model = "facebook/opt-125m", max_model_len: int = 2048): command = ( f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 " diff --git a/tests/unit_test/entrypoints/test_llumnix_utils.py b/tests/unit_test/entrypoints/test_utils.py similarity index 81% rename from tests/unit_test/entrypoints/test_llumnix_utils.py rename to tests/unit_test/entrypoints/test_utils.py index aad1b7de..d4787b58 100644 --- a/tests/unit_test/entrypoints/test_llumnix_utils.py +++ b/tests/unit_test/entrypoints/test_utils.py @@ -16,13 +16,13 @@ import ray from llumnix.arg_utils import EngineManagerArgs -from llumnix.entrypoints.llumnix_utils import (get_ip_address, - launch_ray_cluster, - init_manager, - retry_manager_method_sync, - retry_manager_method_async) +from llumnix.entrypoints.utils import (get_ip_address, + launch_ray_cluster, + init_manager, + retry_manager_method_sync, + retry_manager_method_async) from llumnix.llm_engine_manager import MANAGER_ACTOR_NAME -from llumnix.queue.utils import get_output_queue_server +from llumnix.queue.utils import init_output_queue_server # pylint: disable=unused-import from tests.conftest import setup_ray_env @@ -46,7 +46,7 @@ def test_init_manager(setup_ray_env): def test_init_zmq(setup_ray_env): ip = '127.0.0.1' port = 1234 - request_output_queue = get_output_queue_server(ip, port, 'zmq') + request_output_queue = init_output_queue_server(ip, port, 'zmq') assert request_output_queue is not None def test_retry_manager_method_sync(setup_ray_env): diff --git a/tests/unit_test/entrypoints/vllm/api_server_manager.py b/tests/unit_test/entrypoints/vllm/api_server_manager.py index ee1bf473..0fd6f64d 100644 --- a/tests/unit_test/entrypoints/vllm/api_server_manager.py +++ b/tests/unit_test/entrypoints/vllm/api_server_manager.py @@ -23,8 +23,8 @@ from llumnix.arg_utils import EngineManagerArgs from llumnix.server_info import ServerInfo, RequestTimestamps from llumnix.utils import random_uuid -from llumnix.queue.utils import get_output_queue_server, get_output_queue_client, QueueType - +from llumnix.queue.utils import init_output_queue_server, init_output_queue_client, QueueType +from llumnix.entrypoints.utils import LlumnixEntrypointsContext app = llumnix.entrypoints.vllm.api_server.app engine_manager = None @@ -36,7 +36,7 @@ class MockLLMEngineManager: def __init__(self, output_queue_type: QueueType): self._num_generates = 0 self._num_aborts = 0 - self.request_output_queue = get_output_queue_client(output_queue_type) + self.request_output_queue = init_output_queue_client(output_queue_type) async def generate(self, request_id, server_info, *args, **kwargs): self._num_generates += 1 @@ -73,18 +73,17 @@ def stats() -> Response: output_queue_type = QueueType(args.output_queue_type) engine_manager = init_manager(output_queue_type) - llumnix.entrypoints.vllm.api_server.engine_manager = engine_manager - + llumnix.entrypoints.vllm.api_server.llumnix_context = LlumnixEntrypointsContext() + llumnix.entrypoints.vllm.api_server.llumnix_context.engine_manager = engine_manager ip = '127.0.0.1' port = 1234 - llumnix.entrypoints.vllm.api_server.request_output_queue = \ - get_output_queue_server(ip, port, output_queue_type) - + llumnix.entrypoints.vllm.api_server.llumnix_context.request_output_queue = \ + init_output_queue_server(ip, port, output_queue_type) ray_queue_server = None if output_queue_type == QueueType.RAYQUEUE: - ray_queue_server = llumnix.entrypoints.vllm.api_server.request_output_queue + ray_queue_server = llumnix.entrypoints.vllm.api_server.llumnix_context.request_output_queue server_info = ServerInfo(random_uuid(), output_queue_type, ray_queue_server, ip, port) - llumnix.entrypoints.vllm.api_server.server_info = server_info + llumnix.entrypoints.vllm.api_server.llumnix_context.server_info = server_info uvicorn.run( app, diff --git a/tests/unit_test/queue/utils.py b/tests/unit_test/queue/utils.py index e1018491..d1e508fd 100644 --- a/tests/unit_test/queue/utils.py +++ b/tests/unit_test/queue/utils.py @@ -13,12 +13,12 @@ from llumnix.utils import random_uuid from llumnix.server_info import ServerInfo -from llumnix.queue.utils import get_output_queue_server, QueueType +from llumnix.queue.utils import init_output_queue_server, QueueType def request_output_queue_server(output_queue_type: QueueType): ip = '127.0.0.1' port = 1234 - output_queue = get_output_queue_server(ip, port, output_queue_type) + output_queue = init_output_queue_server(ip, port, output_queue_type) server_id = random_uuid() server_info = ServerInfo(server_id, output_queue_type, output_queue, ip, port) return output_queue, server_info