diff --git a/caikit_tgis_backend/tgis_backend.py b/caikit_tgis_backend/tgis_backend.py index ad67e26..6d683f3 100644 --- a/caikit_tgis_backend/tgis_backend.py +++ b/caikit_tgis_backend/tgis_backend.py @@ -21,10 +21,21 @@ # Third Party import grpc +# Since fastapi is optional in caikit, it may not be available +try: + # Third Party + import fastapi + + HAVE_FASTAPI = True +except ImportError: + HAVE_FASTAPI = False + fastapi = None + # First Party from caikit.core.exceptions import error_handler from caikit.core.module_backends.backend_types import register_backend_type from caikit.core.module_backends.base import BackendBase +from caikit.interfaces.runtime.data_model import RuntimeServerContextType import alog # Local @@ -54,6 +65,10 @@ class TGISBackend(BackendBase): TGIS_LOCAL_GRPC_PORT = 50055 TGIS_LOCAL_HTTP_PORT = 3000 + # HTTP Header / gRPC Metadata key used to identify a route override in an + # inbound request context + ROUTE_INFO_HEADER_KEY = "x-route-info" + ## Backend Interface ## backend_type = "TGIS" @@ -164,6 +179,24 @@ def stop(self): log.debug("Unloading model %s on stop", model_id) self.unload_model(model_id) + def handle_runtime_context( + self, + model_id: str, + runtime_context: RuntimeServerContextType, + ): + """Handle the runtime context for a request for the given model""" + if route_info := self.get_route_info(runtime_context): + log.debug( + " Registering remote model connection with context " + "override: 'hostname: %s'", + route_info, + ) + self.register_model_connection( + model_id, + {"hostname": route_info}, + fill_with_defaults=True, + ) + ## Backend user interface ## def get_connection( @@ -312,6 +345,34 @@ def model_loaded(self) -> bool: self._managed_tgis is not None and self._managed_tgis.is_ready() ) + @classmethod + def get_route_info( + cls, + context: Optional[RuntimeServerContextType], + ) -> Optional[str]: + """Get the string value of the x-route-info header/metadata if present + + Args: + context (Optional[RuntimeServerContextType]): The grpc or fastapi + request context + + Returns: + route_info (Optional[str]): The header/metadata value if present, + otherwise None + """ + if context is None: + return context + if isinstance(context, grpc.ServicerContext): + return dict(context.invocation_metadata()).get(cls.ROUTE_INFO_HEADER_KEY) + if HAVE_FASTAPI and isinstance(context, fastapi.Request): + return context.headers.get(cls.ROUTE_INFO_HEADER_KEY) + error.log_raise( + "", + TypeError(f"context is of an unsupported type: {type(context)}"), + ) + + ## Implementation Details ## + def _test_connection( self, model_conn: Optional[TGISConnection], timeout: Optional[float] = None ) -> Optional[TGISConnection]: diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index 29cf9a7..247b790 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -18,13 +18,14 @@ # Standard from copy import deepcopy from dataclasses import asdict -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from unittest import mock import os import tempfile import time # Third Party +import fastapi import grpc import pytest import tls_test_tools @@ -96,6 +97,18 @@ def mock_tgis_fixture(): mock_tgis.stop() +class TestServicerContext: + """ + A dummy class for mimicking ServicerContext invocation metadata storage. + """ + + def __init__(self, metadata: Dict[str, Union[str, bytes]]): + self.metadata = metadata + + def invocation_metadata(self): + return list(self.metadata.items()) + + ## Conn Config ################################################################# @@ -904,3 +917,74 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure): conn = tgis_be.get_connection(model_id) conn.test_connection() conn.test_connection(timeout=1) + + +@pytest.mark.parametrize( + argnames=["context", "route_info"], + argvalues=[ + ( + fastapi.Request( + { + "type": "http", + "headers": [ + (TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), b"sometext") + ], + } + ), + "sometext", + ), + ( + fastapi.Request( + {"type": "http", "headers": [(b"route-info", b"sometext")]} + ), + None, + ), + ( + TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "sometext"}), + "sometext", + ), + ( + TestServicerContext({"route-info": "sometext"}), + None, + ), + ("should raise TypeError", None), + (None, None), + # Uncertain how to create a grpc.ServicerContext object + ], +) +def test_get_route_info(context, route_info: Optional[str]): + if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): + with pytest.raises(TypeError): + TGISBackend.get_route_info(context) + else: + actual_route_info = TGISBackend.get_route_info(context) + assert actual_route_info == route_info + + +def test_handle_runtime_context_with_route_info(): + """Test that with route info present, handle_runtime_context updates the + model connection + """ + route_info = "sometext" + context = fastapi.Request( + { + "type": "http", + "headers": [ + (TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), route_info.encode("utf-8")) + ], + } + ) + + tgis_be = TGISBackend( + { + "connection": {"hostname": "foobar:1234"}, + "test_connections": False, + } + ) + assert not tgis_be._model_connections + + # Handle the connection and make sure model_connections is updated + model_id = "my-model" + tgis_be.handle_runtime_context(model_id, context) + assert model_id in tgis_be._model_connections + assert (conn := tgis_be.get_connection(model_id)) and conn.hostname == route_info diff --git a/tox.ini b/tox.ini index a94c982..ae6a5db 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,7 @@ deps = grpcio-tools>=1.35.0,<2.0 wheel>=0.38.4 Flask>=2.2.3,<3 + fastapi[all]>=0.100,<1 passenv = LOG_LEVEL LOG_FILTERS