diff --git a/caikit_tgis_backend/tgis_backend.py b/caikit_tgis_backend/tgis_backend.py index aa6238b..ad67e26 100644 --- a/caikit_tgis_backend/tgis_backend.py +++ b/caikit_tgis_backend/tgis_backend.py @@ -190,6 +190,8 @@ def register_model_connection( """ Register a remote model connection. + If a local TGIS instance is maintained, do nothing. + If the model connection is already registered, do nothing. Otherwise create and register the model connection using the TGISBackend's @@ -198,7 +200,20 @@ def register_model_connection( If `fill_with_defaults == True`, missing keys in `conn_cfg` will be populated with defaults from the TGISBackend's config connection. """ + # Don't attempt registering a remote model if running local TGIS instance + if self.local_tgis: + log.debug( + " Running a local TGIS instance... won't register a " + "remote model connection" + ) + return + if model_id in self._model_connections: + log.debug( + " remote model connection for model %s already exists... " + "nothing to register", + model_id, + ) return # Model connection exists --> do nothing # Craft new connection config @@ -211,6 +226,10 @@ def register_model_connection( new_conn_cfg.update(conn_cfg) # Create model connection + error.value_check( + "", new_conn_cfg, "TGISConnection config is empty" + ) + model_conn = TGISConnection.from_config(model_id, new_conn_cfg) error.value_check("", model_conn is not None) @@ -219,6 +238,10 @@ def register_model_connection( if self._test_connections: model_conn = self._test_connection(model_conn) if model_conn is not None: + log.debug( + " Registering new remote model connection for %s", + model_id, + ) self._safely_update_state(model_id, model_conn, new_conn_cfg) def get_client(self, model_id: str) -> generation_pb2_grpc.GenerationServiceStub: diff --git a/tests/test_tgis_backend.py b/tests/test_tgis_backend.py index b255f36..29cf9a7 100644 --- a/tests/test_tgis_backend.py +++ b/tests/test_tgis_backend.py @@ -771,6 +771,24 @@ def test_tgis_backend_register_model_connection( assert tgis_be._base_connection_cfg == backup_base_cfg +def test_tgis_backend_register_model_connection_local(): + tgis_be = TGISBackend() + + # Confirm marked as local TGIS instance with no base connection config + assert tgis_be.local_tgis + assert not tgis_be._base_connection_cfg + assert not tgis_be._model_connections + assert not tgis_be._remote_models_cfg + + # Register action should do nothing + tgis_be.register_model_connection("should do nothing") + + # Confirm nothing was done + assert not tgis_be._base_connection_cfg + assert not tgis_be._model_connections + assert not tgis_be._remote_models_cfg + + ## Failure Tests ###############################################################