diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 75afb3fa6..1c6138f3d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -51,21 +51,19 @@ def retriever_mock() -> MagicMock: @pytest.fixture(scope="function") -@patch("neo4j_graphrag.retrievers.VectorRetriever._verify_version") -def vector_retriever( - _verify_version_mock: MagicMock, driver: MagicMock -) -> VectorRetriever: +@patch("neo4j_graphrag.retrievers.base.get_version") +def vector_retriever(mock_get_version: MagicMock, driver: MagicMock) -> VectorRetriever: + mock_get_version.return_value = ((5, 23, 0), False) return VectorRetriever(driver, "my-index") @pytest.fixture(scope="function") -@patch("neo4j_graphrag.retrievers.VectorCypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def vector_cypher_retriever( - _verify_version_mock: MagicMock, driver: MagicMock + mock_get_version: MagicMock, driver: MagicMock ) -> VectorCypherRetriever: - retrieval_query = """ - RETURN node.id AS node_id, node.text AS text, score - """ + mock_get_version.return_value = ((5, 23, 0), False) + retrieval_query = "RETURN node.id AS node_id, node.text AS text, score" return VectorCypherRetriever(driver, "my-index", retrieval_query) @@ -77,10 +75,11 @@ def hybrid_retriever(mock_get_version: MagicMock, driver: MagicMock) -> HybridRe @pytest.fixture(scope="function") -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def t2c_retriever( - _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock ) -> Text2CypherRetriever: + mock_get_version.return_value = ((5, 23, 0), False) return Text2CypherRetriever(driver, llm) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 05b1e5450..5f23164c8 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -30,42 +30,44 @@ def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None: - with patch( - "neo4j_graphrag.retrievers.base.Retriever._verify_version" - ) as mock_verify: + with patch("neo4j_graphrag.retrievers.base.get_version") as mock_get_version: + mock_get_version.return_value = ((5, 23, 0), False) Text2CypherRetriever(driver, llm, neo4j_schema="dummy-text") - mock_verify.assert_called_once() + mock_get_version.assert_called_once() -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") @patch("neo4j_graphrag.retrievers.text2cypher.get_schema") def test_t2c_retriever_schema_retrieval( get_schema_mock: MagicMock, - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) Text2CypherRetriever(driver, llm) get_schema_mock.assert_called_once() -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") @patch("neo4j_graphrag.retrievers.text2cypher.get_schema") def test_t2c_retriever_schema_retrieval_failure( get_schema_mock: MagicMock, - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) get_schema_mock.side_effect = Neo4jError with pytest.raises(SchemaFetchError): Text2CypherRetriever(driver, llm) -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_invalid_neo4j_schema( - _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) with pytest.raises(RetrieverInitializationError) as exc_info: Text2CypherRetriever( driver=driver, @@ -77,10 +79,11 @@ def test_t2c_retriever_invalid_neo4j_schema( assert "Input should be a valid string" in str(exc_info.value) -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_invalid_search_query( - _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) with pytest.raises(SearchValidationError) as exc_info: retriever = Text2CypherRetriever( driver=driver, llm=llm, neo4j_schema="dummy-text" @@ -91,10 +94,11 @@ def test_t2c_retriever_invalid_search_query( assert "Input should be a valid string" in str(exc_info.value) -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_invalid_search_examples( - _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) with pytest.raises(RetrieverInitializationError) as exc_info: Text2CypherRetriever( driver=driver, @@ -107,13 +111,14 @@ def test_t2c_retriever_invalid_search_examples( assert "Initialization failed" in str(exc_info.value) -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_happy_path( - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) t2c_query = "MATCH (n) RETURN n;" query_text = "may thy knife chip and shatter" neo4j_schema = "dummy-schema" @@ -147,10 +152,11 @@ def test_t2c_retriever_happy_path( ) -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_cypher_error( - _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) t2c_query = "this is not a cypher query" neo4j_schema = "dummy-schema" examples = ["example-1", "example-2"] @@ -165,14 +171,15 @@ def test_t2c_retriever_cypher_error( assert "Failed to get search result" in str(e) -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_with_result_format_function( - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, result_formatter: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) retriever = Text2CypherRetriever( driver=driver, llm=llm, result_formatter=result_formatter ) @@ -197,13 +204,14 @@ def test_t2c_retriever_with_result_format_function( ) -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_initialization_with_custom_prompt( - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) prompt = "This is a custom prompt. {query_text}" retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt) driver.execute_query.return_value = ( @@ -216,13 +224,14 @@ def test_t2c_retriever_initialization_with_custom_prompt( llm.invoke.assert_called_once_with("This is a custom prompt. test") -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples( - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) prompt = "This is a custom prompt. {query_text}" neo4j_schema = "dummy-schema" examples = ["example-1", "example-2"] @@ -245,10 +254,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples llm.invoke.assert_called_once_with("This is a custom prompt. test") -@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_invalid_custom_prompt_type( - _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) with pytest.raises(RetrieverInitializationError) as exc_info: Text2CypherRetriever( driver=driver, @@ -259,13 +269,14 @@ def test_t2c_retriever_invalid_custom_prompt_type( assert "Input should be a valid string" in str(exc_info.value) -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_with_custom_prompt_prompt_params( - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) prompt = "This is a custom prompt. {query_text} {examples_custom}" query = "test" examples = ["example A", "example B"] @@ -283,13 +294,14 @@ def test_t2c_retriever_with_custom_prompt_prompt_params( ) -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") def test_t2c_retriever_with_custom_prompt_bad_prompt_params( - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) prompt = "This is a custom prompt. {query_text} {examples}" query = "test" examples = ["example A", "example B"] @@ -313,15 +325,16 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params( ) -@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.base.get_version") @patch("neo4j_graphrag.retrievers.text2cypher.get_schema") def test_t2c_retriever_with_custom_prompt_and_schema( get_schema_mock: MagicMock, - _verify_version_mock: MagicMock, + mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock, neo4j_record: MagicMock, ) -> None: + mock_get_version.return_value = ((5, 23, 0), False) prompt = "This is a custom prompt. {query_text} {schema}" query = "test"