Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle runtime context #57

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,21 @@
# Third Party
import grpc

# Since fastapi is optional in caikit, it may not be available
try:
# Third Party
import fastapi

HAVE_FASTAPI = True
except ImportError:
HAVE_FASTAPI = False
fastapi = None

# First Party
from caikit.core.exceptions import error_handler
from caikit.core.module_backends.backend_types import register_backend_type
from caikit.core.module_backends.base import BackendBase
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
import alog

# Local
Expand Down Expand Up @@ -54,6 +65,10 @@ class TGISBackend(BackendBase):
TGIS_LOCAL_GRPC_PORT = 50055
TGIS_LOCAL_HTTP_PORT = 3000

# HTTP Header / gRPC Metadata key used to identify a route override in an
# inbound request context
ROUTE_INFO_HEADER_KEY = "x-route-info"

## Backend Interface ##

backend_type = "TGIS"
Expand Down Expand Up @@ -164,6 +179,24 @@ def stop(self):
log.debug("Unloading model %s on stop", model_id)
self.unload_model(model_id)

def handle_runtime_context(
self,
model_id: str,
runtime_context: RuntimeServerContextType,
):
"""Handle the runtime context for a request for the given model"""
if route_info := self.get_route_info(runtime_context):
log.debug(
"<TGB10705560D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self.register_model_connection(
model_id,
{"hostname": route_info},
fill_with_defaults=True,
)

## Backend user interface ##

def get_connection(
Expand Down Expand Up @@ -312,6 +345,34 @@ def model_loaded(self) -> bool:
self._managed_tgis is not None and self._managed_tgis.is_ready()
)

@classmethod
def get_route_info(
cls,
context: Optional[RuntimeServerContextType],
) -> Optional[str]:
"""Get the string value of the x-route-info header/metadata if present

Args:
context (Optional[RuntimeServerContextType]): The grpc or fastapi
request context

Returns:
route_info (Optional[str]): The header/metadata value if present,
otherwise None
"""
if context is None:
return context
if isinstance(context, grpc.ServicerContext):
return dict(context.invocation_metadata()).get(cls.ROUTE_INFO_HEADER_KEY)
if HAVE_FASTAPI and isinstance(context, fastapi.Request):
return context.headers.get(cls.ROUTE_INFO_HEADER_KEY)
error.log_raise(
"<TGB92615097E>",
TypeError(f"context is of an unsupported type: {type(context)}"),
)

## Implementation Details ##

def _test_connection(
self, model_conn: Optional[TGISConnection], timeout: Optional[float] = None
) -> Optional[TGISConnection]:
Expand Down
86 changes: 85 additions & 1 deletion tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
# Standard
from copy import deepcopy
from dataclasses import asdict
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from unittest import mock
import os
import tempfile
import time

# Third Party
import fastapi
import grpc
import pytest
import tls_test_tools
Expand Down Expand Up @@ -96,6 +97,18 @@ def mock_tgis_fixture():
mock_tgis.stop()


class TestServicerContext:
"""
A dummy class for mimicking ServicerContext invocation metadata storage.
"""

def __init__(self, metadata: Dict[str, Union[str, bytes]]):
self.metadata = metadata

def invocation_metadata(self):
return list(self.metadata.items())


## Conn Config #################################################################


Expand Down Expand Up @@ -904,3 +917,74 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure):
conn = tgis_be.get_connection(model_id)
conn.test_connection()
conn.test_connection(timeout=1)


@pytest.mark.parametrize(
argnames=["context", "route_info"],
argvalues=[
(
fastapi.Request(
{
"type": "http",
"headers": [
(TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), b"sometext")
],
}
),
"sometext",
),
(
fastapi.Request(
{"type": "http", "headers": [(b"route-info", b"sometext")]}
),
None,
),
(
TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "sometext"}),
"sometext",
),
(
TestServicerContext({"route-info": "sometext"}),
None,
),
("should raise TypeError", None),
(None, None),
# Uncertain how to create a grpc.ServicerContext object
],
)
def test_get_route_info(context, route_info: Optional[str]):
if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))):
with pytest.raises(TypeError):
TGISBackend.get_route_info(context)
else:
actual_route_info = TGISBackend.get_route_info(context)
assert actual_route_info == route_info


def test_handle_runtime_context_with_route_info():
"""Test that with route info present, handle_runtime_context updates the
model connection
"""
route_info = "sometext"
context = fastapi.Request(
{
"type": "http",
"headers": [
(TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), route_info.encode("utf-8"))
],
}
)

tgis_be = TGISBackend(
{
"connection": {"hostname": "foobar:1234"},
"test_connections": False,
}
)
assert not tgis_be._model_connections

# Handle the connection and make sure model_connections is updated
model_id = "my-model"
tgis_be.handle_runtime_context(model_id, context)
assert model_id in tgis_be._model_connections
assert (conn := tgis_be.get_connection(model_id)) and conn.hostname == route_info
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ deps =
grpcio-tools>=1.35.0,<2.0
wheel>=0.38.4
Flask>=2.2.3,<3
fastapi[all]>=0.100,<1
passenv =
LOG_LEVEL
LOG_FILTERS
Expand Down
Loading