Skip to content

Commit

Permalink
Removed query argument from GraphRAG's .search method (#145)
Browse files Browse the repository at this point in the history
* Removed query argument from GraphRAG's .search method

* Update CHANGELOG

* Removed test for deprecation warning
  • Loading branch information
willtai authored Sep 30, 2024
1 parent 5668bdc commit 7cb652d
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

### Changed
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.
- Removed query argument from the GraphRAG class' `.search` method; users must now use `query_text`.
- Neo4jWriter component now runs a single query to merge node and set its embeddings if any.

## 0.6.3
Expand Down
18 changes: 0 additions & 18 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import logging
import warnings
from typing import Any, Optional

from pydantic import ValidationError
Expand Down Expand Up @@ -86,7 +85,6 @@ def search(
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool = False,
query: Optional[str] = None,
) -> RagResultModel:
"""This method performs a full RAG search:
1. Retrieval: context retrieval
Expand All @@ -99,28 +97,12 @@ def search(
retriever_config (Optional[dict]): Parameters passed to the retriever
search method; e.g.: top_k
return_context (bool): Whether to append the retriever result to the final result (default: False)
query (Optional[str]): The user question. Will be deprecated in favor of query_text.
Returns:
RagResultModel: The LLM-generated answer
"""
try:
if query is not None:
if query_text:
warnings.warn(
"Both 'query' and 'query_text' are provided, 'query_text' will be used.",
DeprecationWarning,
stacklevel=2,
)
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query

validated_data = RagSearchModel(
query_text=query_text,
examples=examples,
Expand Down
20 changes: 0 additions & 20 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from warnings import catch_warnings

import pytest
from neo4j_graphrag.exceptions import RagInitializationError, SearchValidationError
Expand All @@ -22,7 +21,6 @@
from neo4j_graphrag.generation.types import RagResultModel
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
from pydantic import ValidationError


def test_graphrag_prompt_template() -> None:
Expand Down Expand Up @@ -101,21 +99,3 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non
with pytest.raises(SearchValidationError) as excinfo:
rag.search(10) # type: ignore
assert "Input should be a valid string" in str(excinfo)


def test_graphrag_search_query_deprecation_warning(
retriever_mock: MagicMock, llm: MagicMock
) -> None:
with catch_warnings(record=True) as warn_list:
rag = GraphRAG(
retriever=retriever_mock,
llm=llm,
)
with pytest.raises(ValidationError):
rag.search(query="Some query text")

assert len(warn_list) == 1
assert (
str(warn_list[0].message)
== "'query' is deprecated and will be removed in a future version, please use 'query_text' instead."
)

0 comments on commit 7cb652d

Please sign in to comment.