diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py new file mode 100644 index 0000000..346b461 --- /dev/null +++ b/python/starpoint/embedding.py @@ -0,0 +1,75 @@ +import logging +from typing import Any, Dict, List, Optional +from uuid import UUID + +import requests + +from starpoint._utils import ( + _build_header, + _check_collection_identifier_collision, + _validate_host, +) +from starpoint.enums import EmbeddingModel + + +LOGGER = logging.getLogger(__name__) + +# Host +EMBEDDING_URL = "https://embedding.starpoint.ai" + +# Endpoints +EMBED_PATH = "/api/v1/embed" + +# Error and warning messages +SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid." + + +class EmbeddingClient(object): + """Client for the embedding endpoints.""" + + def __init__(self, api_key: UUID, host: Optional[str] = None): + if host is None: + host = EMBEDDING_URL + + self.host = _validate_host(host) + self.api_key = api_key + + def embed( + self, + text: List[str], + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes some text and creates an embedding against a model in starpoint. + + Args: + text: List of strings to create embeddings from. + model: A choice of + + Returns: + dict: Result with multiple lists of embeddings, matching the number of requested strings to + create embeddings from. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + request_data = dict(text=text, model=model.value) + try: + response = requests.post( + url=f"{self.host}{EMBED_PATH}", + json=request_data, + headers=_build_header( + api_key=self.api_key, + additional_headers={"Content-Type": "application/json"}, + ), + ) + except requests.exceptions.SSLError as e: + LOGGER.error(SSL_ERROR_MSG) + raise + + if not response.ok: + LOGGER.error( + f"Request failed with status code {response.status_code} " + f"and the following message:\n{response.text}" + ) + return {} + return response.json() diff --git a/python/starpoint/enums.py b/python/starpoint/enums.py new file mode 100644 index 0000000..eef8972 --- /dev/null +++ b/python/starpoint/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class EmbeddingModel(Enum): + MINI6 = "MINI6" + MINI12 = "MINI12" diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py new file mode 100644 index 0000000..1d6eb70 --- /dev/null +++ b/python/tests/test_embedding.py @@ -0,0 +1,81 @@ +from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from requests.exceptions import SSLError + +from starpoint import embedding +from starpoint.enums import EmbeddingModel + + +@pytest.fixture(scope="session") +def api_uuid() -> UUID: + return uuid4() + + +@pytest.fixture(scope="session") +@patch("starpoint._utils._check_host_health") +def mock_embedding_client( + host_health_mock: MagicMock, api_uuid: UUID +) -> embedding.EmbeddingClient: + return embedding.EmbeddingClient(api_uuid) + + +def test_embedding_default_init( + mock_embedding_client: embedding.EmbeddingClient, api_uuid: UUID +): + assert mock_embedding_client.host + assert mock_embedding_client.host == embedding.EMBEDDING_URL + assert mock_embedding_client.api_key == api_uuid + + +@patch("starpoint.embedding._validate_host") +def test_embedding_init_non_default_host( + mock_host_validator: MagicMock, api_uuid: UUID +): + test_host = "http://www.example.com" + test_embedding_client = embedding.EmbeddingClient(api_key=api_uuid, host=test_host) + + mock_host_validator.assert_called_once_with(test_host) + # This assert needs to be after assert_called_once_with to make sure it doesn't confound the result + assert test_embedding_client.host == mock_host_validator() + assert test_embedding_client.api_key == api_uuid + + +@patch("starpoint.embedding.requests") +def test_embedding_embed_not_200( + requests_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, + monkeypatch: MonkeyPatch, +): + requests_mock.post().ok = False + + expected_json = {} + + logger_mock = MagicMock() + monkeypatch.setattr(embedding, "LOGGER", logger_mock) + + actual_json = mock_embedding_client.embed(["asdf"], EmbeddingModel.MINI6) + + requests_mock.post.assert_called() + logger_mock.error.assert_called_once() + assert actual_json == expected_json + + +@patch("starpoint.embedding.requests") +def test_embedding_embed_SSLError( + requests_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, + monkeypatch: MonkeyPatch, +): + requests_mock.exceptions.SSLError = SSLError + requests_mock.post.side_effect = SSLError("mock exception") + + logger_mock = MagicMock() + monkeypatch.setattr(embedding, "LOGGER", logger_mock) + + with pytest.raises(SSLError, match="mock exception"): + mock_embedding_client.embed(["asdf"], EmbeddingModel.MINI6) + + logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG)