Skip to content

Commit

Permalink
Python embed endpoint (#31)
Browse files Browse the repository at this point in the history
* Add support for embedding endpoint

* Add tests and clean up embedding code

* Fix return type hint
  • Loading branch information
FullMetalMeowchemist authored Sep 6, 2023
1 parent 7a7d00a commit 4d3fc74
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
75 changes: 75 additions & 0 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
from typing import Any, Dict, List, Optional
from uuid import UUID

import requests

from starpoint._utils import (
_build_header,
_check_collection_identifier_collision,
_validate_host,
)
from starpoint.enums import EmbeddingModel


LOGGER = logging.getLogger(__name__)

# Host
EMBEDDING_URL = "https://embedding.starpoint.ai"

# Endpoints
EMBED_PATH = "/api/v1/embed"

# Error and warning messages
SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid."


class EmbeddingClient(object):
"""Client for the embedding endpoints."""

def __init__(self, api_key: UUID, host: Optional[str] = None):
if host is None:
host = EMBEDDING_URL

self.host = _validate_host(host)
self.api_key = api_key

def embed(
self,
text: List[str],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some text and creates an embedding against a model in starpoint.
Args:
text: List of strings to create embeddings from.
model: A choice of
Returns:
dict: Result with multiple lists of embeddings, matching the number of requested strings to
create embeddings from.
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
request_data = dict(text=text, model=model.value)
try:
response = requests.post(
url=f"{self.host}{EMBED_PATH}",
json=request_data,
headers=_build_header(
api_key=self.api_key,
additional_headers={"Content-Type": "application/json"},
),
)
except requests.exceptions.SSLError as e:
LOGGER.error(SSL_ERROR_MSG)
raise

if not response.ok:
LOGGER.error(
f"Request failed with status code {response.status_code} "
f"and the following message:\n{response.text}"
)
return {}
return response.json()
6 changes: 6 additions & 0 deletions python/starpoint/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class EmbeddingModel(Enum):
MINI6 = "MINI6"
MINI12 = "MINI12"
81 changes: 81 additions & 0 deletions python/tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4

import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.exceptions import SSLError

from starpoint import embedding
from starpoint.enums import EmbeddingModel


@pytest.fixture(scope="session")
def api_uuid() -> UUID:
return uuid4()


@pytest.fixture(scope="session")
@patch("starpoint._utils._check_host_health")
def mock_embedding_client(
host_health_mock: MagicMock, api_uuid: UUID
) -> embedding.EmbeddingClient:
return embedding.EmbeddingClient(api_uuid)


def test_embedding_default_init(
mock_embedding_client: embedding.EmbeddingClient, api_uuid: UUID
):
assert mock_embedding_client.host
assert mock_embedding_client.host == embedding.EMBEDDING_URL
assert mock_embedding_client.api_key == api_uuid


@patch("starpoint.embedding._validate_host")
def test_embedding_init_non_default_host(
mock_host_validator: MagicMock, api_uuid: UUID
):
test_host = "http://www.example.com"
test_embedding_client = embedding.EmbeddingClient(api_key=api_uuid, host=test_host)

mock_host_validator.assert_called_once_with(test_host)
# This assert needs to be after assert_called_once_with to make sure it doesn't confound the result
assert test_embedding_client.host == mock_host_validator()
assert test_embedding_client.api_key == api_uuid


@patch("starpoint.embedding.requests")
def test_embedding_embed_not_200(
requests_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
monkeypatch: MonkeyPatch,
):
requests_mock.post().ok = False

expected_json = {}

logger_mock = MagicMock()
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

actual_json = mock_embedding_client.embed(["asdf"], EmbeddingModel.MINI6)

requests_mock.post.assert_called()
logger_mock.error.assert_called_once()
assert actual_json == expected_json


@patch("starpoint.embedding.requests")
def test_embedding_embed_SSLError(
requests_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
monkeypatch: MonkeyPatch,
):
requests_mock.exceptions.SSLError = SSLError
requests_mock.post.side_effect = SSLError("mock exception")

logger_mock = MagicMock()
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

with pytest.raises(SSLError, match="mock exception"):
mock_embedding_client.embed(["asdf"], EmbeddingModel.MINI6)

logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG)

0 comments on commit 4d3fc74

Please sign in to comment.