diff --git a/shortfin/python/shortfin_apps/llm/application.py b/shortfin/python/shortfin_apps/llm/application.py new file mode 100644 index 000000000..c6bcebe17 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/application.py @@ -0,0 +1,22 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from fastapi import FastAPI + +from .lifecycle_hooks import lifespan +from .routes import application_router, generation_router + + +def add_routes(app: FastAPI): + app.include_router(application_router) + app.include_router(generation_router) + return app + + +def get_app() -> FastAPI: + app = FastAPI(lifespan=lifespan) + app = add_routes(app) + return app diff --git a/shortfin/python/shortfin_apps/llm/lifecycle_hooks.py b/shortfin/python/shortfin_apps/llm/lifecycle_hooks.py new file mode 100644 index 000000000..89e4d7a5e --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/lifecycle_hooks.py @@ -0,0 +1,37 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from contextlib import asynccontextmanager +import logging +from typing import Any +from fastapi import FastAPI + +from .components.manager import SystemManager + +sysman: SystemManager +services: dict[str, Any] = {} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global sysman + global services + + sysman.start() + try: + for service_name, service in services.items(): + logging.info("Initializing service '%s': %r", service_name, service) + service.start() + except: + sysman.shutdown() + raise + yield + try: + for service_name, service in services.items(): + logging.info("Shutting down service '%s'", service_name) + service.shutdown() + finally: + sysman.shutdown() diff --git a/shortfin/python/shortfin_apps/llm/routes/__init__.py b/shortfin/python/shortfin_apps/llm/routes/__init__.py new file mode 100644 index 000000000..8232061f9 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/routes/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .application import application_router +from .generate import generation_router + +__all__ = ["application_router", "generation_router"] diff --git a/shortfin/python/shortfin_apps/llm/routes/application.py b/shortfin/python/shortfin_apps/llm/routes/application.py new file mode 100644 index 000000000..6c632ae35 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/routes/application.py @@ -0,0 +1,14 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from fastapi import APIRouter, Response + +application_router = APIRouter() + + +@application_router.get("/health") +async def health() -> Response: + return Response(status_code=200) diff --git a/shortfin/python/shortfin_apps/llm/routes/generate.py b/shortfin/python/shortfin_apps/llm/routes/generate.py new file mode 100644 index 000000000..7d160b218 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/routes/generate.py @@ -0,0 +1,25 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from fastapi import APIRouter, Request + +from shortfin.interop.fastapi import FastAPIResponder + +from ..components.generate import ClientGenerateBatchProcess +from ..components.io_struct import GenerateReqInput +from ..lifecycle_hooks import services + +generation_router = APIRouter() + + +@generation_router.post("/generate") +@generation_router.put("/generate") +async def generate_request(gen_req: GenerateReqInput, request: Request): + service = services["default"] + gen_req.post_init() + responder = FastAPIResponder(request) + ClientGenerateBatchProcess(service, gen_req, responder).launch() + return await responder.response diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 24794eff2..7f9413997 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -4,8 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any - import argparse import logging from pathlib import Path @@ -15,17 +13,11 @@ # Import first as it does dep checking and reporting. from shortfin import ProgramIsolation -from shortfin.interop.fastapi import FastAPIResponder - -from contextlib import asynccontextmanager - -from fastapi import FastAPI, Request, Response import uvicorn - -from .components.generate import ClientGenerateBatchProcess +from . import lifecycle_hooks +from .application import get_app from .components.config_struct import ModelParams, ServerParams -from .components.io_struct import GenerateReqInput from .components.manager import SystemManager from .components.service import GenerateService from .components.tokenizer import Tokenizer @@ -61,47 +53,6 @@ } -@asynccontextmanager -async def lifespan(app: FastAPI): - sysman.start() - try: - for service_name, service in services.items(): - logging.info("Initializing service '%s': %r", service_name, service) - service.start() - except: - sysman.shutdown() - raise - yield - try: - for service_name, service in services.items(): - logging.info("Shutting down service '%s'", service_name) - service.shutdown() - finally: - sysman.shutdown() - - -sysman: SystemManager -services: dict[str, Any] = {} -app = FastAPI(lifespan=lifespan) - - -@app.get("/health") -async def health() -> Response: - return Response(status_code=200) - - -async def generate_request(gen_req: GenerateReqInput, request: Request): - service = services["default"] - gen_req.post_init() - responder = FastAPIResponder(request) - ClientGenerateBatchProcess(service, gen_req, responder).launch() - return await responder.response - - -app.post("/generate")(generate_request) -app.put("/generate")(generate_request) - - def get_eos_from_tokenizer_config(json_path): import json @@ -142,13 +93,13 @@ def configure(args) -> SystemManager: ) sm.load_inference_module(args.vmfb) sm.load_inference_parameters(*args.parameters, parameter_scope="model") - services[sm.name] = sm + lifecycle_hooks.services[sm.name] = sm return sysman def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) + parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) parser.add_argument( "--root-path", @@ -243,11 +194,10 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): args.tokenizer_json.stem + "_config.json" ) args.tokenizer_config_json = inferred_tokenizer_config_path - global sysman - sysman = configure(args) + lifecycle_hooks.sysman = configure(args) uvicorn.run( - app, + get_app(), host=args.host, port=args.port, log_config=log_config,