Skip to content

Commit

Permalink
Merge pull request #13 from ai-cfia/12-duplicated-url-results-with-ai…
Browse files Browse the repository at this point in the history
…-lab-llamaindex-search-in-llamaindex-db

issue #8: fix duplicate urls
  • Loading branch information
k-allagbe authored Apr 29, 2024
2 parents 7b2b97c + 03fc62a commit e1a3c84
Show file tree
Hide file tree
Showing 8 changed files with 1,099 additions and 117 deletions.
5 changes: 3 additions & 2 deletions ailab-llamaindex-search/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Overview

The `ailab-llamaindex-search` package facilitates querying our custom index built using LlamaIndex and PostgresSQL.
The `ailab-llamaindex-search` package facilitates querying our custom index
built using LlamaIndex and PostgresSQL.

## Installation

Expand Down Expand Up @@ -46,7 +47,7 @@ trans_paths = {
}

index = create_index_object(embed_model_params, vector_store_params)
search_results = search("your query", index, trans_paths=trans_paths)
search_results = search("your query", index, similarity_top_k=10, trans_paths=trans_paths)

for result in search_results:
print(result)
Expand Down
18 changes: 14 additions & 4 deletions ailab-llamaindex-search/ailab_llamaindex_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dpath
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import NodeWithScore
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.vector_stores.postgres import PGVectorStore

Expand All @@ -10,6 +11,15 @@ class AilabLlamaIndexSearchError(Exception):
"""Generic Ailab LlamaIndex search error."""


def select_highest_scored_nodes_by_url(nodes: list[NodeWithScore]):
best_nodes: dict[str, NodeWithScore] = {}
for node in nodes:
url: str = node.metadata["url"]
if url not in best_nodes or best_nodes[url].score < node.score:
best_nodes[url] = node
return list(best_nodes.values())


def transform(node_dict: dict, paths: dict):
if not paths:
return node_dict
Expand All @@ -20,16 +30,16 @@ def transform(node_dict: dict, paths: dict):
def search(
query: str,
index: VectorStoreIndex,
search_params: dict = {},
similarity_top_k: int = 10,
trans_paths: dict = {},
):
if not query:
logging.error("Empty search query received")
raise AilabLlamaIndexSearchError("search query cannot be empty.")

retriever = index.as_retriever(**search_params)
retriever = index.as_retriever(similarity_top_k=similarity_top_k * 2)
nodes = retriever.retrieve(query)
return [transform(n.dict(), trans_paths) for n in nodes]
best_nodes = select_highest_scored_nodes_by_url(nodes)
return [transform(node.dict(), trans_paths) for node in best_nodes]


def create_index_object(embed_model_params: dict, vector_store_params: dict):
Expand Down
100 changes: 83 additions & 17 deletions ailab-llamaindex-search/tests/test_ailab_llamaindex_search.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch

from ailab_llamaindex_search import (
AilabLlamaIndexSearchError,
VectorStoreIndex,
create_index_object,
search,
select_highest_scored_nodes_by_url,
transform,
)
from llama_index.core.schema import NodeWithScore, TextNode


class TestAilabLlamaTransform(unittest.TestCase):
Expand Down Expand Up @@ -45,35 +47,99 @@ class TestAilabLlamaSearch(unittest.TestCase):
def setUp(self):
self.mock_index = MagicMock(spec=VectorStoreIndex)
self.mock_retriever = MagicMock()
self.mock_retriever.retrieve.return_value = [MagicMock(dict=MagicMock(return_value={'id': 1, 'name': 'Test Node'}))]
self.mock_index.as_retriever.return_value = self.mock_retriever

def test_search_with_empty_query_error(self):
with self.assertRaises(AilabLlamaIndexSearchError):
search("", self.mock_index)

