diff --git a/CHANGELOG.md b/CHANGELOG.md index d5bb9f8b8..12ca29daa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ ### Changed - Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders. +- Removed query argument from the GraphRAG class' `.search` method; users must now use `query_text`. - Neo4jWriter component now runs a single query to merge node and set its embeddings if any. ## 0.6.3 diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 3f402e976..f62302acd 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -15,7 +15,6 @@ from __future__ import annotations import logging -import warnings from typing import Any, Optional from pydantic import ValidationError @@ -86,7 +85,6 @@ def search( examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool = False, - query: Optional[str] = None, ) -> RagResultModel: """This method performs a full RAG search: 1. Retrieval: context retrieval @@ -99,28 +97,12 @@ def search( retriever_config (Optional[dict]): Parameters passed to the retriever search method; e.g.: top_k return_context (bool): Whether to append the retriever result to the final result (default: False) - query (Optional[str]): The user question. Will be deprecated in favor of query_text. Returns: RagResultModel: The LLM-generated answer """ try: - if query is not None: - if query_text: - warnings.warn( - "Both 'query' and 'query_text' are provided, 'query_text' will be used.", - DeprecationWarning, - stacklevel=2, - ) - elif isinstance(query, str): - warnings.warn( - "'query' is deprecated and will be removed in a future version, please use 'query_text' instead.", - DeprecationWarning, - stacklevel=2, - ) - query_text = query - validated_data = RagSearchModel( query_text=query_text, examples=examples, diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 9bb8b65b5..58508d717 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import MagicMock -from warnings import catch_warnings import pytest from neo4j_graphrag.exceptions import RagInitializationError, SearchValidationError @@ -22,7 +21,6 @@ from neo4j_graphrag.generation.types import RagResultModel from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem -from pydantic import ValidationError def test_graphrag_prompt_template() -> None: @@ -101,21 +99,3 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non with pytest.raises(SearchValidationError) as excinfo: rag.search(10) # type: ignore assert "Input should be a valid string" in str(excinfo) - - -def test_graphrag_search_query_deprecation_warning( - retriever_mock: MagicMock, llm: MagicMock -) -> None: - with catch_warnings(record=True) as warn_list: - rag = GraphRAG( - retriever=retriever_mock, - llm=llm, - ) - with pytest.raises(ValidationError): - rag.search(query="Some query text") - - assert len(warn_list) == 1 - assert ( - str(warn_list[0].message) - == "'query' is deprecated and will be removed in a future version, please use 'query_text' instead." - )