Skip to content

Commit

Permalink
Fixes import error & adds delete_session_node option to Neo4jMessageH…
Browse files Browse the repository at this point in the history
…istory (neo4j#282)

* Moves LLMMessage to avoid a cicular import with LLM classes

* Updated more imports

* Updates docs

* Updated more imports

* Updated Neo4jMessageHistory to allow for optional session node deletion

* Updated LLMMessage deprecation warning
  • Loading branch information
alexthomas93 authored Feb 25, 2025
1 parent c944dca commit 69c9c68
Show file tree
Hide file tree
Showing 18 changed files with 93 additions and 41 deletions.
2 changes: 1 addition & 1 deletion docs/source/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ LLMResponse
LLMMessage
===========

.. autoclass:: neo4j_graphrag.llm.types.LLMMessage
.. autoclass:: neo4j_graphrag.types.LLMMessage


RagResultModel
Expand Down
2 changes: 1 addition & 1 deletion examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, List, Optional, Union

from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage


class CustomLLM(LLMInterface):
Expand Down
3 changes: 1 addition & 2 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult
from neo4j_graphrag.types import LLMMessage, RetrieverResult

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
UserMessage,
)
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

if TYPE_CHECKING:
from anthropic.types.message_param import MessageParam
Expand Down
6 changes: 2 additions & 4 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
from typing import Any, List, Optional, Union

from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

from .types import (
LLMMessage,
LLMResponse,
)
from .types import LLMResponse


class LLMInterface(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
UserMessage,
)
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

if TYPE_CHECKING:
from cohere import ChatMessages
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
UserMessage,
)
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

try:
from mistralai import Messages, Mistral
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

from .base import LLMInterface
from .types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from pydantic import ValidationError

from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

from ..exceptions import LLMGenerationError
from .base import LLMInterface
from .types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
Expand Down
19 changes: 14 additions & 5 deletions src/neo4j_graphrag/llm/types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from typing import Literal, TypedDict
import warnings
from typing import Any, Literal

from pydantic import BaseModel

from neo4j_graphrag.types import LLMMessage as _LLMMessage

class LLMResponse(BaseModel):
content: str

def __getattr__(name: str) -> Any:
if name == "LLMMessage":
warnings.warn(
"LLMMessage has been moved to neo4j_graphrag.types. Please update your imports.",
DeprecationWarning,
stacklevel=2,
)
return _LLMMessage
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

class LLMMessage(TypedDict):
role: Literal["system", "user", "assistant"]

class LLMResponse(BaseModel):
content: str


Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import BaseMessage, LLMMessage, LLMResponse, MessageList
from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

try:
from vertexai.generative_models import (
Expand Down
40 changes: 28 additions & 12 deletions src/neo4j_graphrag/message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@
import neo4j
from pydantic import PositiveInt

from neo4j_graphrag.llm.types import (
LLMMessage,
)
from neo4j_graphrag.types import (
LLMMessage,
Neo4jDriverModel,
Neo4jMessageHistoryModel,
)

CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"

CLEAR_SESSION_QUERY = (
DELETE_SESSION_AND_MESSAGES_QUERY = (
"MATCH (s:`{node_label}`) "
"WHERE s.id = $session_id "
"OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) "
Expand All @@ -38,6 +36,14 @@
"DETACH DELETE node;"
)

DELETE_MESSAGES_QUERY = (
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message:Message) "
"WHERE s.id = $session_id "
"MATCH p=(last_message)<-[:NEXT*0..]-(:Message) "
"UNWIND nodes(p) as node "
"DETACH DELETE node;"
)

GET_MESSAGES_QUERY = (
"MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
Expand Down Expand Up @@ -82,8 +88,8 @@ class InMemoryMessageHistory(MessageHistory):
.. code-block:: python
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import InMemoryMessageHistory
from neo4j_graphrag.types import LLMMessage
history = InMemoryMessageHistory()
Expand Down Expand Up @@ -125,8 +131,8 @@ class Neo4jMessageHistory(MessageHistory):
.. code-block:: python
import neo4j
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import Neo4jMessageHistory
from neo4j_graphrag.types import LLMMessage
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
Expand Down Expand Up @@ -204,9 +210,19 @@ def add_message(self, message: LLMMessage) -> None:
},
)

