Skip to content

Commit

Permalink
Updated Text2Cypher unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Jan 23, 2025
1 parent b5b08d3 commit b3ccd75
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 43 deletions.
21 changes: 10 additions & 11 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,19 @@ def retriever_mock() -> MagicMock:


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.VectorRetriever._verify_version")
def vector_retriever(
_verify_version_mock: MagicMock, driver: MagicMock
) -> VectorRetriever:
@patch("neo4j_graphrag.retrievers.base.get_version")
def vector_retriever(mock_get_version: MagicMock, driver: MagicMock) -> VectorRetriever:
mock_get_version.return_value = ((5, 23, 0), False)
return VectorRetriever(driver, "my-index")


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.VectorCypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def vector_cypher_retriever(
_verify_version_mock: MagicMock, driver: MagicMock
mock_get_version: MagicMock, driver: MagicMock
) -> VectorCypherRetriever:
retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score
"""
mock_get_version.return_value = ((5, 23, 0), False)
retrieval_query = "RETURN node.id AS node_id, node.text AS text, score"
return VectorCypherRetriever(driver, "my-index", retrieval_query)


Expand All @@ -77,10 +75,11 @@ def hybrid_retriever(mock_get_version: MagicMock, driver: MagicMock) -> HybridRe


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def t2c_retriever(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> Text2CypherRetriever:
mock_get_version.return_value = ((5, 23, 0), False)
return Text2CypherRetriever(driver, llm)


Expand Down
77 changes: 45 additions & 32 deletions tests/unit/retrievers/test_text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,42 +30,44 @@


def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None:
with patch(
"neo4j_graphrag.retrievers.base.Retriever._verify_version"
) as mock_verify:
with patch("neo4j_graphrag.retrievers.base.get_version") as mock_get_version:
mock_get_version.return_value = ((5, 23, 0), False)
Text2CypherRetriever(driver, llm, neo4j_schema="dummy-text")
mock_verify.assert_called_once()
mock_get_version.assert_called_once()


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_schema_retrieval(
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
Text2CypherRetriever(driver, llm)
get_schema_mock.assert_called_once()


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_schema_retrieval_failure(
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
get_schema_mock.side_effect = Neo4jError
with pytest.raises(SchemaFetchError):
Text2CypherRetriever(driver, llm)


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_invalid_neo4j_schema(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
with pytest.raises(RetrieverInitializationError) as exc_info:
Text2CypherRetriever(
driver=driver,
Expand All @@ -77,10 +79,11 @@ def test_t2c_retriever_invalid_neo4j_schema(
assert "Input should be a valid string" in str(exc_info.value)


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_invalid_search_query(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
with pytest.raises(SearchValidationError) as exc_info:
retriever = Text2CypherRetriever(
driver=driver, llm=llm, neo4j_schema="dummy-text"
Expand All @@ -91,10 +94,11 @@ def test_t2c_retriever_invalid_search_query(
assert "Input should be a valid string" in str(exc_info.value)


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_invalid_search_examples(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
with pytest.raises(RetrieverInitializationError) as exc_info:
Text2CypherRetriever(
driver=driver,
Expand All @@ -107,13 +111,14 @@ def test_t2c_retriever_invalid_search_examples(
assert "Initialization failed" in str(exc_info.value)


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_happy_path(
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
t2c_query = "MATCH (n) RETURN n;"
query_text = "may thy knife chip and shatter"
neo4j_schema = "dummy-schema"
Expand Down Expand Up @@ -147,10 +152,11 @@ def test_t2c_retriever_happy_path(
)


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_cypher_error(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
t2c_query = "this is not a cypher query"
neo4j_schema = "dummy-schema"
examples = ["example-1", "example-2"]
Expand All @@ -165,14 +171,15 @@ def test_t2c_retriever_cypher_error(
assert "Failed to get search result" in str(e)


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_with_result_format_function(
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
result_formatter: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
retriever = Text2CypherRetriever(
driver=driver, llm=llm, result_formatter=result_formatter
)
Expand All @@ -197,13 +204,14 @@ def test_t2c_retriever_with_result_format_function(
)


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_initialization_with_custom_prompt(
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
prompt = "This is a custom prompt. {query_text}"
retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt)
driver.execute_query.return_value = (
Expand All @@ -216,13 +224,14 @@ def test_t2c_retriever_initialization_with_custom_prompt(
llm.invoke.assert_called_once_with("This is a custom prompt. test")


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples(
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
prompt = "This is a custom prompt. {query_text}"
neo4j_schema = "dummy-schema"
examples = ["example-1", "example-2"]
Expand All @@ -245,10 +254,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
llm.invoke.assert_called_once_with("This is a custom prompt. test")


@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_invalid_custom_prompt_type(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
with pytest.raises(RetrieverInitializationError) as exc_info:
Text2CypherRetriever(
driver=driver,
Expand All @@ -259,13 +269,14 @@ def test_t2c_retriever_invalid_custom_prompt_type(
assert "Input should be a valid string" in str(exc_info.value)


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_with_custom_prompt_prompt_params(
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
prompt = "This is a custom prompt. {query_text} {examples_custom}"
query = "test"
examples = ["example A", "example B"]
Expand All @@ -283,13 +294,14 @@ def test_t2c_retriever_with_custom_prompt_prompt_params(
)


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
prompt = "This is a custom prompt. {query_text} {examples}"
query = "test"
examples = ["example A", "example B"]
Expand All @@ -313,15 +325,16 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
)


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_with_custom_prompt_and_schema(
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
mock_get_version: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
mock_get_version.return_value = ((5, 23, 0), False)
prompt = "This is a custom prompt. {query_text} {schema}"
query = "test"

Expand Down

0 comments on commit b3ccd75

Please sign in to comment.