Skip to content

Commit

Permalink
Refactor common code into _test_connection() and _safely_update_state()
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
  • Loading branch information
mynhardtburger committed Jun 7, 2024
1 parent a9ae9ca commit f52f7d2
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,9 @@ def __init__(self, config: Optional[dict] = None):
model_id,
)
if self._test_connections:
try:
model_conn.test_connection(timeout=self._connect_timeout)
except grpc.RpcError as err:
log.warning(
"<TGB95244222W>",
"Unable to connect to model %s: %s",
model_id,
err,
exc_info=True,
)
model_conn = None
model_conn = self._test_connection(model_conn, self._connect_timeout)
if model_conn is not None:
self._model_connections[model_id] = model_conn
self._safely_update_state(model_id, model_conn)

# We manage a local TGIS instance if there are no remote connections
# specified as either a valid base connection or remote_connections
Expand Down Expand Up @@ -178,9 +168,9 @@ def get_connection(
if not model_conn and create and not self.local_tgis and conn_cfg:
model_conn = TGISConnection.from_config(model_id, conn_cfg)
if self._test_connections:
self._test_connection(model_conn)
model_conn = self._test_connection(model_conn)
if model_conn is not None:
self._safely_update_state(model_id, model_conn, conn_cfg)
self._safely_update_state(model_id, model_conn)

return model_conn

Expand All @@ -202,26 +192,27 @@ def register_model_connection(
with defaults from the TGISBackend's config connection.
"""
if model_id in self._model_connections:
# Model connection exists --> do nothing
return
return # Model connection exists --> do nothing

# Create model connection...
# Craft new connection config
new_conn_cfg = {}
if conn_cfg is None:
new_conn_conf = self._base_connection_cfg
model_conn = TGISConnection.from_config(model_id, new_conn_conf)
new_conn_cfg = self._base_connection_cfg
else:
new_conn_conf: Dict[str, Any] = (
self._base_connection_cfg if fill_with_defaults else {}
)
new_conn_conf.update(conn_cfg)
model_conn = TGISConnection.from_config(model_id, new_conn_conf)
if fill_with_defaults:
new_conn_cfg = self._base_connection_cfg
new_conn_cfg.update(conn_cfg)

# Create model connection
model_conn = TGISConnection.from_config(model_id, new_conn_cfg)

error.value_check("<TGB81270235E>", model_conn is not None)

# Register model connection
if self._test_connections:
self._test_connection(model_conn)
model_conn = self._test_connection(model_conn)
if model_conn is not None:
self._safely_update_state(model_id, model_conn, new_conn_conf)
self._safely_update_state(model_id, model_conn, new_conn_cfg)

def get_client(self, model_id: str) -> generation_pb2_grpc.GenerationServiceStub:
model_conn = self.get_connection(model_id)
Expand Down Expand Up @@ -291,12 +282,17 @@ def model_loaded(self) -> bool:
self._managed_tgis is not None and self._managed_tgis.is_ready()
)

def _test_connection(self, model_conn: Optional[TGISConnection]):
def _test_connection(
self, model_conn: Optional[TGISConnection], timeout: Optional[float] = None
) -> Optional[TGISConnection]:
"""
Returns the TGISConnection if successful, else returns None.
"""
if model_conn is None:
return

try:
model_conn.test_connection()
model_conn.test_connection(timeout)
except grpc.RpcError as err:
log.warning(
"<TGB10601575W>",
Expand All @@ -307,6 +303,8 @@ def _test_connection(self, model_conn: Optional[TGISConnection]):
)
model_conn = None

return model_conn

def _safely_update_state(
self,
model_id: str,
Expand Down

0 comments on commit f52f7d2

Please sign in to comment.