From af275783b1e2273b047cd819578cbf0a2017fc36 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 22 Aug 2024 11:14:20 +0100 Subject: [PATCH] Fixed query parameter bug in GraphRAG class (#109) --- src/neo4j_genai/generation/graphrag.py | 14 +++++++------- tests/unit/test_graphrag.py | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/neo4j_genai/generation/graphrag.py b/src/neo4j_genai/generation/graphrag.py index e1df04511..7a0527852 100644 --- a/src/neo4j_genai/generation/graphrag.py +++ b/src/neo4j_genai/generation/graphrag.py @@ -85,13 +85,13 @@ def search( DeprecationWarning, stacklevel=2, ) - elif isinstance(query, str): - warnings.warn( - "'query' is deprecated and will be removed in a future version, please use 'query_text' instead.", - DeprecationWarning, - stacklevel=2, - ) - query_text = query + elif isinstance(query, str): + warnings.warn( + "'query' is deprecated and will be removed in a future version, please use 'query_text' instead.", + DeprecationWarning, + stacklevel=2, + ) + query_text = query validated_data = RagSearchModel( query_text=query_text, diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 731ef43b3..3c1e96228 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import MagicMock +from warnings import catch_warnings import pytest from neo4j_genai.exceptions import RagInitializationError, SearchValidationError @@ -21,6 +22,7 @@ from neo4j_genai.generation.types import RagResultModel from neo4j_genai.llm import LLMResponse from neo4j_genai.types import RetrieverResult, RetrieverResultItem +from pydantic import ValidationError def test_graphrag_prompt_template() -> None: @@ -99,3 +101,21 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non with pytest.raises(SearchValidationError) as excinfo: rag.search(10) # type: ignore assert "Input should be a valid string" in str(excinfo) + + +def test_graphrag_search_query_deprecation_warning( + retriever_mock: MagicMock, llm: MagicMock +) -> None: + with catch_warnings(record=True) as warn_list: + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + ) + with pytest.raises(ValidationError): + rag.search(query="Some query text") + + assert len(warn_list) == 1 + assert ( + str(warn_list[0].message) + == "'query' is deprecated and will be removed in a future version, please use 'query_text' instead." + )