From 1c32a72003de340dd41dd1d9e099e91e61cac04a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 18 Jun 2024 16:54:56 -0600 Subject: [PATCH 1/5] HandleRuntimeContext: Port get_route_info from caikit-nlp Since the context parsing needs to happen in the backend directly, this needs to live here now Signed-off-by: Gabe Goodhart --- caikit_tgis_backend/tgis_backend.py | 43 ++++++++++++++++++++++ tests/test_tgis_backend.py | 57 ++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/caikit_tgis_backend/tgis_backend.py b/caikit_tgis_backend/tgis_backend.py index ad67e26..add5c8b 100644 --- a/caikit_tgis_backend/tgis_backend.py +++ b/caikit_tgis_backend/tgis_backend.py @@ -21,10 +21,21 @@ # Third Party import grpc +# Since fastapi is optional in caikit, it may not be available +try: + # Third Party + import fastapi + + HAVE_FASTAPI = True +except ImportError: + HAVE_FASTAPI = False + fastapi = None + # First Party from caikit.core.exceptions import error_handler from caikit.core.module_backends.backend_types import register_backend_type from caikit.core.module_backends.base import BackendBase +from caikit.interfaces.runtime.data_model import RuntimeServerContextType import alog # Local @@ -54,6 +65,10 @@ class TGISBackend(BackendBase): TGIS_LOCAL_GRPC_PORT = 50055 TGIS_LOCAL_HTTP_PORT = 3000 + # HTTP Header / gRPC Metadata key used to identify a route override in an + # inbound request context + ROUTE_INFO_HEADER_KEY = "x-route-info" + ## Backend Interface ## backend_type = "TGIS" @@ -312,6 +327,34 @@ def model_loaded(self) -> bool: self._managed_tgis is not None and self._managed_tgis.is_ready() ) + @classmethod + def get_route_info( + cls, + context: Optional[RuntimeServerContextType], + ) -> Optional[str]: + """Get the string value of the x-route-info header/metadata if present + + Args: + context (Optional[RuntimeServerContextType]): The grpc or fastapi + request context + + Returns: + route_info (Optional[str]): The header/metadata value if present, + otherwise None + """ + if context is None: + return context + if isinstance(context, grpc.ServicerContext): + return dict(context.invocation_metadata()).get(cls.ROUTE_INFO_HEADER_KEY) + if HAVE_FASTAPI and isinstance(context, fastapi.Request): + return context.headers.get(cls.ROUTE_INFO_HEADER_KEY) + error.log_raise( + "", + TypeError(f"context is of an unsupported type: {type(context)}"), + ) + + ## Implementation Details ## + def _test_connection( self, model_conn: Optional[TGISConnection], timeout: Optional[float] = None ) -> Optional[TGISConnection]: diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index 29cf9a7..f1b08f6 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -18,13 +18,14 @@ # Standard from copy import deepcopy from dataclasses import asdict -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from unittest import mock import os import tempfile import time # Third Party +import fastapi import grpc import pytest import tls_test_tools @@ -96,6 +97,18 @@ def mock_tgis_fixture(): mock_tgis.stop() +class TestServicerContext: + """ + A dummy class for mimicking ServicerContext invocation metadata storage. + """ + + def __init__(self, metadata: dict[str, Union[str, bytes]]): + self.metadata = metadata + + def invocation_metadata(self): + return list(self.metadata.items()) + + ## Conn Config ################################################################# @@ -904,3 +917,45 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure): conn = tgis_be.get_connection(model_id) conn.test_connection() conn.test_connection(timeout=1) + + +@pytest.mark.parametrize( + argnames=["context", "route_info"], + argvalues=[ + ( + fastapi.Request( + { + "type": "http", + "headers": [ + (TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), b"sometext") + ], + } + ), + "sometext", + ), + ( + fastapi.Request( + {"type": "http", "headers": [(b"route-info", b"sometext")]} + ), + None, + ), + ( + TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "sometext"}), + "sometext", + ), + ( + TestServicerContext({"route-info": "sometext"}), + None, + ), + ("should raise ValueError", None), + (None, None), + # Uncertain how to create a grpc.ServicerContext object + ], +) +def test_get_route_info(context, route_info: Optional[str]): + if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): + with pytest.raises(TypeError): + TGISBackend.get_route_info(context) + else: + actual_route_info = TGISBackend.get_route_info(context) + assert actual_route_info == route_info From 56d95dbadfa4fa3fbe192fad8ee38d368aed848a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 18 Jun 2024 16:55:15 -0600 Subject: [PATCH 2/5] HandleRuntimeContext: Add fastapi as a test dependency for context parsing Signed-off-by: Gabe Goodhart --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index a94c982..ae6a5db 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,7 @@ deps = grpcio-tools>=1.35.0,<2.0 wheel>=0.38.4 Flask>=2.2.3,<3 + fastapi[all]>=0.100,<1 passenv = LOG_LEVEL LOG_FILTERS From f671f4c67aca87e6461d971426cddfe05b44c6fb Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 18 Jun 2024 16:55:39 -0600 Subject: [PATCH 3/5] HandleRuntimeContext: Implement handle_runtime_context Signed-off-by: Gabe Goodhart --- caikit_tgis_backend/tgis_backend.py | 18 ++++++++++++++++++ tests/test_tgis_backend.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/caikit_tgis_backend/tgis_backend.py b/caikit_tgis_backend/tgis_backend.py index add5c8b..6d683f3 100644 --- a/caikit_tgis_backend/tgis_backend.py +++ b/caikit_tgis_backend/tgis_backend.py @@ -179,6 +179,24 @@ def stop(self): log.debug("Unloading model %s on stop", model_id) self.unload_model(model_id) + def handle_runtime_context( + self, + model_id: str, + runtime_context: RuntimeServerContextType, + ): + """Handle the runtime context for a request for the given model""" + if route_info := self.get_route_info(runtime_context): + log.debug( + " Registering remote model connection with context " + "override: 'hostname: %s'", + route_info, + ) + self.register_model_connection( + model_id, + {"hostname": route_info}, + fill_with_defaults=True, + ) + ## Backend user interface ## def get_connection( diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index f1b08f6..f64310e 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -959,3 +959,32 @@ def test_get_route_info(context, route_info: Optional[str]): else: actual_route_info = TGISBackend.get_route_info(context) assert actual_route_info == route_info + + +def test_handle_runtime_context_with_route_info(): + """Test that with route info present, handle_runtime_context updates the + model connection + """ + route_info = "sometext" + context = fastapi.Request( + { + "type": "http", + "headers": [ + (TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), route_info.encode("utf-8")) + ], + } + ) + + tgis_be = TGISBackend( + { + "connection": {"hostname": "foobar:1234"}, + "test_connections": False, + } + ) + assert not tgis_be._model_connections + + # Handle the connection and make sure model_connections is updated + model_id = "my-model" + tgis_be.handle_runtime_context(model_id, context) + assert model_id in tgis_be._model_connections + assert (conn := tgis_be.get_connection(model_id)) and conn.hostname == route_info From 7a7184985ff27525954ca7831a8d5f7c9d488fdd Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 18 Jun 2024 17:04:22 -0600 Subject: [PATCH 4/5] HandleRuntimeContext: Fix dict[...] -> Dict[...] for compatibility Signed-off-by: Gabe Goodhart --- tests/test_tgis_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index f64310e..811a4f2 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -102,7 +102,7 @@ class TestServicerContext: A dummy class for mimicking ServicerContext invocation metadata storage. """ - def __init__(self, metadata: dict[str, Union[str, bytes]]): + def __init__(self, metadata: Dict[str, Union[str, bytes]]): self.metadata = metadata def invocation_metadata(self): From c6c52c5cb0d9abbaff818eb69c165be0135faed3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 14:59:14 -0600 Subject: [PATCH 5/5] HandleRuntimeContext: Minor typo fix in test Signed-off-by: Gabe Goodhart --- tests/test_tgis_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index 811a4f2..247b790 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -947,7 +947,7 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure): TestServicerContext({"route-info": "sometext"}), None, ), - ("should raise ValueError", None), + ("should raise TypeError", None), (None, None), # Uncertain how to create a grpc.ServicerContext object ],