Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Qdrant Component #447

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added chroma/chroma.sqlite3
Binary file not shown.
2 changes: 1 addition & 1 deletion realtime_ai_character/audio/speech_to_text/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from realtime_ai_character.audio.speech_to_text.base import SpeechToText
from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton
from realtime_ai_character.singleton import Singleton

logger = get_logger(__name__)
config = types.SimpleNamespace(**{
Expand Down
2 changes: 1 addition & 1 deletion realtime_ai_character/audio/speech_to_text/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from realtime_ai_character.audio.speech_to_text.base import SpeechToText
from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton
from realtime_ai_character.singleton import Singleton

DEBUG = False
logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion realtime_ai_character/audio/text_to_speech/edge_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from edge_tts import VoicesManager

from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton
from realtime_ai_character.singleton import Singleton
from realtime_ai_character.audio.text_to_speech.base import TextToSpeech

logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion realtime_ai_character/audio/text_to_speech/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import httpx

from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton
from realtime_ai_character.singleton import Singleton
from realtime_ai_character.audio.text_to_speech.base import TextToSpeech

logger = get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import google.auth.transport.requests

from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton
from realtime_ai_character.singleton import Singleton
from realtime_ai_character.audio.text_to_speech.base import TextToSpeech

logger = get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import httpx

from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton
from realtime_ai_character.singleton import Singleton
from realtime_ai_character.audio.text_to_speech.base import TextToSpeech

logger = get_logger(__name__)
Expand Down
16 changes: 8 additions & 8 deletions realtime_ai_character/character_catalog/catalog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from langchain.text_splitter import CharacterTextSplitter

from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Singleton, Character
from realtime_ai_character.database.chroma import get_chroma
from realtime_ai_character.utils import Character
from realtime_ai_character.singleton import Singleton
from realtime_ai_character.database import get_database
from readerwriterlock import rwlock
from realtime_ai_character.database.connection import get_db
from realtime_ai_character.models.character import Character as CharacterModel
Expand All @@ -24,25 +25,24 @@
class CatalogManager(Singleton):
def __init__(self, overwrite=True):
super().__init__()
self.db = get_chroma()
self.db = get_database()
self.sql_db = next(get_db())
self.sql_load_interval = 30
self.sql_load_lock = rwlock.RWLockFair()

if overwrite:
logger.info('Overwriting existing data in the chroma.')
self.db.delete_collection()
self.db = get_chroma()

self.db.clear_instance()
self.db = get_database()

self.characters = {}
self.author_name_cache = {}
self.load_characters_from_community(overwrite)
self.load_characters(overwrite)
if overwrite:
logger.info('Persisting data in the chroma.')
self.db.persist()
logger.info(
f"Total document load: {self.db._client.get_collection('llm').count()}")
self.run_load_sql_db_thread = True
self.load_sql_db_thread = threading.Thread(target=self.load_sql_db_loop)
self.load_sql_db_thread.daemon = True
Expand Down
19 changes: 19 additions & 0 deletions realtime_ai_character/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from realtime_ai_character.database.base import Database

def get_database(db: str = None) -> Database:
if not db:
db = os.getenv('DATABASE_USE', 'QDRANT')

if db == 'QDRANT':
from realtime_ai_character.database.qdrant import Qdrant
Qdrant.initialize() # Adjust with appropriate arguments if necessary
return Qdrant.get_instance()

elif db == 'CHROMA':
from realtime_ai_character.database.chroma import Chroma
Chroma.initialize() # Adjust with appropriate arguments if necessary
return Chroma.get_instance()

else:
raise NotImplementedError(f'Unknown database engine: {db}')
22 changes: 22 additions & 0 deletions realtime_ai_character/database/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
from abc import ABC, abstractmethod
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class Database(ABC):
@abstractmethod
def delete_collection(self):
pass

@abstractmethod
def persist(self):
pass

@abstractmethod
def add_documents(self, docs):
pass

@abstractmethod
def similarity_search(self, query):
pass

@abstractmethod
def generate_context(self, docs, character):
pass
38 changes: 30 additions & 8 deletions realtime_ai_character/database/chroma.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.vectorstores import Chroma as LangChainChroma
from langchain.embeddings import OpenAIEmbeddings
from realtime_ai_character.logger import get_logger
from .base import Database
from realtime_ai_character.singleton import Singleton
from realtime_ai_character.utils import Character

load_dotenv()
logger = get_logger(__name__)
Expand All @@ -12,11 +15,30 @@
embedding = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"), deployment=os.getenv(
"OPENAI_API_EMBEDDING_DEPLOYMENT_NAME", "text-embedding-ada-002"), chunk_size=1)

class Chroma(Singleton, Database):
def __init__(self):
super().__init__()
self.db = LangChainChroma(
collection_name='llm',
embedding_function=embedding,
persist_directory='./chroma.db'
)
print("There are", self.db._collection.count(), "in the collection")

def delete_collection(self):
self.db.delete_collection()

def get_chroma():
chroma = Chroma(
collection_name='llm',
embedding_function=embedding,
persist_directory='./chroma.db'
)
return chroma
def persist(self):
self.db.persist()

def add_documents(self, docs):
self.db.add_documents(docs)

def similarity_search(self, query):
return self.db.similarity_search(query)

def generate_context(self, docs, character: Character) -> str:
docs = [d for d in docs if d.metadata['character_name'] == character.name]
logger.info(f'Found {len(docs)} documents')
context = '\n'.join([d.page_content for d in docs])
return context
92 changes: 92 additions & 0 deletions realtime_ai_character/database/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
from dotenv import load_dotenv
from langchain.embeddings import OpenAIEmbeddings
from realtime_ai_character.logger import get_logger
from .base import Database
from realtime_ai_character.singleton import Singleton
from qdrant_client.http.models import Distance, VectorParams, Batch
from qdrant_client import QdrantClient
from realtime_ai_character.utils import Character
from itertools import islice
import uuid

