Skip to content

Commit

Permalink
Added extract_cypher function
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Feb 19, 2025
1 parent 4dae4eb commit 4bdb7e2
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 3 deletions.
45 changes: 44 additions & 1 deletion src/neo4j_graphrag/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
import re
from typing import Any, Callable, Dict, Optional

import neo4j
Expand Down Expand Up @@ -44,6 +45,48 @@
logger = logging.getLogger(__name__)


def extract_cypher(text: str) -> str:
"""Extract and format Cypher query from text, handling code blocks and special characters.
This function performs two main operations:
1. Extracts Cypher code from within triple backticks (```), if present
2. Automatically adds backtick quotes around multi-word identifiers:
- Node labels (e.g., ":Data Science" becomes ":`Data Science`")
- Property keys (e.g., "first name:" becomes "`first name`:")
- Relationship types (e.g., "[:WORKS WITH]" becomes "[:`WORKS WITH`]")
Args:
text (str): Raw text that may contain Cypher code, either within triple
backticks or as plain text.
Returns:
str: Properly formatted Cypher query with correct backtick quoting.
"""
# Extract Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
cypher_query = matches[0] if matches else text
# Quote node labels in backticks if they contain spaces and are not already quoted
cypher_query = re.sub(
r":\s*(?!`\s*)(\s*)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(?!\s*`)(\s*)",
r":`\2`",
cypher_query,
)
# Quote property keys in backticks if they contain spaces and are not already quoted
cypher_query = re.sub(
r"([,{]\s*)(?!`)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(?!`)(\s*:)",
r"\1`\2`\3",
cypher_query,
)
# Quote relationship types in backticks if they contain spaces and are not already quoted
cypher_query = re.sub(
r"(\[\s*[a-zA-Z0-9_]*\s*:\s*)(?!`)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(?!`)(\s*(?:\]|-))",
r"\1`\2`\3",
cypher_query,
)
return cypher_query


class Text2CypherRetriever(Retriever):
"""
Allows for the retrieval of records from a Neo4j database using natural language.
Expand Down Expand Up @@ -168,7 +211,7 @@ def get_search_results(

try:
llm_result = self.llm.invoke(prompt)
t2c_query = llm_result.content
t2c_query = extract_cypher(llm_result.content)
logger.debug("Text2CypherRetriever Cypher query: %s", t2c_query)
records, _, _ = self.driver.execute_query(
query_=t2c_query,
Expand Down
86 changes: 84 additions & 2 deletions tests/unit/retrievers/test_text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from neo4j_graphrag.generation.prompts import Text2CypherTemplate
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.retrievers import Text2CypherRetriever
from neo4j_graphrag.retrievers.text2cypher import Text2CypherRetriever, extract_cypher
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem


Expand Down Expand Up @@ -204,9 +204,11 @@ def test_t2c_retriever_with_result_format_function(
)


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_initialization_with_custom_prompt(
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand All @@ -224,9 +226,11 @@ def test_t2c_retriever_initialization_with_custom_prompt(
llm.invoke.assert_called_once_with("This is a custom prompt. test")


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples(
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand Down Expand Up @@ -254,9 +258,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
llm.invoke.assert_called_once_with("This is a custom prompt. test")


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples_for_prompt_params(
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand Down Expand Up @@ -286,9 +292,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
)


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_examples(
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand Down Expand Up @@ -321,9 +329,13 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_e
)


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_invalid_custom_prompt_type(
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False, False)
with pytest.raises(RetrieverInitializationError) as exc_info:
Expand All @@ -336,9 +348,11 @@ def test_t2c_retriever_invalid_custom_prompt_type(
assert "Input should be a valid string" in str(exc_info.value)


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_with_custom_prompt_prompt_params(
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand All @@ -361,9 +375,11 @@ def test_t2c_retriever_with_custom_prompt_prompt_params(
)


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand Down Expand Up @@ -392,11 +408,13 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
)


@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
@patch("neo4j_graphrag.retrievers.base.get_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_with_custom_prompt_and_schema(
get_schema_mock: MagicMock,
mock_get_version: MagicMock,
mock_extract_cypher: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
Expand All @@ -419,3 +437,67 @@ def test_t2c_retriever_with_custom_prompt_and_schema(

get_schema_mock.assert_not_called()
llm.invoke.assert_called_once_with("""This is a custom prompt. test """)


@pytest.mark.parametrize(
"description, cypher_query, expected_output",
[
("No changes", "MATCH (n) RETURN n;", "MATCH (n) RETURN n;"),
(
"Surrounded by backticks",
"Cypher query: ```MATCH (n) RETURN n;```",
"MATCH (n) RETURN n;",
),
(
"Spaces in label",
"Cypher query: ```MATCH (n: Label With Spaces ) RETURN n;```",
"MATCH (n:`Label With Spaces`) RETURN n;",
),
(
"No spaces in label",
"Cypher query: ```MATCH (n: LabelWithNoSpaces ) RETURN n;```",
"MATCH (n: LabelWithNoSpaces ) RETURN n;",
),
(
"Backticks in label",
"Cypher query: ```MATCH (n: `LabelWithBackticks` ) RETURN n;```",
"MATCH (n: `LabelWithBackticks` ) RETURN n;",
),
(
"Spaces in property key",
"Cypher query: ```MATCH (n: { prop 1: 1, prop 2: 2 }) RETURN n;```",
"MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;",
),
(
"No spaces in property key",
"Cypher query: ```MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;```",
"MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;",
),
(
"Backticks in property key",
"Cypher query: ```MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;```",
"MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;",
),
(
"Spaces in relationship type",
"Cypher query: ```MATCH (n)-[: Relationship With Spaces ]->(m) RETURN n, m;```",
"MATCH (n)-[:`Relationship With Spaces`]->(m) RETURN n, m;",
),
(
"No spaces in relationship type",
"Cypher query: ```MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;```",
"MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;",
),
(
"Backticks in relationship type",
"Cypher query: ```MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;```",
"MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;",
),
],
)
def test_extract_cypher(
description: str, cypher_query: str, expected_output: str
) -> None:
assert (
extract_cypher(cypher_query) == expected_output
), f"Failed test case: {description}"

0 comments on commit 4bdb7e2

Please sign in to comment.