Skip to content

Support for MongoDB Vector Search #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions engine/clients/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ElasticUploader,
)
from engine.clients.milvus import MilvusConfigurator, MilvusSearcher, MilvusUploader
from engine.clients.mongodb import MongoConfigurator, MongoSearcher, MongoUploader
from engine.clients.opensearch import (
OpenSearchConfigurator,
OpenSearchSearcher,
Expand All @@ -39,6 +40,7 @@
"opensearch": OpenSearchConfigurator,
"redis": RedisConfigurator,
"pgvector": PgVectorConfigurator,
"mongodb": MongoConfigurator,
}

ENGINE_UPLOADERS = {
Expand All @@ -49,6 +51,7 @@
"opensearch": OpenSearchUploader,
"redis": RedisUploader,
"pgvector": PgVectorUploader,
"mongodb": MongoUploader,
}

ENGINE_SEARCHERS = {
Expand All @@ -59,6 +62,7 @@
"opensearch": OpenSearchSearcher,
"redis": RedisSearcher,
"pgvector": PgVectorSearcher,
"mongodb": MongoSearcher,
}


Expand Down
5 changes: 5 additions & 0 deletions engine/clients/mongodb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from engine.clients.mongodb.configure import MongoConfigurator
from engine.clients.mongodb.search import MongoSearcher
from engine.clients.mongodb.upload import MongoUploader

__all__ = ["MongoConfigurator", "MongoSearcher", "MongoUploader"]
33 changes: 33 additions & 0 deletions engine/clients/mongodb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

from pymongo.mongo_client import MongoClient

MONGO_PORT = int(os.getenv("MONGO_PORT", 27017))
MONGO_AUTH = os.getenv("MONGO_AUTH", "performance")
MONGO_USER = os.getenv("MONGO_USER", "performance")
MONGO_READ_PREFERENCE = os.getenv("MONGO_READ_PREFERENCE", "primary")
MONGO_WRITE_CONCERN = os.getenv("MONGO_READ_PREFERENCE", "1")
EMBEDDING_FIELD_NAME = os.getenv("EMBEDDING_FIELD_NAME", "embedding")
EMBEDDING_DISTANCE = os.getenv("EMBEDDING_DISTANCE", None)
ATLAS_DB_NAME = os.getenv("ATLAS_DB_NAME", "vector-db")
ATLAS_COLLECTION_NAME = os.getenv("ATLAS_COLLECTION_NAME", "vector-collection")
ATLAS_VECTOR_SEARCH_INDEX_NAME = os.getenv(
"ATLAS_VECTOR_SEARCH_INDEX_NAME", "vector-index"
)

# 90 seconds timeout
MONGO_QUERY_TIMEOUT = int(os.getenv("MONGO_QUERY_TIMEOUT", 90 * 1000))


def get_mongo_client(host, connection_params):
user = MONGO_USER
auth = MONGO_AUTH
uri = f"mongodb+srv://{user}:{auth}@{host}/?retryWrites=true&w={MONGO_WRITE_CONCERN}&appName=vector-db-benchmark&readPreference={MONGO_READ_PREFERENCE}"
# Create a new client and connect to the server
client = MongoClient(uri)
# Send a ping to confirm a successful connection
try:
client.admin.command("ping")
except Exception as e:
print(f"Failed pinging the deployment... error {e}")
return client
115 changes: 115 additions & 0 deletions engine/clients/mongodb/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import time

from benchmark.dataset import Dataset
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.clients.mongodb.config import (
ATLAS_COLLECTION_NAME,
ATLAS_DB_NAME,
ATLAS_VECTOR_SEARCH_INDEX_NAME,
EMBEDDING_FIELD_NAME,
get_mongo_client,
)


class MongoConfigurator(BaseConfigurator):
DISTANCE_MAPPING = {
Distance.L2: "euclidean",
Distance.COSINE: "cosine",
Distance.DOT: "dotProduct",
}

def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)
self.client = get_mongo_client(host, connection_params)
self.db = self.client[ATLAS_DB_NAME]
self.collection = self.db[ATLAS_COLLECTION_NAME]

def clean(self):
index_exists = True
try_count = 1

while index_exists is True:
index_exists = False
print(
f"Ensuring the search index named {ATLAS_VECTOR_SEARCH_INDEX_NAME} does not exist..."
)
try:
self.collection.drop_search_index(ATLAS_VECTOR_SEARCH_INDEX_NAME)
except Exception as e:
if "IndexNotFound" in e.__str__():
pass
else:
print(e)

stats = self.db.command("collstats", self.collection.name)
# Print the index details
index_details = stats.get("indexDetails", {})
index_exists = False
for index_name, details in index_details.items():
if ATLAS_VECTOR_SEARCH_INDEX_NAME in index_name:
print(f"Still detected index. Stats: {details}")
index_exists = True
try_count = try_count + 1
# sleep for 10 seconds to avoid invalid state
time.sleep(10)

print(
f"Finished ensuring the search index does not exist... after {try_count} tries"
)

print("Ensuring the collection does not exist...")

collection_exists = True
while collection_exists is True:
try_count = try_count + 1
try:
self.db.drop_collection(ATLAS_COLLECTION_NAME)
except Exception as e:
if "not exist" in e.__str__():
pass
else:
print(e)
collection_exists = False
collection_names = self.db.list_collection_names()
for collection_name in collection_names:
if ATLAS_COLLECTION_NAME in collection_name:
print(
f"Still detected collection named {ATLAS_COLLECTION_NAME}. Trying again..."
)
collection_exists = True
# sleep for 10 seconds to avoid invalid state
time.sleep(10)
print(
f"Finished ensuring the collection does not exist... after {try_count} tries"
)

