Skip to content

Commit

Permalink
Removed node_label parameter from Neo4jMessageHistory
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Feb 19, 2025
1 parent c3c28c8 commit 8ae9423
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 35 deletions.
4 changes: 1 addition & 3 deletions examples/customize/llms/llm_with_neo4j_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
database=DATABASE,
)

history = Neo4jMessageHistory(
session_id="123", driver=driver, node_label="Message", window=10
)
history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)

for question in questions:
res: LLMResponse = llm.invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@
llm=llm,
)

history = Neo4jMessageHistory(
session_id="123", driver=driver, node_label="Message", window=10
)
history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)

questions = [
"Who starred in the Apollo 13 movies?",
Expand Down
15 changes: 5 additions & 10 deletions src/neo4j_graphrag/message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class Neo4jMessageHistory(MessageHistory):
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
history = Neo4jMessageHistory(
session_id="123", driver=driver, node_label="Message", window=10
session_id="123", driver=driver, window=10
)
message = LLMMessage(role="user", content="Hello!")
Expand All @@ -147,33 +147,28 @@ def __init__(
self,
session_id: Union[str, int],
driver: neo4j.Driver,
node_label: str = "Session",
window: Optional[PositiveInt] = None,
) -> None:
validated_data = Neo4jMessageHistoryModel(
session_id=session_id,
driver_model=Neo4jDriverModel(driver=driver),
node_label=node_label,
window=window,
)
self._driver = validated_data.driver_model.driver
self._session_id = validated_data.session_id
self._node_label = validated_data.node_label
self._window = (
"" if validated_data.window is None else validated_data.window - 1
)
# Create session node
self._driver.execute_query(
query_=CREATE_SESSION_NODE_QUERY.format(node_label=self._node_label),
query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"),
parameters_={"session_id": self._session_id},
)

@property
def messages(self) -> List[LLMMessage]:
result = self._driver.execute_query(
query_=GET_MESSAGES_QUERY.format(
node_label=self._node_label, window=self._window
),
query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window),
parameters_={"session_id": self._session_id},
)
messages = [
Expand All @@ -199,7 +194,7 @@ def add_message(self, message: LLMMessage) -> None:
message (LLMMessage): The message to add.
"""
self._driver.execute_query(
query_=ADD_MESSAGE_QUERY.format(node_label=self._node_label),
query_=ADD_MESSAGE_QUERY.format(node_label="Session"),
parameters_={
"role": message["role"],
"content": message["content"],
Expand All @@ -210,6 +205,6 @@ def add_message(self, message: LLMMessage) -> None:
def clear(self) -> None:
"""Clear the message history."""
self._driver.execute_query(
query_=CLEAR_SESSION_QUERY.format(node_label=self._node_label),
query_=CLEAR_SESSION_QUERY.format(node_label="Session"),
parameters_={"session_id": self._session_id},
)
7 changes: 0 additions & 7 deletions src/neo4j_graphrag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,10 @@ class Text2CypherRetrieverModel(BaseModel):
class Neo4jMessageHistoryModel(BaseModel):
session_id: Union[str, int]
driver_model: Neo4jDriverModel
node_label: str = "Session"
window: Optional[PositiveInt] = None

@field_validator("session_id")
def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]:
if isinstance(v, str) and len(v) == 0:
raise ValueError("session_id cannot be empty")
return v

@field_validator("node_label")
def validate_node_label(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("node_label cannot be empty")
return v
1 change: 0 additions & 1 deletion tests/e2e/test_graphrag_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_graphrag_happy_path_with_neo4j_message_history(
message_history = Neo4jMessageHistory(
driver=driver,
session_id="123",
node_label="Message",
)
message_history.clear()
message_history.add_messages(
Expand Down
14 changes: 3 additions & 11 deletions tests/unit/test_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,19 @@ def test_in_memory_message_history_clear() -> None:

def test_neo4j_message_history_invalid_session_id(driver: MagicMock) -> None:
with pytest.raises(ValidationError) as exc_info:
Neo4jMessageHistory(session_id=1.5, driver=driver, node_label="123", window=1) # type: ignore[arg-type]
Neo4jMessageHistory(session_id=1.5, driver=driver, window=1) # type: ignore[arg-type]
assert "Input should be a valid string" in str(exc_info.value)


def test_neo4j_message_history_invalid_driver() -> None:
with pytest.raises(ValidationError) as exc_info:
Neo4jMessageHistory(session_id="123", driver=1.5, node_label="123", window=1) # type: ignore[arg-type]
Neo4jMessageHistory(session_id="123", driver=1.5, window=1) # type: ignore[arg-type]
assert "Input should be an instance of Driver" in str(exc_info.value)


def test_neo4j_message_history_invalid_node_label(driver: MagicMock) -> None:
with pytest.raises(ValidationError) as exc_info:
Neo4jMessageHistory(session_id="123", driver=driver, node_label=1.5, window=1) # type: ignore[arg-type]
assert "Input should be a valid string" in str(exc_info.value)


def test_neo4j_message_history_invalid_window(driver: MagicMock) -> None:
with pytest.raises(ValidationError) as exc_info:
Neo4jMessageHistory(
session_id="123", driver=driver, node_label="123", window=-1
)
Neo4jMessageHistory(session_id="123", driver=driver, window=-1)
assert "Input should be greater than 0" in str(exc_info.value)


Expand Down

0 comments on commit 8ae9423

Please sign in to comment.