Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding input image new #13

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/fundus_murag/assistant/gemini_fundus_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from loguru import logger
from vertexai.generative_models import (
ChatSession,
Content,
GenerationConfig,
GenerationResponse,
GenerativeModel,
Expand Down Expand Up @@ -115,11 +116,35 @@ def get_chat_messages(self) -> list[ChatMessage]:

return messages

def append_chat_message_to_history(self, role: str, content: str):
"""
Append a predefined chat message to the history without waiting for a model's response.

Args:
role: The role of the message ('user' or 'model').
content: The content of the message to append.
"""
if self._chat_session is None:
self.reset_chat_session()
self._chat_session = self._model.start_chat()

# Convert ChatMessage into Part
part = Part.from_text(content)

# Create a Content object using the Part and role
content_object = Content(parts=[part], role=role)

# Manually append the Content object to the chat history
self._chat_session.history.append(content_object)

logger.info(f"Appended message to history: {role} - {content}")

def load_model(self, model_name: str, use_tools: bool) -> GenerativeModel:
model_name = model_name.lower()
if "/" in model_name:
model_name = model_name.split("/")[-1]

# GenerativeModel is imported from vertexai.generative_models, it is part of Google's Vertex AI SDK.
# This SDK is designed to interact with Google Cloud's AI services, not OpenAI's models
model = GenerativeModel(
model_name=model_name,
generation_config=GEMINI_GENERATION_CONFIG,
Expand Down
126 changes: 110 additions & 16 deletions src/fundus_murag/data/routers/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64

import numpy as np
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, File, HTTPException, UploadFile

from fundus_murag.data.dto import (
EmbeddingQuery,
Expand All @@ -11,22 +13,34 @@
RecordLexicalSearchQuery,
)
from fundus_murag.data.vector_db import VectorDB
from fundus_murag.ml.client import FundusMLClient

router = APIRouter(prefix="/search", tags=["search"])

vdb = VectorDB()

ml_client = FundusMLClient()


@router.post(
"/records/similarity_search",
"/records/image_similarity_search",
# "/records/image_similarity_search",
response_model=list[FundusRecordSemanticSearchResult],
summary="Perform a similarity search of records based on an image embedding.",
tags=["search"],
)
def fundus_record_image_similarity_search(query: EmbeddingQuery):
"""
Perform a similarity search of records based on an image embedding.

Args:
query (EmbeddingQuery): The embedding query parameters.

Returns:
List[FundusRecordSemanticSearchResult]: A list of search results.
"""
try:
query_embedding = np.array(query.query_embedding)
query_embedding = np.array(query.query_embedding).tolist()
return vdb._fundus_record_image_similarity_search(
query_embedding=query_embedding,
search_in_collections=query.search_in_collections,
Expand All @@ -48,9 +62,14 @@ def fundus_record_image_similarity_search(query: EmbeddingQuery):
def fundus_record_title_similarity_search(query: EmbeddingQuery):
"""
Perform a similarity search of records based on a title embedding.

Args:
query (EmbeddingQuery): The embedding query parameters.

Returns:
list[FundusRecordSemanticSearchResult]: A list of search results.
"""
try:
# query_embedding = np.array(query.query_embedding)
query_embedding = list(query.query_embedding)
return vdb._fundus_record_title_similarity_search(
query_embedding=query_embedding,
Expand Down Expand Up @@ -89,12 +108,21 @@ def fundus_record_title_lexical_search(query: RecordLexicalSearchQuery):


@router.post(
"/collections/lexical_search",
"/collections/title_lexical_search",
response_model=list[FundusCollection],
summary="Perform a lexical search on `FundusCollection`s using a query string.",
tags=["search"],
)
def fundus_collection_lexical_search(query: LexicalSearchQuery):
"""
Perform a lexical search on `FundusCollection`s based on title.

Args:
query (LexicalSearchQuery): The search parameters.

Returns:
List[FundusCollection]: A list of matching collections.
"""
try:
return vdb._fundus_collection_lexical_search(
query=query.query,
Expand All @@ -110,16 +138,24 @@ def fundus_collection_lexical_search(query: LexicalSearchQuery):


@router.post(
"/collections/description_similarity_search",
response_model=list[FundusCollectionSemanticSearchResult], # Updated response model
summary="Perform a semantic similarity search on `FundusCollection`s based on their title description.",
"/collections/title_similarity_search",
response_model=list[FundusCollectionSemanticSearchResult],
summary="Perform a semantic similarity search on `FundusCollection`s based on their title.",
tags=["search"],
)
def fundus_collection_description_similarity_search(query: EmbeddingQuery):
def fundus_collection_title_similarity_search(query: EmbeddingQuery):
"""
Perform a semantic similarity search on `FundusCollection`s based on their title.

Args:
query (EmbeddingQuery): The embedding query parameters.

Returns:
List[FundusCollectionSemanticSearchResult]: A list of search results.
"""
try:
# query_embedding = np.array(query.query_embedding)
query_embedding = list(query.query_embedding)
return vdb.fundus_collection_description_similarity_search(
return vdb.fundus_collection_title_similarity_search(
query_embedding=query_embedding,
top_k=query.top_k,
)
Expand All @@ -128,18 +164,76 @@ def fundus_collection_description_similarity_search(query: EmbeddingQuery):


@router.post(
"/collections/title_similarity_search",
"/collections/description_similarity_search",
response_model=list[FundusCollectionSemanticSearchResult], # Updated response model
summary="Perform a semantic similarity search on `FundusCollection`s based on their title.",
summary="Perform a semantic similarity search on `FundusCollection`s based on their title description.",
tags=["search"],
)
def fundus_collection_title_similarity_search(query: EmbeddingQuery):
def fundus_collection_description_similarity_search(query: EmbeddingQuery):
"""
Perform a semantic similarity search on `FundusCollection`s based on their description.

Args:
query (EmbeddingQuery): The embedding query parameters.

Returns:
List[FundusCollectionSemanticSearchResult]: A list of search results.
"""
try:
# query_embedding = np.array(query.query_embedding)
query_embedding = list(query.query_embedding)
return vdb.fundus_collection_title_similarity_search(
return vdb.fundus_collection_description_similarity_search(
query_embedding=query_embedding,
top_k=query.top_k,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.post(
"/image_to_image_search",
response_model=list[FundusRecordSemanticSearchResult],
summary="Upload an image, process it, and perform a similarity search.",
tags=["image"],
)
async def image_to_image_search(file: UploadFile = File(...)):
"""
Upload an image, find its embedding, perform a similarity search, and return a list of matching records.

Args:
file (UploadFile): Uploaded image file.

Returns:
List[FundusRecordSemanticSearchResult]: A list of search results.
"""
try:
image_bytes = await file.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")

# FundusMLClient for image embedding
embedding = ml_client.compute_image_embedding(image_base64, return_tensor="np")

if (
embedding is None
or not isinstance(embedding, (np.ndarray, list))
or len(embedding) == 0
):
raise HTTPException(
status_code=500, detail="Failed to retrieve embedding from ML server"
)

query_embedding = np.array(embedding).flatten().tolist()

# Search in all collections `search_in_collections=None`
search_query = EmbeddingQuery(
query_embedding=query_embedding, search_in_collections=None
)

search_results = vdb._fundus_record_image_similarity_search(
query_embedding=search_query.query_embedding,
search_in_collections=search_query.search_in_collections,
top_k=search_query.top_k,
)
return search_results

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Loading