def recreate(self, dataset: Dataset, collection_params):
# Explicitly create a collection in a MongoDB database.
print(f"Explicitly creating a collection named {ATLAS_COLLECTION_NAME}...")
self.db.create_collection(ATLAS_COLLECTION_NAME)
self.collection = self.db[ATLAS_COLLECTION_NAME]
print(
f"Creating the search index with vector mapping named {ATLAS_VECTOR_SEARCH_INDEX_NAME}..."
)

self.collection.create_search_index(
{
"definition": {
"mappings": {
"dynamic": True,
"fields": {
EMBEDDING_FIELD_NAME: {
"dimensions": dataset.config.vector_size,
"similarity": self.DISTANCE_MAPPING[
dataset.config.distance
],
"type": "knnVector",
}
},
}
},
"name": ATLAS_VECTOR_SEARCH_INDEX_NAME,
}
)
pass
60 changes: 60 additions & 0 deletions engine/clients/mongodb/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import copy
from typing import List, Tuple

from engine.base_client.distances import Distance
from engine.base_client.search import BaseSearcher
from engine.clients.mongodb.config import (
ATLAS_COLLECTION_NAME,
ATLAS_DB_NAME,
ATLAS_VECTOR_SEARCH_INDEX_NAME,
EMBEDDING_FIELD_NAME,
get_mongo_client,
)


class MongoSearcher(BaseSearcher):
search_params = {}
client = None

@classmethod
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
cls.distance = distance
cls.client = get_mongo_client(host, connection_params)
cls.search_params = copy.deepcopy(search_params)

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
numCandidates = cls.search_params.pop("numCandidates", 100)
# define pipeline

pipeline = [
{
"$vectorSearch": {
"index": ATLAS_VECTOR_SEARCH_INDEX_NAME,
"path": EMBEDDING_FIELD_NAME,
"queryVector": vector,
"numCandidates": numCandidates,
"limit": top,
}
},
{
"$project": {
"score": {"$meta": "vectorSearchScore"},
}
},
]

# run pipeline
results = cls.client[ATLAS_DB_NAME][ATLAS_COLLECTION_NAME].aggregate(pipeline)
search_result = []
for result in results:
reverted_normalization_score = float(result["score"])
# In MongoDB Atlas, for cosine and dotProduct similarities,
# the normalization of the score is done using the following formula:
# score = (1 + cosine/dot_product(v1,v2)) / 2
# to revert it we simply do:
if cls.distance == Distance.COSINE or cls.distance == Distance.L2:
reverted_normalization_score = (2.0 * reverted_normalization_score) - 1
search_result.append((int(result["_id"]), reverted_normalization_score))

return search_result
72 changes: 72 additions & 0 deletions engine/clients/mongodb/upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import List, Optional

from pymongo import InsertOne

from engine.base_client.upload import BaseUploader
from engine.clients.mongodb.config import (
ATLAS_COLLECTION_NAME,
ATLAS_DB_NAME,
ATLAS_VECTOR_SEARCH_INDEX_NAME,
EMBEDDING_FIELD_NAME,
get_mongo_client,
)


class MongoUploader(BaseUploader):
client = None
upload_params = {}

@classmethod
def init_client(cls, host, distance, connection_params, upload_params):
cls.client = get_mongo_client(host, connection_params)
cls.upload_params = upload_params
# Getting the database instance
cls.db = cls.client[ATLAS_DB_NAME]
# Creating a collection
cls.collection = cls.db[ATLAS_COLLECTION_NAME]

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
# Update the collection with the embeddings
requests = []

for i in range(len(ids)):
doc_id = ids[i]
embedding = vectors[i]
doc = {}
doc["_id"] = doc_id
doc[EMBEDDING_FIELD_NAME] = embedding
requests.append(InsertOne(doc))

cls.collection.bulk_write(requests)

@classmethod
def post_upload(cls, _distance):
print("waiting for search index status to be Active")

queryable = False
status = "n/a"
try_count = 1
while status != "ACTIVE" and queryable is False:
print(f"checking search indices. try: {try_count}...")
search_indexes = cls.collection.list_search_indexes()
for search_index in search_indexes:
index_name = search_index["name"]
if index_name == ATLAS_VECTOR_SEARCH_INDEX_NAME:
print(
f"detected search index named {ATLAS_VECTOR_SEARCH_INDEX_NAME}. checking status..."
)
print(search_index)
queryable = search_index["queryable"]
status = search_index["status"]
try_count = try_count + 1
print(
f"Finished waiting for search index status={status} and queryable={queryable}."
)
return {}

@classmethod
def get_memory_usage(cls):
return {}
15 changes: 15 additions & 0 deletions experiments/configurations/mongodb-single-node.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[
{
"name": "mongodb-default",
"engine": "mongodb",
"connection_params": {
"request_timeout": 10000
},
"collection_params": { "index_options": { } },
"search_params": [
{ "parallel": 1, "numCandidates": 128 }, { "parallel": 1, "numCandidates": 256 }, { "parallel": 1, "numCandidates": 512 }, { "parallel": 1, "numCandidates": 1024 }, { "parallel": 1, "numCandidates": 1536 }, { "parallel": 1, "numCandidates": 2048 },
{ "parallel": 100, "numCandidates": 128 }, { "parallel": 100, "numCandidates": 256 }, { "parallel": 100, "numCandidates": 512 }, { "parallel": 1, "numCandidates": 1536 }, { "parallel": 1, "numCandidates": 2048 }
],
"upload_params": { "parallel": 16 }
}
]