Skip to content

Commit

Permalink
use components.lifecycle to handle server init using manager
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Feb 19, 2025
1 parent 888a98a commit 6d216fd
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 93 deletions.
8 changes: 5 additions & 3 deletions shortfin/python/shortfin_apps/llm/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
The shortfin GenerateService and SystemManager are configured and added as app.state.services and app.state.sysman respectively.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from .lifecycle_hooks import lifespan
from .routes import application_router, generation_router
from fastapi import FastAPI


def add_routes(app: FastAPI):
Expand All @@ -27,7 +29,7 @@ def add_middleware(app: FastAPI):
return app


def get_app() -> FastAPI:
def get_app(lifespan) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app = add_routes(app)
app = add_middleware(app)
Expand Down
93 changes: 93 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Implements a context manager that configures a shortfin llm server from a namespace mirroring server.py's commandline args, and exposes a context manager interface such that we can do:
```python
def lifecycle(app: FastApi):
with lifecycle_manager(args) as man:
yield
```
"""


from .config_struct import ModelParams, ServerParams
from .manager import SystemManager
from .service import GenerateService
from .tokenizer import Tokenizer
from typing import TYPE_CHECKING
from fastapi import FastAPI


from contextlib import asynccontextmanager
import logging


def get_eos_from_tokenizer_config(json_path):
import json

with open(json_path, "rt") as f:
json_text = f.read()
config = json.loads(json_text)
return config["eos_token"]


class ShortfinLlmLifecycleManager:
def __init__(self, args):
# Load server configuration with priority: command line > config file > defaults
server_params = ServerParams.load(
args.server_config if hasattr(args, "server_config") else None
)
server_params.update_from_args(args)

# Setup system (configure devices, etc).
sysman = SystemManager(
device=args.device,
device_ids=server_params.device_ids,
async_allocs=server_params.amdgpu_async_allocations,
amdgpu_allocators=server_params.amdgpu_allocators,
)

# Setup each service we are hosting.
eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json)
tokenizer = Tokenizer.from_tokenizer_json_file(
args.tokenizer_json, eos_token=eos_token
)
model_params = ModelParams.load_json(args.model_config)
service = GenerateService(
name="default",
sysman=sysman,
tokenizer=tokenizer,
model_params=model_params,
server_params=server_params,
program_isolation=server_params.program_isolation,
)
service.load_inference_module(args.vmfb)
service.load_inference_parameters(*args.parameters, parameter_scope="model")
self.sysman = sysman
self.services = {"default": service}

def __enter__(self):
self.sysman.start()
for service_name, service in self.services.items():
logging.info("Initializing service '%s': %r", service_name, service)
service.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
for service_name, service in self.services.items():
logging.info("Shutting down service '%s'", service_name)
service.shutdown()
self.sysman.shutdown()
return False

@asynccontextmanager
async def fastapi_lifespan(self, app: FastAPI):
with self:
app.state.sysman = self.sysman
app.state.services = self.services
yield
37 changes: 0 additions & 37 deletions shortfin/python/shortfin_apps/llm/lifecycle_hooks.py

This file was deleted.

3 changes: 1 addition & 2 deletions shortfin/python/shortfin_apps/llm/routes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

from ..components.generate import ClientGenerateBatchProcess
from ..components.io_struct import GenerateReqInput
from ..lifecycle_hooks import services

generation_router = APIRouter()


@generation_router.post("/generate")
@generation_router.put("/generate")
async def generate_request(gen_req: GenerateReqInput, request: Request):
service = services["default"]
service = request.app.state.services["default"]
gen_req.post_init()
responder = FastAPIResponder(request)
ClientGenerateBatchProcess(service, gen_req, responder).launch()
Expand Down
55 changes: 4 additions & 51 deletions shortfin/python/shortfin_apps/llm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@
from shortfin import ProgramIsolation
import uvicorn

from . import lifecycle_hooks
from .application import get_app
from .components.config_struct import ModelParams, ServerParams
from .components.manager import SystemManager
from .components.service import GenerateService
from .components.tokenizer import Tokenizer
from .components.lifecycle import ShortfinLlmLifecycleManager


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,50 +49,6 @@
}


def get_eos_from_tokenizer_config(json_path):
import json

with open(json_path, "rt") as f:
json_text = f.read()
config = json.loads(json_text)
return config["eos_token"]


def configure(args) -> SystemManager:
# Load server configuration with priority: command line > config file > defaults
server_params = ServerParams.load(
args.server_config if hasattr(args, "server_config") else None
)
server_params.update_from_args(args)

# Setup system (configure devices, etc).
sysman = SystemManager(
device=args.device,
device_ids=server_params.device_ids,
async_allocs=server_params.amdgpu_async_allocations,
amdgpu_allocators=server_params.amdgpu_allocators,
)

# Setup each service we are hosting.
eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json)
tokenizer = Tokenizer.from_tokenizer_json_file(
args.tokenizer_json, eos_token=eos_token
)
model_params = ModelParams.load_json(args.model_config)
sm = GenerateService(
name="default",
sysman=sysman,
tokenizer=tokenizer,
model_params=model_params,
server_params=server_params,
program_isolation=server_params.program_isolation,
)
sm.load_inference_module(args.vmfb)
sm.load_inference_parameters(*args.parameters, parameter_scope="model")
lifecycle_hooks.services[sm.name] = sm
return sysman


def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
Expand Down Expand Up @@ -194,10 +146,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
args.tokenizer_json.stem + "_config.json"
)
args.tokenizer_config_json = inferred_tokenizer_config_path
lifecycle_hooks.sysman = configure(args)

lifecycle_manager = ShortfinLlmLifecycleManager(args)

uvicorn.run(
get_app(),
get_app(lifecycle_manager.fastapi_lifespan),
host=args.host,
port=args.port,
log_config=log_config,
Expand Down

0 comments on commit 6d216fd

Please sign in to comment.