From 4b9f64c234f35f61a5eef18d679121f84cb5487a Mon Sep 17 00:00:00 2001 From: Will Tai Date: Wed, 18 Sep 2024 16:34:18 +0100 Subject: [PATCH] Ruff and modify example --- examples/similarity_search_for_text_mistral.py | 4 +--- src/neo4j_graphrag/embeddings/mistral.py | 1 - tests/unit/embeddings/test_mistralai_embeddings.py | 6 +++--- tests/unit/llm/test_mistralaillm.py | 4 +++- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/similarity_search_for_text_mistral.py b/examples/similarity_search_for_text_mistral.py index 73b4cc3e5..db95b890b 100644 --- a/examples/similarity_search_for_text_mistral.py +++ b/examples/similarity_search_for_text_mistral.py @@ -3,9 +3,8 @@ from random import random from neo4j import GraphDatabase -from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.embeddings.mistral import MistralAIEmbeddings -from neo4j_graphrag.indexes import create_vector_index, drop_index_if_exists +from neo4j_graphrag.indexes import create_vector_index from neo4j_graphrag.retrievers import VectorRetriever URI = "neo4j://localhost:7687" @@ -20,7 +19,6 @@ embedder = MistralAIEmbeddings() # Creating the index -drop_index_if_exists(driver, INDEX_NAME) create_vector_index( driver, INDEX_NAME, diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index 0541df002..3993303e5 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -57,5 +57,4 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: model=self.model, inputs=[text], ) - print("@@@", embeddings_batch_response.data[0].embedding) return embeddings_batch_response.data[0].embedding diff --git a/tests/unit/embeddings/test_mistralai_embeddings.py b/tests/unit/embeddings/test_mistralai_embeddings.py index 68402d6d2..1ff71076d 100644 --- a/tests/unit/embeddings/test_mistralai_embeddings.py +++ b/tests/unit/embeddings/test_mistralai_embeddings.py @@ -14,9 +14,7 @@ # limitations under the License. from unittest.mock import MagicMock, Mock, patch - import pytest - from neo4j_graphrag.embeddings.mistral import MistralAIEmbeddings @@ -31,7 +29,9 @@ def test_mistralai_embedder_happy_path(mock_mistralai: Mock) -> None: mock_mistral_instance = mock_mistralai.return_value embeddings_batch_response_mock = MagicMock() embeddings_batch_response_mock.data = [MagicMock(embedding=[1.0, 2.0])] - mock_mistral_instance.embeddings.create.return_value = embeddings_batch_response_mock + mock_mistral_instance.embeddings.create.return_value = ( + embeddings_batch_response_mock + ) embedder = MistralAIEmbeddings() res = embedder.embed_query("my text") diff --git a/tests/unit/llm/test_mistralaillm.py b/tests/unit/llm/test_mistralaillm.py index a601ae862..88fb9bb5e 100644 --- a/tests/unit/llm/test_mistralaillm.py +++ b/tests/unit/llm/test_mistralaillm.py @@ -30,7 +30,9 @@ def test_mistralai_embeddings_happy_path(mock_mistral: Mock) -> None: mock_mistral_instance = mock_mistral.return_value embeddings_batch_response_mock = MagicMock() embeddings_batch_response_mock.data = [MagicMock(embedding=[1.0, 2.0, 3.0])] - mock_mistral_instance.embeddings.create.return_value = embeddings_batch_response_mock + mock_mistral_instance.embeddings.create.return_value = ( + embeddings_batch_response_mock + ) embedder = MistralAIEmbeddings() res = embedder.embed_query("some text")