Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed query argument from GraphRAG's .search method #145

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

### Changed
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.
- Removed query argument from the GraphRAG class' `.search` method; users must now use `query_text`.
- Neo4jWriter component now runs a single query to merge node and set its embeddings if any.

## 0.6.3
Expand Down
18 changes: 0 additions & 18 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import logging
import warnings
from typing import Any, Optional

from pydantic import ValidationError
Expand Down Expand Up @@ -86,7 +85,6 @@ def search(
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool = False,
query: Optional[str] = None,
) -> RagResultModel:
"""This method performs a full RAG search:
1. Retrieval: context retrieval
Expand All @@ -99,28 +97,12 @@ def search(
retriever_config (Optional[dict]): Parameters passed to the retriever
search method; e.g.: top_k
return_context (bool): Whether to append the retriever result to the final result (default: False)
query (Optional[str]): The user question. Will be deprecated in favor of query_text.

Returns:
RagResultModel: The LLM-generated answer

"""
try:
if query is not None:
if query_text:
warnings.warn(
"Both 'query' and 'query_text' are provided, 'query_text' will be used.",
DeprecationWarning,
stacklevel=2,
)
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query

validated_data = RagSearchModel(
query_text=query_text,
examples=examples,
Expand Down
20 changes: 0 additions & 20 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from warnings import catch_warnings

import pytest
from neo4j_graphrag.exceptions import RagInitializationError, SearchValidationError
Expand All @@ -22,7 +21,6 @@
from neo4j_graphrag.generation.types import RagResultModel
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
from pydantic import ValidationError


def test_graphrag_prompt_template() -> None:
Expand Down Expand Up @@ -101,21 +99,3 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non
with pytest.raises(SearchValidationError) as excinfo:
rag.search(10) # type: ignore
assert "Input should be a valid string" in str(excinfo)


def test_graphrag_search_query_deprecation_warning(
retriever_mock: MagicMock, llm: MagicMock
) -> None:
with catch_warnings(record=True) as warn_list:
rag = GraphRAG(
retriever=retriever_mock,
llm=llm,
)
with pytest.raises(ValidationError):
rag.search(query="Some query text")

assert len(warn_list) == 1
assert (
str(warn_list[0].message)
== "'query' is deprecated and will be removed in a future version, please use 'query_text' instead."
)