Skip to content

Commit

Permalink
Neo4j message history (neo4j#273)
Browse files Browse the repository at this point in the history
* Added message history classes

* Updated Neo4jMessageHistoryModel

* Fixed spelling error

* Fixed tests

* Added test_graphrag_happy_path_with_neo4j_message_history

* Updated LLMs

* Added missing copyright headers

* Refactored graphrag

* Added docstrings to message history classes

* Added message history examples

* Updated docs

* Updated CHANGELOG

* Removed Neo4jMessageHistory __del__ method

* Makes the build_query and chat_summary_prompt methods in the GraphRAG class private

* Added a threading lock to InMemoryMessageHistory

* Removed node_label parameter from Neo4jMessageHistory

* Updated CLEAR_SESSION_QUERY

* Fixed CLEAR_SESSION_QUERY
  • Loading branch information
alexthomas93 authored Feb 21, 2025
1 parent 4ce3b56 commit b893584
Show file tree
Hide file tree
Showing 23 changed files with 908 additions and 81 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Support for effective_search_ratio parameter in vector and hybrid searches.
- Introduced upsert_vectors utility function for batch upserting embeddings to vector indexes.
- Introduced `extract_cypher` function to enhance Cypher query extraction and formatting in `Text2CypherRetriever`.
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
- Added examples and documentation for using message history with Neo4j and in-memory storage.
- Updated LLM and GraphRAG classes to support new message history classes.

### Changed

Expand Down
9 changes: 9 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,15 @@ Database Interaction
.. autofunction:: neo4j_graphrag.schema.format_schema


***************
Message History
***************

.. autoclass:: neo4j_graphrag.message_history.InMemoryMessageHistory

.. autoclass:: neo4j_graphrag.message_history.Neo4jMessageHistory


******
Errors
******
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Note that the below example is not the only way you can upsert data into your Ne


.. code:: python
from neo4j import GraphDatabase
from neo4j_graphrag.indexes import upsert_vectors
from neo4j_graphrag.types import EntityType
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ Populate a Vector Index
==========================

.. code:: python
from random import random
from neo4j import GraphDatabase
Expand Down
3 changes: 2 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ are listed in [the last section of this file](#customize).

- [End to end GraphRAG](./answer/graphrag.py)
- [GraphRAG with message history](./question_answering/graphrag_with_message_history.py)

- [GraphRAG with Neo4j message history](./question_answering/graphrag_with_neo4j_message_history.py)

## Customize

Expand All @@ -75,6 +75,7 @@ are listed in [the last section of this file](#customize).
- [Custom LLM](./customize/llms/custom_llm.py)

- [Message history](./customize/llms/llm_with_message_history.py)
- [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py)
- [System Instruction](./customize/llms/llm_with_system_instructions.py)


Expand Down
7 changes: 4 additions & 3 deletions examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
import string
from typing import Any, Optional
from typing import Any, List, Optional, Union

from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import MessageHistory


class CustomLLM(LLMInterface):
Expand All @@ -15,7 +16,7 @@ def __init__(
def invoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
content: str = (
Expand All @@ -26,7 +27,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
raise NotImplementedError()
Expand Down
59 changes: 59 additions & 0 deletions examples/customize/llms/llm_with_neo4j_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""This example illustrates the message_history feature
of the LLMInterface by mocking a conversation between a user
and an LLM about Tom Hanks.
Neo4j is used as the database for storing the message history.
OpenAILLM can be replaced by any supported LLM from this package.
"""

import neo4j
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
from neo4j_graphrag.message_history import Neo4jMessageHistory

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"

# set api key here on in the OPENAI_API_KEY env var
api_key = None

llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)

questions = [
"What are some movies Tom Hanks starred in?",
"Is he also a director?",
"Wow, that's impressive. And what about his personal life, does he have children?",
]

driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)

for question in questions:
res: LLMResponse = llm.invoke(
question,
message_history=history,
)
history.add_message(
{
"role": "user",
"content": question,
}
)
history.add_message(
{
"role": "assistant",
"content": res.content,
}
)

print("#" * 50, question)
print(res.content)
print("#" * 50)
87 changes: 87 additions & 0 deletions examples/question_answering/graphrag_with_neo4j_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""End to end example of building a RAG pipeline backed by a Neo4j database,
simulating a chat with message history which is also stored in Neo4j.
Requires OPENAI_API_KEY to be in the env var.
"""

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.generation import GraphRAG
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.message_history import Neo4jMessageHistory
from neo4j_graphrag.retrievers import VectorCypherRetriever

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"


driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
)

embedder = OpenAIEmbeddings()

retriever = VectorCypherRetriever(
driver,
index_name=INDEX,
retrieval_query="""
WITH node as movie, score
CALL(movie) {
MATCH (movie)<-[:ACTED_IN]-(p:Person)
RETURN collect(p.name) as actors
}
CALL(movie) {
MATCH (movie)<-[:DIRECTED]-(p:Person)
RETURN collect(p.name) as directors
}
RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors
""",
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

rag = GraphRAG(
retriever=retriever,
llm=llm,
)

history = Neo4jMessageHistory(session_id="123", driver=driver, window=10)

questions = [
"Who starred in the Apollo 13 movies?",
"Who was its director?",
"In which year was this movie released?",
]

for question in questions:
result = rag.search(
question,
return_context=False,
message_history=history,
)

answer = result.answer
print("#" * 50, question)
print(answer)
print("#" * 50)

history.add_message(
{
"role": "user",
"content": question,
}
)
history.add_message(
{
"role": "assistant",
"content": answer,
}
)

driver.close()
25 changes: 15 additions & 10 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
import warnings
from typing import Any, Optional
from typing import Any, List, Optional, Union

from pydantic import ValidationError

Expand All @@ -28,6 +28,7 @@
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
def search(
self,
query_text: str = "",
message_history: Optional[list[LLMMessage]] = None,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -102,7 +103,8 @@ def search(
Args:
query_text (str): The user question.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever.
search method; e.g.: top_k
Expand All @@ -127,7 +129,9 @@ def search(
)
except ValidationError as e:
raise SearchValidationError(e.errors())
query = self.build_query(validated_data.query_text, message_history)
if isinstance(message_history, MessageHistory):
message_history = message_history.messages
query = self._build_query(validated_data.query_text, message_history)
retriever_result: RetrieverResult = self.retriever.search(
query_text=query, **validated_data.retriever_config
)
Expand All @@ -147,12 +151,14 @@ def search(
result["retriever_result"] = retriever_result
return RagResultModel(**result)

def build_query(
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
def _build_query(
self,
query_text: str,
message_history: Optional[List[LLMMessage]] = None,
) -> str:
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
if message_history:
summarization_prompt = self.chat_summary_prompt(
summarization_prompt = self._chat_summary_prompt(
message_history=message_history
)
summary = self.llm.invoke(
Expand All @@ -162,10 +168,9 @@ def build_query(
return self.conversation_prompt(summary=summary, current_query=query_text)
return query_text

def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
f"{message['role']}: {message['content']}" for message in message_history
]
history = "\n".join(message_list)
return f"""
Expand Down
Loading

0 comments on commit b893584

Please sign in to comment.