Skip to content

Commit

Permalink
rename ServerInstance.start to start_full_fastapi_server and add a st…
Browse files Browse the repository at this point in the history
…art_generate_service contextmanager
  • Loading branch information
renxida committed Feb 13, 2025
1 parent 37f3ee2 commit fa10955
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def server(model_artifacts, request):
)

server_instance = ServerInstance(server_config)
server_instance.start()
server_instance.start_full_fastapi_server()
process, port = server_instance.process, server_instance.port
yield process, port

Expand Down
58 changes: 56 additions & 2 deletions app_tests/integration_tests/llm/server_management.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Handles server lifecycle and configuration."""
import socket
from contextlib import closing
from contextlib import closing, contextmanager
from dataclasses import dataclass
import subprocess
import time
Expand All @@ -11,6 +11,11 @@
from .device_settings import DeviceSettings
from .model_management import ModelArtifacts

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class ServerConfig:
Expand Down Expand Up @@ -70,7 +75,7 @@ def get_server_args(config: ServerConfig) -> list[str]:
f"--prefix_sharing_algorithm={config.prefix_sharing_algorithm}",
] + config.device_settings.server_flags

def start(self) -> None:
def start_full_fastapi_server(self) -> None:
"""Starts the server process."""
if self.process is not None:
raise RuntimeError("Server is already running")
Expand All @@ -83,6 +88,55 @@ def start(self) -> None:
self.process = subprocess.Popen(cmd)
self.wait_for_ready()

@contextmanager
def start_generate_service(
model_artifacts, request
) -> "shortfin_apps.llm.components.service.GenerateService":
"""
like server, but no fastapi,
this yields a service object that gives access to shortfin while bypassing fastapi
use like so:
```
with instance.start_generate_service(model_artifacts, request) as service:
# run tests with service
...
```
"""

model_config = model_artifacts.model_config

server_config = ServerConfig(
artifacts=model_artifacts,
device_settings=model_config.device_settings,
prefix_sharing_algorithm=request.param.get("prefix_sharing", "none"),
)

from shortfin_apps.llm import server as server_module

argv = ServerInstance.get_server_args(server_config)
args = server_module.parse_args(argv)
server_module.sysman = server_module.configure(args)
sysman = server_module.sysman
services = server_module.services
# sysman.start()
try:
for service_name, service in sysman.services.items():
logging.info("Initializing service '%s': %r", service_name, service)
service.start()
except:
sysman.shutdown()
raise
yield sysman.services["default"]
try:
for service_name, service in services.items():
logging.info("Shutting down service '%s'", service_name)
service.shutdown()
finally:
sysman.shutdown()

def wait_for_ready(self, timeout: int = 30) -> None:
"""Waits for server to be ready and responding to health checks."""
if self.port is None:
Expand Down
2 changes: 1 addition & 1 deletion app_tests/integration_tests/llm/sglang/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def start_server(request, model_artifacts):
)

server_instance = ServerInstance(server_config)
server_instance.start()
server_instance.start_full_fastapi_server()
process, port = server_instance.process, server_instance.port

yield process, port
Expand Down
11 changes: 10 additions & 1 deletion app_tests/integration_tests/llm/shortfin/batching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path
import math

import pytest
import shortfin as sf
import shortfin.array as sfnp
from shortfin_apps.llm.components.service import (
Expand Down Expand Up @@ -54,6 +54,15 @@

from shortfin_apps.llm.components.messages import InferenceExecRequest, InferencePhase

pytestmark = pytest.mark.parametrize(
"model_artifacts,server",
[
["llama3.1_8b", {"prefix_sharing": "none"}],
["llama3.1_8b", {"prefix_sharing": "trie"}],
],
indirect=True,
)


def test_batch_sizes_same_inputs_same_outputs():
"""
Expand Down
37 changes: 7 additions & 30 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

import shortfin_apps.llm.server as server_module

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def pytest_addoption(parser):
parser.addoption(
Expand Down Expand Up @@ -61,7 +66,7 @@ def model_artifacts(tmp_path_factory, request, test_device):


@pytest.fixture(scope="module")
def server(model_artifacts, request):
def full_fastapi_server(model_artifacts, request):
"""Starts and manages the test server."""
model_config = model_artifacts.model_config

Expand All @@ -72,42 +77,14 @@ def server(model_artifacts, request):
)

server_instance = ServerInstance(server_config)
server_instance.start()
server_instance.start_full_fastapi_server()
process, port = server_instance.process, server_instance.port
yield process, port

process.terminate()
process.wait()


def generate_service(
model_artifacts, request
) -> "shortfin_apps.llm.components.service.GenerateService":
"""
like server, but no fastapi,
this yields a service object that gives access to shortfin while bypassing fastapi
"""

model_config = model_artifacts.model_config

server_config = ServerConfig(
artifacts=model_artifacts,
device_settings=model_config.device_settings,
prefix_sharing_algorithm=request.param.get("prefix_sharing", "none"),
)

argv = ServerInstance.get_server_args(server_config)
args = server_module.parse_args(argv)
server_module.sysman = server_module.configure(args)

# lifespan() is a context manager that calls GenerateService.start() and GenerateService.stop() on enter and exit
# GenerateService.start() will launch the batcher process.
# TODO: consider providing a version that yields the service without starting it
with server_module.lifespan(): # this would take care of cleanup automatically when pytest shuts down
yield server_module.service


@pytest.fixture(scope="module")
def encoded_prompt(model_artifacts: ModelArtifacts, request) -> list[int]:
tokenizer = Tokenizer.from_file(str(model_artifacts.tokenizer_path))
Expand Down

0 comments on commit fa10955

Please sign in to comment.