diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c07a7d3..7b70d1a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ - Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories. - Added examples and documentation for using message history with Neo4j and in-memory storage. - Updated LLM and GraphRAG classes to support new message history classes. - +- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter. ### Changed - Refactored index-related functions for improved compatibility and functionality. diff --git a/src/neo4j_graphrag/exceptions.py b/src/neo4j_graphrag/exceptions.py index 30746593..bde7cfe8 100644 --- a/src/neo4j_graphrag/exceptions.py +++ b/src/neo4j_graphrag/exceptions.py @@ -124,3 +124,7 @@ class PdfLoaderError(Neo4jGraphRagError): class PromptMissingPlaceholderError(Neo4jGraphRagError): """Exception raised when a prompt is missing an expected placeholder.""" + + +class InvalidHybridSearchRankerError(Neo4jGraphRagError): + """Exception raised when an invalid ranker type for Hybrid Search is provided.""" diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index e6564b9b..4e6e4ef3 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -15,10 +15,11 @@ from __future__ import annotations import warnings -from typing import Any, Optional +from typing import Any, Optional, Union +from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError from neo4j_graphrag.filters import get_metadata_filter -from neo4j_graphrag.types import EntityType, SearchType +from neo4j_graphrag.types import EntityType, SearchType, HybridSearchRanker NODE_VECTOR_INDEX_QUERY = ( "CALL db.index.vector.queryNodes" @@ -171,6 +172,45 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str: return call_prefix + query_body +def _get_hybrid_query_linear(neo4j_version_is_5_23_or_above: bool, alpha: float) -> str: + """ + Construct a Cypher query for hybrid search using a linear combination approach with an alpha parameter. + + This query retrieves normalized scores from both the vector index and full-text index. It then + computes the final score as a weighted sum: + + ``` + final_score = alpha * (vector normalized score) + (1 - alpha) * (fulltext normalized score) + ``` + + If a node appears in only one index, the missing score is treated as 0. + + Args: + neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines the call syntax. + alpha (float): Weight for the vector index normalized score. The full-text score is weighted as (1 - alpha). + + Returns: + str: The constructed Cypher query string. + """ + call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { " + + query_body = ( + f"{NODE_VECTOR_INDEX_QUERY} " + "WITH collect({node: node, score: score}) AS nodes, max(score) AS vector_index_max_score " + "UNWIND nodes AS n " + "WITH n.node AS node, (n.score / vector_index_max_score) AS rawScore " + "RETURN node, rawScore * $alpha AS score " + "UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + "WITH collect({node: node, score: score}) AS nodes, max(score) AS ft_index_max_score " + "UNWIND nodes AS n " + "WITH n.node AS node, (n.score / ft_index_max_score) AS rawScore " + "RETURN node, rawScore * (1 - $alpha) AS score } " + "WITH node, sum(score) AS score ORDER BY score DESC LIMIT $top_k" + ) + return call_prefix + query_body + + def _get_filtered_vector_query( filters: dict[str, Any], node_label: str, @@ -223,6 +263,8 @@ def get_search_query( filters: Optional[dict[str, Any]] = None, neo4j_version_is_5_23_or_above: bool = False, use_parallel_runtime: bool = False, + ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE, + alpha: Optional[float] = None, ) -> tuple[str, dict[str, Any]]: """ Constructs a search query for vector or hybrid search, including optional pre-filtering @@ -243,6 +285,8 @@ def get_search_query( neo4j_version_is_5_23_or_above (Optional[bool]): Whether the Neo4j version is 5.23 or above. use_parallel_runtime (bool): Whether or not use the parallel runtime to run the query. Defaults to False. + ranker (HybridSearchRanker): Type of ranker to order the results from retrieval. + alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided. Returns: tuple[str, dict[str, Any]]: A tuple containing the constructed query string and @@ -262,7 +306,14 @@ def get_search_query( if search_type == SearchType.HYBRID: if filters: raise Exception("Filters are not supported with hybrid search") - query = _get_hybrid_query(neo4j_version_is_5_23_or_above) + if ranker == HybridSearchRanker.NAIVE: + query = _get_hybrid_query(neo4j_version_is_5_23_or_above) + elif ranker == HybridSearchRanker.LINEAR and alpha: + query = _get_hybrid_query_linear( + neo4j_version_is_5_23_or_above, alpha=alpha + ) + else: + raise InvalidHybridSearchRankerError() params: dict[str, Any] = {} elif search_type == SearchType.VECTOR: if filters: diff --git a/src/neo4j_graphrag/retrievers/hybrid.py b/src/neo4j_graphrag/retrievers/hybrid.py index 4a2bcea4..2edd5449 100644 --- a/src/neo4j_graphrag/retrievers/hybrid.py +++ b/src/neo4j_graphrag/retrievers/hybrid.py @@ -16,7 +16,7 @@ import copy import logging -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import neo4j from pydantic import ValidationError @@ -39,6 +39,7 @@ RawSearchResult, RetrieverResultItem, SearchType, + HybridSearchRanker, ) logger = logging.getLogger(__name__) @@ -142,6 +143,8 @@ def get_search_results( query_vector: Optional[list[float]] = None, top_k: int = 5, effective_search_ratio: int = 1, + ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE, + alpha: Optional[float] = None, ) -> RawSearchResult: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. @@ -162,6 +165,10 @@ def get_search_results( top_k (int, optional): The number of neighbors to return. Defaults to 5. effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query accuracy and performance. Defaults to 1. + ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval. + alpha (Optional[float]): Weight for the vector score when using the linear ranker. + The fulltext index score is multiplied by (1 - alpha). + **Required** when using the linear ranker; must be between 0 and 1. Raises: SearchValidationError: If validation of the input arguments fail. @@ -176,6 +183,8 @@ def get_search_results( query_text=query_text, top_k=top_k, effective_search_ratio=effective_search_ratio, + ranker=ranker, + alpha=alpha, ) except ValidationError as e: raise SearchValidationError(e.errors()) from e @@ -191,13 +200,18 @@ def get_search_results( ) query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query, _ = get_search_query( search_type=SearchType.HYBRID, return_properties=self.return_properties, embedding_node_property=self._embedding_node_property, neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above, + ranker=validated_data.ranker, + alpha=validated_data.alpha, ) + + if "ranker" in parameters: + del parameters["ranker"] + sanitized_parameters = copy.deepcopy(parameters) if "query_vector" in sanitized_parameters: sanitized_parameters["query_vector"] = "..." @@ -301,6 +315,8 @@ def get_search_results( top_k: int = 5, effective_search_ratio: int = 1, query_params: Optional[dict[str, Any]] = None, + ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE, + alpha: Optional[float] = None, ) -> RawSearchResult: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. @@ -320,7 +336,10 @@ def get_search_results( effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query accuracy and performance. Defaults to 1. query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None. - + ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval. + alpha (Optional[float]): Weight for the vector score when using the linear ranker. + The fulltext index score is multiplied by (1 - alpha). + **Required** when using the linear ranker; must be between 0 and 1. Raises: SearchValidationError: If validation of the input arguments fail. EmbeddingRequiredError: If no embedder is provided. @@ -334,6 +353,8 @@ def get_search_results( query_text=query_text, top_k=top_k, effective_search_ratio=effective_search_ratio, + ranker=ranker, + alpha=alpha, query_params=query_params, ) except ValidationError as e: @@ -361,7 +382,13 @@ def get_search_results( search_type=SearchType.HYBRID, retrieval_query=self.retrieval_query, neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above, + ranker=validated_data.ranker, + alpha=validated_data.alpha, ) + + if "ranker" in parameters: + del parameters["ranker"] + sanitized_parameters = copy.deepcopy(parameters) if "query_vector" in sanitized_parameters: sanitized_parameters["query_vector"] = "..." diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index da2a5665..1c0b7454 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import warnings from enum import Enum from typing import Any, Callable, Literal, Optional, TypedDict, Union @@ -25,6 +26,7 @@ field_validator, model_validator, ) +from typing_extensions import Self from neo4j_graphrag.utils.validation import validate_search_query_input @@ -137,11 +139,53 @@ class VectorCypherSearchModel(VectorSearchModel): query_params: Optional[dict[str, Any]] = None +class HybridSearchRanker(Enum): + """Enumerator of Hybrid search rankers.""" + + NAIVE = "naive" + LINEAR = "linear" + + class HybridSearchModel(BaseModel): query_text: str query_vector: Optional[list[float]] = None top_k: PositiveInt = 5 effective_search_ratio: PositiveInt = 1 + ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE + alpha: Optional[float] = None + + @field_validator("ranker", mode="before") + def validate_ranker(cls, v: Union[str, HybridSearchRanker]) -> HybridSearchRanker: + if isinstance(v, str): + try: + return HybridSearchRanker(v.lower()) + except ValueError: + allowed = ", ".join([r.value for r in HybridSearchRanker]) + raise ValueError( + f"Invalid ranker value. Allowed values are: {allowed}." + ) + elif isinstance(v, HybridSearchRanker): + return v + else: + allowed = ", ".join([r.value for r in HybridSearchRanker]) + raise ValueError(f"Invalid ranker type. Allowed values are: {allowed}.") + + @model_validator(mode="after") + def validate_alpha(self) -> Self: + ranker, alpha = self.ranker, self.alpha + if ranker == HybridSearchRanker.LINEAR: + if alpha is None: + raise ValueError("alpha must be provided when using the linear ranker") + if not (0.0 <= alpha <= 1.0): + raise ValueError("alpha must be between 0 and 1") + else: + if alpha is not None: + warnings.warn( + "alpha parameter is only used when ranker is 'linear'. Ignoring alpha.", + UserWarning, + ) + self.alpha = None + return self class HybridCypherSearchModel(HybridSearchModel): diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py index 908f9214..d3da020b 100644 --- a/tests/e2e/test_hybrid_e2e.py +++ b/tests/e2e/test_hybrid_e2e.py @@ -176,3 +176,27 @@ def test_hybrid_retriever_return_properties(driver: Driver) -> None: assert len(results.items) == 5 for result in results.items: assert isinstance(result, RetrieverResultItem) + + +@pytest.mark.usefixtures("setup_neo4j_for_retrieval") +def test_hybrid_retriever_search_text_linear_ranker( + driver: Driver, random_embedder: Embedder +) -> None: + retriever = HybridRetriever( + driver, "vector-index-name", "fulltext-index-name", random_embedder + ) + + top_k = 5 + effective_search_ratio = 2 + results = retriever.search( + query_text="Find me a book about Fremen", + top_k=top_k, + effective_search_ratio=effective_search_ratio, + ranker="linear", + alpha=0.9, + ) + + assert isinstance(results, RetrieverResult) + assert len(results.items) == 5 + for result in results.items: + assert isinstance(result, RetrieverResultItem) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index fac15405..d38f1ab1 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -20,10 +20,16 @@ from neo4j_graphrag.exceptions import ( EmbeddingRequiredError, RetrieverInitializationError, + SearchValidationError, ) from neo4j_graphrag.neo4j_queries import get_search_query from neo4j_graphrag.retrievers import HybridCypherRetriever, HybridRetriever -from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem, SearchType +from neo4j_graphrag.types import ( + RetrieverResult, + RetrieverResultItem, + SearchType, + HybridSearchRanker, +) def test_vector_retriever_initialization(driver: MagicMock) -> None: @@ -605,3 +611,185 @@ def test_hybrid_cypher_search_sanitizes_text( database_=None, routing_=neo4j.RoutingControl.READ, ) + + +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_hybrid_retriever_linear_without_alpha( + mock_get_version: MagicMock, driver: MagicMock +) -> None: + mock_get_version.return_value = ((5, 23, 0), False, False) + with pytest.raises(SearchValidationError) as exc_info: + HybridRetriever( + driver=driver, + vector_index_name="vector-index", + fulltext_index_name="fulltext-index", + neo4j_database="neo4j", + ).search(query_text="test query", ranker="linear") + assert "alpha must be provided" in str(exc_info.value) + + +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_hybrid_cypher_retriever_linear_without_alpha( + mock_get_version: MagicMock, driver: MagicMock +) -> None: + mock_get_version.return_value = ((5, 23, 0), False, False) + with pytest.raises(SearchValidationError) as exc_info: + HybridCypherRetriever( + driver=driver, + vector_index_name="vector-index", + fulltext_index_name="fulltext-index", + neo4j_database="neo4j", + retrieval_query="", + ).search(query_text="test query", ranker="linear") + assert "alpha must be provided" in str(exc_info.value) + + +@patch("neo4j_graphrag.retrievers.HybridRetriever._fetch_index_infos") +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_hybrid_search_linear_ranker_happy_path( + mock_get_version: MagicMock, + _fetch_index_infos_mock: MagicMock, + driver: MagicMock, + embedder: MagicMock, + neo4j_record: MagicMock, +) -> None: + mock_get_version.return_value = ((5, 23, 0), False, False) + embed_query_vector = [1.0 for _ in range(1536)] + embedder.embed_query.return_value = embed_query_vector + vector_index_name = "vector-index" + fulltext_index_name = "fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + effective_search_ratio = 2 + ranker = HybridSearchRanker.LINEAR + alpha = 0.7 + + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, embedder + ) + retriever.neo4j_version_is_5_23_or_above = True + retriever._embedding_node_property = "embedding" + retriever.driver.execute_query.return_value = [ # type: ignore + [neo4j_record], + None, + None, + ] + search_query, _ = get_search_query( + SearchType.HYBRID, + embedding_node_property="embedding", + neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above, + ranker=ranker, + alpha=alpha, + ) + + records = retriever.search( + query_text=query_text, + top_k=top_k, + effective_search_ratio=effective_search_ratio, + ranker=ranker, + alpha=alpha, + ) + + retriever.driver.execute_query.assert_called_once_with( # type: ignore + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "effective_search_ratio": effective_search_ratio, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + "alpha": alpha, + }, + database_=None, + routing_=neo4j.RoutingControl.READ, + ) + embedder.embed_query.assert_called_once_with(query_text) + assert records == RetrieverResult( + items=[ + RetrieverResultItem(content="dummy-node", metadata={"score": 1.0}), + ], + metadata={"__retriever": "HybridRetriever"}, + ) + + +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_hybrid_cypher_linear_ranker( + mock_get_version: MagicMock, + driver: MagicMock, + embedder: MagicMock, + neo4j_record: MagicMock, +) -> None: + mock_get_version.return_value = ((5, 23, 0), False, False) + embed_query_vector = [1.0 for _ in range(1536)] + embedder.embed_query.return_value = embed_query_vector + vector_index_name = "vector-index" + fulltext_index_name = "fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + effective_search_ratio = 2 + ranker = HybridSearchRanker.LINEAR + alpha = 0.7 + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + query_params = { + "param": "dummy-param", + } + retriever = HybridCypherRetriever( + driver, + vector_index_name, + fulltext_index_name, + retrieval_query, + embedder, + ) + retriever.neo4j_version_is_5_23_or_above = True + driver.execute_query.return_value = [ + [neo4j_record], + None, + None, + ] + search_query, _ = get_search_query( + search_type=SearchType.HYBRID, + retrieval_query=retrieval_query, + neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above, + ranker=ranker, + alpha=alpha, + ) + + records = retriever.search( + query_text=query_text, + top_k=top_k, + effective_search_ratio=effective_search_ratio, + query_params=query_params, + ranker=ranker, + alpha=alpha, + ) + + embedder.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "effective_search_ratio": effective_search_ratio, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + "param": "dummy-param", + "alpha": alpha, + }, + database_=None, + routing_=neo4j.RoutingControl.READ, + ) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="", + metadata=None, + ), + ], + metadata={"__retriever": "HybridCypherRetriever"}, + ) diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index d901b21c..0fc4836b 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -16,9 +16,12 @@ from unittest.mock import patch import pytest + +from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError from neo4j_graphrag.neo4j_queries import ( get_query_tail, get_search_query, + _get_hybrid_query_linear, ) from neo4j_graphrag.types import EntityType, SearchType @@ -249,3 +252,16 @@ def test_get_query_tail_ordering_no_retrieval_query() -> None: fallback_return=fallback, ) assert result.strip() == expected.strip() + + +def test_get_hybrid_query_linear_with_alpha() -> None: + query = _get_hybrid_query_linear(neo4j_version_is_5_23_or_above=True, alpha=0.7) + vector_substr = "rawScore * $alpha" + ft_substr = "rawScore * (1 - $alpha)" + assert vector_substr in query + assert ft_substr in query + + +def test_invalid_hybrid_search_ranker_error() -> None: + with pytest.raises(InvalidHybridSearchRankerError): + get_search_query(SearchType.HYBRID, ranker="invalid")