-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First version of the Object API for vector search (#187)
First implementation of the Object level API for vector search. This is using client provided object reader and embedding functions to implement a high level object API.
- Loading branch information
1 parent
ebcf35b
commit 5b1fa3b
Showing
19 changed files
with
2,755 additions
and
1 deletion.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
436 changes: 436 additions & 0 deletions
436
apis/python/examples/object_api/soma_cell_similarity_search.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
13 changes: 13 additions & 0 deletions
13
apis/python/src/tiledb/vector_search/embeddings/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .image_resnetv2_embedding import ImageResNetV2Embedding | ||
from .object_embedding import ObjectEmbedding | ||
from .random_embedding import RandomEmbedding | ||
from .sentence_transformers_embedding import SentenceTransformersEmbedding | ||
from .soma_geneptw_embedding import SomaGenePTwEmbedding | ||
|
||
__all__ = [ | ||
"ObjectEmbedding", | ||
"SomaGenePTwEmbedding", | ||
"ImageResNetV2Embedding", | ||
"RandomEmbedding", | ||
"SentenceTransformersEmbedding", | ||
] |
47 changes: 47 additions & 0 deletions
47
apis/python/src/tiledb/vector_search/embeddings/image_resnetv2_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Dict, OrderedDict | ||
|
||
import numpy as np | ||
|
||
# from tiledb.vector_search.embeddings import ObjectEmbedding | ||
|
||
EMBED_DIM = 2048 | ||
|
||
|
||
# class ImageResNetV2Embedding(ObjectEmbedding): | ||
class ImageResNetV2Embedding: | ||
def __init__( | ||
self, | ||
): | ||
self.model = None | ||
|
||
def init_kwargs(self) -> Dict: | ||
return {} | ||
|
||
def dimensions(self) -> int: | ||
return EMBED_DIM | ||
|
||
def vector_type(self) -> np.dtype: | ||
return np.float32 | ||
|
||
def load(self) -> None: | ||
import tensorflow as tf | ||
|
||
self.model = tf.keras.applications.ResNet50V2(include_top=False) | ||
|
||
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: | ||
from efficientnet.preprocessing import center_crop_and_resize | ||
from tensorflow.keras.applications.resnet_v2 import preprocess_input | ||
|
||
size = len(objects["image"]) | ||
crop_size = 224 | ||
images = np.zeros((size, crop_size, crop_size, 3), dtype=np.uint8) | ||
for image_id in range(len(objects["image"])): | ||
images[image_id] = center_crop_and_resize( | ||
np.reshape(objects["image"][image_id], objects["shape"][image_id]), | ||
crop_size, | ||
).astype(np.uint8) | ||
maps = self.model.predict(preprocess_input(images)) | ||
if np.prod(maps.shape) == maps.shape[-1] * len(objects): | ||
return np.squeeze(maps) | ||
else: | ||
return maps.mean(axis=1).mean(axis=1) |
58 changes: 58 additions & 0 deletions
58
apis/python/src/tiledb/vector_search/embeddings/object_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from typing import Dict, OrderedDict | ||
|
||
import numpy as np | ||
|
||
|
||
class ObjectEmbedding(ABC): | ||
""" | ||
Abstract class that can be used to create embeddings for Objects of a specific format. | ||
""" | ||
|
||
@abstractmethod | ||
def init_kwargs(self) -> Dict: | ||
""" | ||
Returns a dictionary containing kwargs that can be used to re-initialize the ObjectEmbedding. | ||
This is used to serialize the ObjectEmbedding and pass it as argument to UDF tasks. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def dimensions(self) -> int: | ||
""" | ||
Returns the number of dimensions of the embedding vectors. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def vector_type(self) -> np.dtype: | ||
""" | ||
Returns the datatype of the embedding vectors. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def load(self) -> None: | ||
""" | ||
Loads the model in order to be ready for embedding objects. | ||
This method will be called once per worker to avoid loading the model multiple times. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: | ||
""" | ||
Creates embedding vectors for objects. Returns a numpy array of embedding vectors. | ||
There is no enforced restriction on the object format. ObjectReaders and ObjectEmbeddings should use comatible object and metadata formats. | ||
Parameters | ||
---------- | ||
objects: OrderedDict | ||
An OrderedDict, containing the object data, having structure similar to TileDB-Py read results. | ||
metadata: OrderedDict | ||
An OrderedDict, containing the object metadata, having structure similar to TileDB-Py read results. | ||
""" | ||
raise NotImplementedError |
31 changes: 31 additions & 0 deletions
31
apis/python/src/tiledb/vector_search/embeddings/random_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Dict, OrderedDict | ||
|
||
import numpy as np | ||
|
||
# from tiledb.vector_search.embeddings import ObjectEmbedding | ||
|
||
EMBED_DIM = 2048 | ||
|
||
|
||
# class RandomEmbedding(ObjectEmbedding): | ||
class RandomEmbedding: | ||
def __init__( | ||
self, | ||
): | ||
self.model = None | ||
|
||
def init_kwargs(self) -> Dict: | ||
return {} | ||
|
||
def dimensions(self) -> int: | ||
return EMBED_DIM | ||
|
||
def vector_type(self) -> np.dtype: | ||
return np.float32 | ||
|
||
def load(self) -> None: | ||
pass | ||
|
||
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: | ||
size = len(objects[list(objects.keys())[0]]) | ||
return np.random.rand(size, EMBED_DIM).astype(self.vector_type()) |
63 changes: 63 additions & 0 deletions
63
apis/python/src/tiledb/vector_search/embeddings/sentence_transformers_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import Dict, Optional, OrderedDict | ||
|
||
import numpy as np | ||
|
||
# from tiledb.vector_search.embeddings import ObjectEmbedding | ||
|
||
|
||
# class SentenceTransformersEmbedding(ObjectEmbedding): | ||
class SentenceTransformersEmbedding: | ||
""" | ||
Hugging SentenceTransformer model that can be used to map sentences / text to embeddings. | ||
:param model_name_or_path: If it is a filepath on disc, it loads the model from that path. If it is not a path, | ||
it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model | ||
from the Hugging Face Hub with that name. | ||
:param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU | ||
can be used. | ||
:param cache_folder: Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable. | ||
:param dimensions: Number of dimensions of the embedding. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: Optional[str] = None, | ||
device: Optional[str] = None, | ||
cache_folder: Optional[str] = None, | ||
dimensions: Optional[int] = -1, | ||
): | ||
self.model_name_or_path = model_name_or_path | ||
self.device = device | ||
self.cache_folder = cache_folder | ||
self.dim_num = dimensions | ||
self.model = None | ||
if self.dim_num == -1: | ||
self.load() | ||
self.dim_num = self.model.get_sentence_embedding_dimension() | ||
|
||
def init_kwargs(self) -> Dict: | ||
return { | ||
"model_name_or_path": self.model_name_or_path, | ||
"device": self.device, | ||
"cache_folder": self.cache_folder, | ||
"dimensions": self.dim_num, | ||
} | ||
|
||
def dimensions(self) -> int: | ||
return self.dim_num | ||
|
||
def vector_type(self) -> np.dtype: | ||
return np.float32 | ||
|
||
def load(self) -> None: | ||
from sentence_transformers import SentenceTransformer | ||
|
||
self.model = SentenceTransformer( | ||
model_name_or_path=self.model_name_or_path, | ||
device=self.device, | ||
cache_folder=self.cache_folder, | ||
) | ||
|
||
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: | ||
return self.model.encode(objects["text"], normalize_embeddings=True) |
74 changes: 74 additions & 0 deletions
74
apis/python/src/tiledb/vector_search/embeddings/soma_geneptw_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Any, Dict, Mapping, Optional, OrderedDict | ||
|
||
import numpy as np | ||
|
||
# from tiledb.vector_search.embeddings import ObjectEmbedding | ||
|
||
EMBED_DIM = 1536 # embedding dim from GPT-3.5 | ||
|
||
|
||
# class SomaGenePTwEmbedding(ObjectEmbedding): | ||
class SomaGenePTwEmbedding: | ||
def __init__( | ||
self, | ||
gene_embeddings_uri: str, | ||
soma_uri: str, | ||
config: Optional[Mapping[str, Any]] = None, | ||
): | ||
self.gene_embeddings_uri = gene_embeddings_uri | ||
self.soma_uri = soma_uri | ||
self.config = config | ||
self.gene_embedding = None | ||
self.gene_names = None | ||
|
||
def init_kwargs(self) -> Dict: | ||
return { | ||
"gene_embeddings_uri": self.gene_embeddings_uri, | ||
"soma_uri": self.soma_uri, | ||
"config": self.config, | ||
} | ||
|
||
def dimensions(self) -> int: | ||
return EMBED_DIM | ||
|
||
def vector_type(self) -> np.dtype: | ||
return np.float32 | ||
|
||
def load(self) -> None: | ||
import numpy as np | ||
import tiledbsoma | ||
|
||
import tiledb | ||
|
||
gene_pt_embeddings = {} | ||
with tiledb.open( | ||
self.gene_embeddings_uri, "r", config=self.config | ||
) as gene_pt_array: | ||
gene_pt = gene_pt_array[:] | ||
i = 0 | ||
for gene in np.array(gene_pt["genes"], dtype=str): | ||
gene_pt_embeddings[str(gene)] = gene_pt["embeddings"][i] | ||
i += 1 | ||
|
||
context = tiledbsoma.SOMATileDBContext(tiledb_ctx=tiledb.Ctx(self.config)) | ||
experiment = tiledbsoma.Experiment.open(self.soma_uri, "r", context=context) | ||
self.gene_names = ( | ||
experiment.ms["RNA"] | ||
.var.read() | ||
.concat() | ||
.to_pandas()["feature_name"] | ||
.to_numpy() | ||
) | ||
|
||
self.gene_embedding = np.zeros(shape=(len(self.gene_names), EMBED_DIM)) | ||
for i, gene in enumerate(self.gene_names): | ||
if gene in gene_pt_embeddings: | ||
self.gene_embedding[i, :] = gene_pt_embeddings[gene] | ||
|
||
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: | ||
import numpy as np | ||
|
||
return np.array( | ||
np.dot(objects["data"], self.gene_embedding) / len(self.gene_names), | ||
dtype=np.float32, | ||
) |
Oops, something went wrong.