load_dotenv()
logger = get_logger(__name__)

embedding = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
if os.getenv('OPENAI_API_TYPE') == 'azure':
embedding = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"), deployment=os.getenv(
"OPENAI_API_EMBEDDING_DEPLOYMENT_NAME", "text-embedding-ada-002"), chunk_size=1)

class Qdrant(Singleton, Database):
def __init__(self):
super().__init__()
self.db = QdrantClient(location=":memory:")
my_collection = "llm"
llm = self.db.recreate_collection(
collection_name=my_collection,
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
)
print(llm)
collection_info = self.db.get_collection(collection_name="llm")
list(collection_info)

def delete_collection(self):
return self.db.delete_collection(collection_name="llm")

def persist(self):
pass

def add_documents(self, docs):
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
return self._add_texts(texts, metadatas)

def _add_texts(self, texts, metadatas=None, ids=None, batch_size=64, **kwargs):
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, metadatas, ids, batch_size):
self.db.upsert(collection_name="llm", points=Batch(ids=batch_ids, vectors=[p.vector for p in points], payloads=[p.payload for p in points]), **kwargs)
added_ids.extend(batch_ids)
return added_ids

def _generate_rest_batches(self, texts, metadatas=None, ids=None, batch_size=64):
from qdrant_client.http import models as rest

texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while (batch_texts := list(islice(texts_iterator, batch_size))):
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
batch_embeddings = [embedding.embed_query(text) for text in batch_texts]
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=self._build_payloads(text, metadata)
)
for point_id, vector, text, metadata in zip(batch_ids, batch_embeddings, batch_texts, batch_metadatas or [None] * len(batch_texts))
]
yield batch_ids, points

def _build_payloads(self, text, metadata):
if text is None:
raise ValueError("Text is None. Please ensure all texts are valid.")
return {
"page_content": text,
"metadata": metadata
}

def similarity_search(self, query):
embedded_query = embedding.embed_query(query)
result = self.db.search(
collection_name="llm",
query_vector=embedded_query
)
return result

def generate_context(self, docs, character: Character) -> str:
docs = [d for d in docs if d.payload["metadata"]['character_name'] == character.name]
logger.info(f'Found {len(docs)} documents')
context = '\n'.join([d.payload["page_content"] for d in docs])
return context
9 changes: 4 additions & 5 deletions realtime_ai_character/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, HumanMessage

from realtime_ai_character.database.chroma import get_chroma
from realtime_ai_character.database import get_database
from realtime_ai_character.llm.base import AsyncCallbackAudioHandler, \
AsyncCallbackTextHandler, LLM, QuivrAgent, SearchAgent
from realtime_ai_character.logger import get_logger
Expand All @@ -25,7 +25,7 @@ def __init__(self, model):
"temperature": 0.5,
"streaming": True
}
self.db = get_chroma()
self.db = get_database()
self.search_agent = SearchAgent()
self.quivr_agent = QuivrAgent()

Expand Down Expand Up @@ -71,10 +71,9 @@ async def achat(self,

def _generate_context(self, query, character: Character) -> str:
docs = self.db.similarity_search(query)
docs = [d for d in docs if d.metadata['character_name'] == character.name]
logger.info(f'Found {len(docs)} documents')

context = self.db.generate_context(docs, character)

context = '\n'.join([d.page_content for d in docs])
return context

def _generate_memory_context(self, user_id: str, query: str) -> str:
Expand Down
9 changes: 4 additions & 5 deletions realtime_ai_character/llm/anyscale_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, HumanMessage

from realtime_ai_character.database.chroma import get_chroma
from realtime_ai_character.database import get_database
from realtime_ai_character.llm.base import AsyncCallbackAudioHandler, AsyncCallbackTextHandler, \
LLM, SearchAgent
from realtime_ai_character.logger import get_logger
Expand All @@ -28,7 +28,7 @@ def __init__(self, model):
"temperature": 0.5,
"streaming": True
}
self.db = get_chroma()
self.db = get_database()
self.search_agent = None
self.search_agent = SearchAgent()

Expand Down Expand Up @@ -68,10 +68,9 @@ async def achat(self,

def _generate_context(self, query, character: Character) -> str:
docs = self.db.similarity_search(query)
docs = [d for d in docs if d.metadata['character_name'] == character.name]
logger.info(f'Found {len(docs)} documents')

context = self.db.generate_context(docs, character)

context = '\n'.join([d.page_content for d in docs])
return context

def _generate_memory_context(self, user_id: str, query: str) -> str:
Expand Down
9 changes: 4 additions & 5 deletions realtime_ai_character/llm/local_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, HumanMessage

from realtime_ai_character.database.chroma import get_chroma
from realtime_ai_character.database import get_database
from realtime_ai_character.llm.base import (
AsyncCallbackAudioHandler,
AsyncCallbackTextHandler,
Expand All @@ -29,7 +29,7 @@ def __init__(self, url):

)
self.config = {"model": "Local LLM", "temperature": 0.5, "streaming": True}
self.db = get_chroma()
self.db = get_database()
self.search_agent = None
self.search_agent = SearchAgent()

Expand Down Expand Up @@ -73,8 +73,7 @@ async def achat(

def _generate_context(self, query, character: Character) -> str:
docs = self.db.similarity_search(query)
docs = [d for d in docs if d.metadata["character_name"] == character.name]
logger.info(f"Found {len(docs)} documents")

context = self.db.generate_context(docs, character)

context = "\n".join([d.page_content for d in docs])
return context
Loading