@patch('ailab_llamaindex_search.transform')
def test_search_calls_transform_on_results(self, mock_transform):
mock_transform.return_value = {'id': 1, 'name': 'Transformed Node'}
@patch("ailab_llamaindex_search.transform")
@patch("ailab_llamaindex_search.select_highest_scored_nodes_by_url")
def test_search_calls_the_right_functions(self, mock_select, mock_transform):
d1 = {"id_": "1", "metadata": {"url": "https://example.com"}}
d2 = {"id_": "2", "metadata": {"url": "https://example.com"}}
node1 = NodeWithScore(node=TextNode.from_dict(d1), score=0.8)
node2 = NodeWithScore(node=TextNode.from_dict(d2), score=0.9)
nodes = [node1, node2]
selected_nodes = [node2]
transformed_nodes = node2.dict()
self.mock_retriever.retrieve.return_value = nodes
mock_select.return_value = selected_nodes
mock_transform.side_effect = lambda node_dict, _: node_dict

results = search("test query", self.mock_index)
self.assertTrue(mock_transform.called)
self.assertEqual(len(results), 1)
self.assertEqual(results[0], {'id': 1, 'name': 'Transformed Node'})

@patch('ailab_llamaindex_search.AzureOpenAIEmbedding')
@patch('ailab_llamaindex_search.PGVectorStore.from_params')
@patch('ailab_llamaindex_search.VectorStoreIndex.from_vector_store')
def test_create_index_object_initializes_correctly(self, mock_from_vector_store, mock_from_params, mock_azure_openai_embedding):
mock_select.assert_called_once_with(nodes)
calls = [call(node.dict(), {}) for node in selected_nodes]
mock_transform.assert_has_calls(calls, any_order=True)
self.assertTrue(results[0] == transformed_nodes)

@patch("ailab_llamaindex_search.select_highest_scored_nodes_by_url")
@patch("ailab_llamaindex_search.transform")
def test_retriever_similarity_top_k_parameter(self, mock_transform, mock_select):
self.mock_index.as_retriever = MagicMock()
similarity_top_k = 10
search("valid query", self.mock_index, similarity_top_k=similarity_top_k)
self.mock_index.as_retriever.assert_called_once_with(
similarity_top_k=similarity_top_k * 2
)

@patch("ailab_llamaindex_search.AzureOpenAIEmbedding")
@patch("ailab_llamaindex_search.PGVectorStore.from_params")
@patch("ailab_llamaindex_search.VectorStoreIndex.from_vector_store")
def test_create_index_object_initializes_correctly(
self, mock_from_vector_store, mock_from_params, mock_azure_openai_embedding
):
mock_embed_model = MagicMock()
mock_azure_openai_embedding.return_value = mock_embed_model
mock_vector_store = MagicMock()
mock_from_params.return_value = mock_vector_store
mock_index_object = MagicMock()
mock_from_vector_store.return_value = mock_index_object
embed_model_params = {'param1': 'value1'}
vector_store_params = {'param2': 'value2'}
embed_model_params = {"param1": "value1"}
vector_store_params = {"param2": "value2"}
result = create_index_object(embed_model_params, vector_store_params)
mock_azure_openai_embedding.assert_called_once_with(**embed_model_params)
mock_from_params.assert_called_once_with(**vector_store_params)
mock_from_vector_store.assert_called_once_with(mock_vector_store, mock_embed_model)
mock_from_vector_store.assert_called_once_with(
mock_vector_store, mock_embed_model
)
self.assertEqual(result, mock_index_object)


class TestSelectHighestScoredNodesByURL(unittest.TestCase):

def test_empty_input(self):
self.assertEqual(select_highest_scored_nodes_by_url([]), [])

def test_single_node(self):
node_data = {"id_": "1", "metadata": {"url": "https://example.com"}}
node = NodeWithScore(node=TextNode.from_dict(node_data), score=1.0)
self.assertEqual(select_highest_scored_nodes_by_url([node]), [node])