def clear(self) -> None:
"""Clear the message history."""
self._driver.execute_query(
query_=CLEAR_SESSION_QUERY.format(node_label="Session"),
parameters_={"session_id": self._session_id},
)
def clear(self, delete_session_node: bool = False) -> None:
"""Clear the message history.
Args:
delete_session_node (bool): Whether to delete the session node. Defaults to False.
"""
if delete_session_node:
self._driver.execute_query(
query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"),
parameters_={"session_id": self._session_id},
)
else:
self._driver.execute_query(
query_=DELETE_MESSAGES_QUERY.format(node_label="Session"),
parameters_={"session_id": self._session_id},
)
7 changes: 6 additions & 1 deletion src/neo4j_graphrag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, TypedDict, Union

import neo4j
from pydantic import (
Expand Down Expand Up @@ -263,3 +263,8 @@ def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]:
if isinstance(v, str) and len(v) == 0:
raise ValueError("session_id cannot be empty")
return v


class LLMMessage(TypedDict):
role: Literal["system", "user", "assistant"]
content: str
3 changes: 1 addition & 2 deletions tests/e2e/test_graphrag_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
from neo4j_graphrag.generation.graphrag import GraphRAG
from neo4j_graphrag.generation.types import RagResultModel
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import Neo4jMessageHistory
from neo4j_graphrag.retrievers import VectorCypherRetriever
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
from neo4j_graphrag.types import LLMMessage, RetrieverResult, RetrieverResultItem

from tests.e2e.conftest import BiologyEmbedder
from tests.e2e.utils import build_data_objects, populate_neo4j
Expand Down
32 changes: 29 additions & 3 deletions tests/e2e/test_message_history_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import neo4j
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import Neo4jMessageHistory
from neo4j_graphrag.types import LLMMessage


def test_neo4j_message_history_add_message(driver: neo4j.Driver) -> None:
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_neo4j_message_history_add_messages(driver: neo4j.Driver) -> None:
)


def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
def test_neo4j_message_history_clear_messages(driver: neo4j.Driver) -> None:
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
message_history.add_messages(
Expand All @@ -74,12 +74,38 @@ def test_neo4j_message_history_clear(driver: neo4j.Driver) -> None:
assert len(message_history.messages) == 2
message_history.clear()
assert len(message_history.messages) == 0
# Test that the session node is not deleted
results = driver.execute_query(
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
)
assert len(results.records) == 1
assert results.records[0]["s"]["id"] == "123"
assert list(results.records[0]["s"].labels) == ["Session"]


def test_neo4j_message_history_clear_session_and_messages(driver: neo4j.Driver) -> None:
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
message_history.add_messages(
[
LLMMessage(role="system", content="You are a helpful assistant."),
LLMMessage(role="user", content="Hello"),
]
)
assert len(message_history.messages) == 2
message_history.clear(delete_session_node=True)
assert len(message_history.messages) == 0
# Test that the session node is deleted
results = driver.execute_query(
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
)
assert results.records == []


def test_neo4j_message_history_clear_no_messages(driver: neo4j.Driver) -> None:
driver.execute_query(query_="MATCH (n) DETACH DELETE n;")
message_history = Neo4jMessageHistory(session_id="123", driver=driver)
message_history.clear()
message_history.clear(delete_session_node=True)
results = driver.execute_query(
query_="MATCH (s:`Session`) WHERE s.id = '123' RETURN s"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
from neo4j_graphrag.types import LLMMessage
from vertexai.generative_models import Content, Part


Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagResultModel
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import InMemoryMessageHistory
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
from neo4j_graphrag.types import LLMMessage, RetrieverResult, RetrieverResultItem


def test_graphrag_prompt_template() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from unittest.mock import MagicMock

import pytest
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import InMemoryMessageHistory, Neo4jMessageHistory
from neo4j_graphrag.types import LLMMessage
from pydantic import ValidationError


Expand Down

0 comments on commit 69c9c68

Please sign in to comment.