Skip to content

Commit

Permalink
#4: implement ailab_llama_search package
Browse files Browse the repository at this point in the history
  • Loading branch information
k-allagbe committed Mar 21, 2024
1 parent d6f856e commit 60c4b10
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 0 deletions.
25 changes: 25 additions & 0 deletions ailab-llama-search/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# AI Lab - Llama Search Package

## Overview

TODO

## Installation

TODO

## Configuration

TODO

## Usage

TODO

## Exceptions

TODO

## Functions

TODO
38 changes: 38 additions & 0 deletions ailab-llama-search/ailab_llama_search/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import dpath
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.vector_stores.postgres import PGVectorStore


class AilabLlamaSearchError(Exception):
"""Generic Ailab llama search error."""


def transform(node_dict: dict, paths: dict):
if not paths:
return node_dict

return {key: dpath.get(node_dict, path) for key, path in paths.items()}


def search(
query: str,
index: VectorStoreIndex,
search_params: dict = {},
trans_paths: dict = {},
):
if not query:
logging.error("Empty search query received")
raise AilabLlamaSearchError("Search query cannot be empty")

retriever = index.as_retriever(**search_params)
nodes = retriever.retrieve(query)
return [transform(n.dict(), trans_paths) for n in nodes]


def create_index_object(embed_model_params: dict, vector_store_params: dict):
embed_model = AzureOpenAIEmbedding(**embed_model_params)
vector_store = PGVectorStore.from_params(**vector_store_params)
return VectorStoreIndex.from_vector_store(vector_store, embed_model)
5 changes: 5 additions & 0 deletions ailab-llama-search/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
llama-index
llama-index-vector-stores-postgres
llama-index-embeddings-azure-openai
llama-index-storage-kvstore-postgres
python-dotenv
22 changes: 22 additions & 0 deletions ailab-llama-search/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from setuptools import find_packages, setup


def read_readme():
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
return long_description


def read_requirements():
with open("requirements.txt") as req:
return req.read().splitlines()


setup(
name="ailab-llama-search",
version="0.1.0",
packages=find_packages(),
install_requires=read_requirements(),
long_description=read_readme(),
long_description_content_type="text/markdown",
)
79 changes: 79 additions & 0 deletions ailab-llama-search/tests/test_ailab_llama_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest
from unittest.mock import MagicMock, patch

from ailab_llama_search import (
AilabLlamaSearchError,
VectorStoreIndex,
create_index_object,
search,
transform,
)


class TestAilabLlamaTransform(unittest.TestCase):
def test_transform(self):
source = {
"id": "123",
"nested": {"key": "value", "list": [1, 2, 3]},
"list": ["a", "b", "c"],
}
paths = {
"new_id": "/id",
"nested_value": "/nested/key",
"first_list_item": "/list/0",
}
expected = {
"new_id": "123",
"nested_value": "value",
"first_list_item": "a",
}
self.assertEqual(transform(source, paths), expected)

def test_transform_with_empty_or_none_path_map(self):
source_dict = {"id": "123", "nested": {"key": "value"}}
self.assertEqual(transform(source_dict, {}), source_dict)
self.assertEqual(transform(source_dict, None), source_dict)

def test_transform_error(self):
source_dict = {"id": "123"}
invalid_path_map = {"invalid_key": "/nonexistent/path"}
with self.assertRaises((KeyError, ValueError)):
transform(source_dict, invalid_path_map)


class TestAilabLlamaSearch(unittest.TestCase):
def setUp(self):
self.mock_index = MagicMock(spec=VectorStoreIndex)
self.mock_retriever = MagicMock()
self.mock_retriever.retrieve.return_value = [MagicMock(dict=MagicMock(return_value={'id': 1, 'name': 'Test Node'}))]
self.mock_index.as_retriever.return_value = self.mock_retriever

def test_search_with_empty_query_error(self):
with self.assertRaises(AilabLlamaSearchError):
search("", self.mock_index)

@patch('ailab_llama_search.transform')
def test_search_calls_transform_on_results(self, mock_transform):
mock_transform.return_value = {'id': 1, 'name': 'Transformed Node'}
results = search("test query", self.mock_index)
self.assertTrue(mock_transform.called)
self.assertEqual(len(results), 1)
self.assertEqual(results[0], {'id': 1, 'name': 'Transformed Node'})

@patch('ailab_llama_search.AzureOpenAIEmbedding')
@patch('ailab_llama_search.PGVectorStore.from_params')
@patch('ailab_llama_search.VectorStoreIndex.from_vector_store')
def test_create_index_object_initializes_correctly(self, mock_from_vector_store, mock_from_params, mock_azure_openai_embedding):
mock_embed_model = MagicMock()
mock_azure_openai_embedding.return_value = mock_embed_model
mock_vector_store = MagicMock()
mock_from_params.return_value = mock_vector_store
mock_index_object = MagicMock()
mock_from_vector_store.return_value = mock_index_object
embed_model_params = {'param1': 'value1'}
vector_store_params = {'param2': 'value2'}
result = create_index_object(embed_model_params, vector_store_params)
mock_azure_openai_embedding.assert_called_once_with(**embed_model_params)
mock_from_params.assert_called_once_with(**vector_store_params)
mock_from_vector_store.assert_called_once_with(mock_vector_store, mock_embed_model)
self.assertEqual(result, mock_index_object)
33 changes: 33 additions & 0 deletions ailab-llama-search/tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import os
import time
import unittest

from ailab_llama_search import create_index_object, search
from dotenv import load_dotenv


class AilabLlamaSearchIntegrationTests(unittest.TestCase):

def setUp(self):
load_dotenv()
self.embed_model_params = json.loads(os.getenv("EMBED_MODEL_PARAMS"))
self.vector_store_params = json.loads(os.getenv("VECTOR_STORE_PARAMS"))
self.trans_paths = json.loads(os.getenv("TRANS_PATHS"))
self.search_params = {"similarity_top_k": 5}
self.index = create_index_object(
self.embed_model_params, self.vector_store_params
)

def test_search(self):
query = "steps and considerations of the sampling procedures for food safety"
start_time = time.time()
results = search(query, self.index, self.search_params, self.trans_paths)
end_time = time.time()
duration = (end_time - start_time) * 1000
n = self.search_params["similarity_top_k"]
self.assertLess(duration, 2000)
self.assertEqual(len(results), n)
for result in results:
for key in self.trans_paths.keys():
self.assertIn(key, result)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ llama-index-storage-index-store-postgres
llama-index-storage-kvstore-postgres
llama-index-readers-database
llama-index-readers-web
dpath

0 comments on commit 60c4b10

Please sign in to comment.