def test_multiple_nodes_one_url(self):
node_data1 = {"id_": "1", "metadata": {"url": "https://example.com"}}
node_data2 = {"id_": "2", "metadata": {"url": "https://example.com"}}
node1 = NodeWithScore(node=TextNode.from_dict(node_data1), score=1.0)
node2 = NodeWithScore(node=TextNode.from_dict(node_data2), score=2.0)
self.assertEqual(select_highest_scored_nodes_by_url([node1, node2]), [node2])

def test_multiple_nodes_multiple_urls(self):
node_data1 = {"id_": "1", "metadata": {"url": "https://example.com"}}
node_data2 = {"id_": "2", "metadata": {"url": "https://example.com"}}
node_data3 = {"id_": "3", "metadata": {"url": "https://example2.com"}}
node1 = NodeWithScore(node=TextNode.from_dict(node_data1), score=1.0)
node2 = NodeWithScore(node=TextNode.from_dict(node_data2), score=2.0)
node3 = NodeWithScore(node=TextNode.from_dict(node_data3), score=3.0)
result = select_highest_scored_nodes_by_url([node1, node2, node3])
self.assertIn(node2, result)
self.assertIn(node3, result)
self.assertEqual(len(result), 2)

def test_nodes_with_same_score(self):
node_data1 = {"id_": "1", "metadata": {"url": "https://example.com"}}
node_data2 = {"id_": "2", "metadata": {"url": "https://example.com"}}
node1 = NodeWithScore(node=TextNode.from_dict(node_data1), score=1.0)
node2 = NodeWithScore(node=TextNode.from_dict(node_data2), score=1.0)
result = select_highest_scored_nodes_by_url([node1, node2])
self.assertIn(node1, result)
self.assertEqual(len(result), 1)
5 changes: 1 addition & 4 deletions ailab-llamaindex-search/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@ def setUp(self):
os.getenv("LLAMAINDEX_DB_VECTOR_STORE_PARAMS")
)
self.trans_paths = json.loads(os.getenv("LLAMAINDEX_DB_TRANS_PATHS"))
self.search_params = {"similarity_top_k": 5}
self.index = create_index_object(
self.embed_model_params, self.vector_store_params
)

def test_search(self):
query = "steps and considerations of the sampling procedures for food safety"
results = search(query, self.index, self.search_params, self.trans_paths)
n = self.search_params["similarity_top_k"]
self.assertEqual(len(results), n)
results = search(query, self.index, 10, self.trans_paths)
for result in results:
for key in self.trans_paths.keys():
self.assertIn(key, result)
Binary file added docs/img/pagination_caching_sequence.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 52 additions & 0 deletions docs/puml/pagination_caching_sequence.puml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
@startuml search sequence

actor user
participant ":Flask" as app
database "cache" as cache
participant ":Config" as config
participant ":VectorStoreIndex" as index
participant ":BaseRetriever" as retriever
entity "ada:EmbedModel" as ada
database "llamaindex_db" as data


user -> app: POST /search/llamaindex\nparams: query, top, skip
app -> cache: get results for query
alt no cached results
activate app
app -> config: get index
app -> index: get retriever
activate index
create retriever
index -> retriever: create(similarity_top_k=high)
index --> app: retriever
deactivate index
app -> retriever: retrieve(query)
activate retriever
retriever -> ada: get embeddings for query
retriever -> data: match embeddings
activate data
return matching nodes
return nodes
app -> app: filter out\nduplicate url nodes
app -> app: transform nodes
app -> cache: set results for query
end alt
app -> user: slice results from skip to top
deactivate app

<style>
legend {
Fontsize 12
BackgroundColor white
LineColor white
HorizontalAlignment center
}
</style>
legend
<img:../img/logo.png{scale=0.25}>
cfia.ai-ia.acia@inspection.gc.ca
kotchikpaguy-landry.allagbe@inspection.gc.ca
2024-03-21
end legend
@enduml
Loading

0 comments on commit e1a3c84

Please sign in to comment.