diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index e83607584..9a3860b71 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -116,6 +116,14 @@ def add_parser_api_server(): default=['*'], help='A list of allowed http headers for cors') parser.add_argument('--proxy-url', type=str, default=None, help='The proxy url for api server.') + parser.add_argument('--max-concurrent-requests', + type=int, + default=None, + help='This refers to the number of concurrent requests that ' + 'the server can handle. The server is designed to process the ' + 'engine’s tasks once the maximum number of concurrent requests is ' + 'reached, regardless of any additional requests sent by clients ' + 'concurrently during that time. Default to None.') # common args ArgumentHelper.backend(parser) ArgumentHelper.log_level(parser) @@ -314,7 +322,8 @@ def api_server(args): ssl=args.ssl, proxy_url=args.proxy_url, max_log_len=args.max_log_len, - disable_fastapi_docs=args.disable_fastapi_docs) + disable_fastapi_docs=args.disable_fastapi_docs, + max_concurrent_requests=args.max_concurrent_requests) @staticmethod def api_client(args): diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index b61772f40..2fb17e4c6 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.middleware.base import BaseHTTPMiddleware from lmdeploy.archs import get_task from lmdeploy.messages import GenerationConfig, LogitsProcessor, PytorchEngineConfig, TurbomindEngineConfig @@ -908,6 +909,18 @@ async def startup_event(): print(f'Service registration failed: {e}') +class ConcurrencyLimitMiddleware(BaseHTTPMiddleware): + + def __init__(self, app: FastAPI, max_concurrent_requests: int): + super().__init__(app) + self.semaphore = asyncio.Semaphore(max_concurrent_requests) + + async def dispatch(self, request: Request, call_next): + async with self.semaphore: + response = await call_next(request) + return response + + def serve(model_path: str, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', @@ -925,6 +938,7 @@ def serve(model_path: str, proxy_url: Optional[str] = None, max_log_len: int = None, disable_fastapi_docs: bool = False, + max_concurrent_requests: Optional[int] = None, **kwargs): """An example to perform model inference through the command line interface. @@ -969,6 +983,11 @@ def serve(model_path: str, proxy_url (str): The proxy url to register the api_server. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited + max_concurrent_requests: This refers to the number of concurrent + requests that the server can handle. The server is designed to + process the engine’s tasks once the maximum number of concurrent + requests is reached, regardless of any additional requests sent by + clients concurrently during that time. Default to None. """ if os.getenv('TM_LOG_LEVEL') is None: os.environ['TM_LOG_LEVEL'] = log_level @@ -993,6 +1012,10 @@ def serve(model_path: str, allow_methods=allow_methods, allow_headers=allow_headers, ) + # Set the maximum number of concurrent requests + if max_concurrent_requests is not None: + app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) + if api_keys is not None: if isinstance(api_keys, str): api_keys = api_keys.split(',')