Skip to content

Commit

Permalink
Add linear hybrid search ranker (neo4j#284)
Browse files Browse the repository at this point in the history
* Add linear hybrid search ranker

* Update CHANGELOG

* Make alpha mandatory for linear ranker

* Use query parameters for alpha to avoid Cypher injection

* Refactor Cypher query string for linear ranker

* Removed isinstance check for float in HybridSearchModel's alpha

* Update E2E test for linear ranker

* Remove delete of alpha from query parameters
  • Loading branch information
willtai authored Feb 26, 2025
1 parent eed1a04 commit 09440e0
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
- Added examples and documentation for using message history with Neo4j and in-memory storage.
- Updated LLM and GraphRAG classes to support new message history classes.

- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
### Changed

- Refactored index-related functions for improved compatibility and functionality.
Expand Down
4 changes: 4 additions & 0 deletions src/neo4j_graphrag/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,7 @@ class PdfLoaderError(Neo4jGraphRagError):

class PromptMissingPlaceholderError(Neo4jGraphRagError):
"""Exception raised when a prompt is missing an expected placeholder."""


class InvalidHybridSearchRankerError(Neo4jGraphRagError):
"""Exception raised when an invalid ranker type for Hybrid Search is provided."""
57 changes: 54 additions & 3 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from __future__ import annotations

import warnings
from typing import Any, Optional
from typing import Any, Optional, Union

from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError
from neo4j_graphrag.filters import get_metadata_filter
from neo4j_graphrag.types import EntityType, SearchType
from neo4j_graphrag.types import EntityType, SearchType, HybridSearchRanker

NODE_VECTOR_INDEX_QUERY = (
"CALL db.index.vector.queryNodes"
Expand Down Expand Up @@ -171,6 +172,45 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
return call_prefix + query_body


def _get_hybrid_query_linear(neo4j_version_is_5_23_or_above: bool, alpha: float) -> str:
"""
Construct a Cypher query for hybrid search using a linear combination approach with an alpha parameter.
This query retrieves normalized scores from both the vector index and full-text index. It then
computes the final score as a weighted sum:
```
final_score = alpha * (vector normalized score) + (1 - alpha) * (fulltext normalized score)
```
If a node appears in only one index, the missing score is treated as 0.
Args:
neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines the call syntax.
alpha (float): Weight for the vector index normalized score. The full-text score is weighted as (1 - alpha).
Returns:
str: The constructed Cypher query string.
"""
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "

query_body = (
f"{NODE_VECTOR_INDEX_QUERY} "
"WITH collect({node: node, score: score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"WITH n.node AS node, (n.score / vector_index_max_score) AS rawScore "
"RETURN node, rawScore * $alpha AS score "
"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
"WITH collect({node: node, score: score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"WITH n.node AS node, (n.score / ft_index_max_score) AS rawScore "
"RETURN node, rawScore * (1 - $alpha) AS score } "
"WITH node, sum(score) AS score ORDER BY score DESC LIMIT $top_k"
)
return call_prefix + query_body


def _get_filtered_vector_query(
filters: dict[str, Any],
node_label: str,
Expand Down Expand Up @@ -223,6 +263,8 @@ def get_search_query(
filters: Optional[dict[str, Any]] = None,
neo4j_version_is_5_23_or_above: bool = False,
use_parallel_runtime: bool = False,
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
alpha: Optional[float] = None,
) -> tuple[str, dict[str, Any]]:
"""
Constructs a search query for vector or hybrid search, including optional pre-filtering
Expand All @@ -243,6 +285,8 @@ def get_search_query(
neo4j_version_is_5_23_or_above (Optional[bool]): Whether the Neo4j version is 5.23 or above.
use_parallel_runtime (bool): Whether or not use the parallel runtime to run the query.
Defaults to False.
ranker (HybridSearchRanker): Type of ranker to order the results from retrieval.
alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
Returns:
tuple[str, dict[str, Any]]: A tuple containing the constructed query string and
Expand All @@ -262,7 +306,14 @@ def get_search_query(
if search_type == SearchType.HYBRID:
if filters:
raise Exception("Filters are not supported with hybrid search")
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
if ranker == HybridSearchRanker.NAIVE:
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
elif ranker == HybridSearchRanker.LINEAR and alpha:
query = _get_hybrid_query_linear(
neo4j_version_is_5_23_or_above, alpha=alpha
)
else:
raise InvalidHybridSearchRankerError()
params: dict[str, Any] = {}
elif search_type == SearchType.VECTOR:
if filters:
Expand Down
33 changes: 30 additions & 3 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import copy
import logging
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import neo4j
from pydantic import ValidationError
Expand All @@ -39,6 +39,7 @@
RawSearchResult,
RetrieverResultItem,
SearchType,
HybridSearchRanker,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,6 +143,8 @@ def get_search_results(
query_vector: Optional[list[float]] = None,
top_k: int = 5,
effective_search_ratio: int = 1,
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
alpha: Optional[float] = None,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -162,6 +165,10 @@ def get_search_results(
top_k (int, optional): The number of neighbors to return. Defaults to 5.
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
accuracy and performance. Defaults to 1.
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
alpha (Optional[float]): Weight for the vector score when using the linear ranker.
The fulltext index score is multiplied by (1 - alpha).
**Required** when using the linear ranker; must be between 0 and 1.
Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -176,6 +183,8 @@ def get_search_results(
query_text=query_text,
top_k=top_k,
effective_search_ratio=effective_search_ratio,
ranker=ranker,
alpha=alpha,
)
except ValidationError as e:
raise SearchValidationError(e.errors()) from e
Expand All @@ -191,13 +200,18 @@ def get_search_results(
)
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

search_query, _ = get_search_query(
search_type=SearchType.HYBRID,
return_properties=self.return_properties,
embedding_node_property=self._embedding_node_property,
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
ranker=validated_data.ranker,
alpha=validated_data.alpha,
)

if "ranker" in parameters:
del parameters["ranker"]

sanitized_parameters = copy.deepcopy(parameters)
if "query_vector" in sanitized_parameters:
sanitized_parameters["query_vector"] = "..."
Expand Down Expand Up @@ -301,6 +315,8 @@ def get_search_results(
top_k: int = 5,
effective_search_ratio: int = 1,
query_params: Optional[dict[str, Any]] = None,
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
alpha: Optional[float] = None,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -320,7 +336,10 @@ def get_search_results(
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
accuracy and performance. Defaults to 1.
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
alpha (Optional[float]): Weight for the vector score when using the linear ranker.
The fulltext index score is multiplied by (1 - alpha).
**Required** when using the linear ranker; must be between 0 and 1.
Raises:
SearchValidationError: If validation of the input arguments fail.
EmbeddingRequiredError: If no embedder is provided.
Expand All @@ -334,6 +353,8 @@ def get_search_results(
query_text=query_text,
top_k=top_k,
effective_search_ratio=effective_search_ratio,
ranker=ranker,
alpha=alpha,
query_params=query_params,
)
except ValidationError as e:
Expand Down Expand Up @@ -361,7 +382,13 @@ def get_search_results(
search_type=SearchType.HYBRID,
retrieval_query=self.retrieval_query,
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
ranker=validated_data.ranker,
alpha=validated_data.alpha,
)

if "ranker" in parameters:
del parameters["ranker"]

sanitized_parameters = copy.deepcopy(parameters)
if "query_vector" in sanitized_parameters:
sanitized_parameters["query_vector"] = "..."
Expand Down
44 changes: 44 additions & 0 deletions src/neo4j_graphrag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import warnings
from enum import Enum
from typing import Any, Callable, Literal, Optional, TypedDict, Union

Expand All @@ -25,6 +26,7 @@
field_validator,
model_validator,
)
from typing_extensions import Self

from neo4j_graphrag.utils.validation import validate_search_query_input

Expand Down Expand Up @@ -137,11 +139,53 @@ class VectorCypherSearchModel(VectorSearchModel):
query_params: Optional[dict[str, Any]] = None


class HybridSearchRanker(Enum):
"""Enumerator of Hybrid search rankers."""

NAIVE = "naive"
LINEAR = "linear"


class HybridSearchModel(BaseModel):
query_text: str
query_vector: Optional[list[float]] = None
top_k: PositiveInt = 5
effective_search_ratio: PositiveInt = 1
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE
alpha: Optional[float] = None

@field_validator("ranker", mode="before")
def validate_ranker(cls, v: Union[str, HybridSearchRanker]) -> HybridSearchRanker:
if isinstance(v, str):
try:
return HybridSearchRanker(v.lower())
except ValueError:
allowed = ", ".join([r.value for r in HybridSearchRanker])
raise ValueError(
f"Invalid ranker value. Allowed values are: {allowed}."
)
elif isinstance(v, HybridSearchRanker):
return v
else:
allowed = ", ".join([r.value for r in HybridSearchRanker])
raise ValueError(f"Invalid ranker type. Allowed values are: {allowed}.")

@model_validator(mode="after")
def validate_alpha(self) -> Self:
ranker, alpha = self.ranker, self.alpha
if ranker == HybridSearchRanker.LINEAR:
if alpha is None:
raise ValueError("alpha must be provided when using the linear ranker")
if not (0.0 <= alpha <= 1.0):
raise ValueError("alpha must be between 0 and 1")
else:
if alpha is not None:
warnings.warn(
"alpha parameter is only used when ranker is 'linear'. Ignoring alpha.",
UserWarning,
)
self.alpha = None
return self


class HybridCypherSearchModel(HybridSearchModel):
Expand Down
24 changes: 24 additions & 0 deletions tests/e2e/test_hybrid_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,27 @@ def test_hybrid_retriever_return_properties(driver: Driver) -> None:
assert len(results.items) == 5
for result in results.items:
assert isinstance(result, RetrieverResultItem)


@pytest.mark.usefixtures("setup_neo4j_for_retrieval")
def test_hybrid_retriever_search_text_linear_ranker(
driver: Driver, random_embedder: Embedder
) -> None:
retriever = HybridRetriever(
driver, "vector-index-name", "fulltext-index-name", random_embedder
)

top_k = 5
effective_search_ratio = 2
results = retriever.search(
query_text="Find me a book about Fremen",
top_k=top_k,
effective_search_ratio=effective_search_ratio,
ranker="linear",
alpha=0.9,
)

assert isinstance(results, RetrieverResult)
assert len(results.items) == 5
for result in results.items:
assert isinstance(result, RetrieverResultItem)
Loading

0 comments on commit 09440e0

Please sign in to comment.