From 8fdc1c2e259df49677fd8aa0fe3448429122df97 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Wed, 25 Sep 2024 01:18:30 +0100 Subject: [PATCH 01/14] refactor: initial new rag pipeline code. --- source/docq/support/rag_pipeline.py | 157 +++++++++++++++ source/docq/support/store.py | 1 + web/__init__.py | 0 web/page_handlers/__init__.py | 0 web/page_handlers/ml_eng_tools/__init__.py | 0 web/page_handlers/ml_eng_tools/rag.py | 190 ++++++++++-------- .../ml_eng_tools/visualise_index.py | 7 +- web/utils/layout.py | 6 +- 8 files changed, 274 insertions(+), 87 deletions(-) create mode 100644 source/docq/support/rag_pipeline.py create mode 100644 web/__init__.py create mode 100644 web/page_handlers/__init__.py create mode 100644 web/page_handlers/ml_eng_tools/__init__.py diff --git a/source/docq/support/rag_pipeline.py b/source/docq/support/rag_pipeline.py new file mode 100644 index 00000000..edac9532 --- /dev/null +++ b/source/docq/support/rag_pipeline.py @@ -0,0 +1,157 @@ +"""Orchestrates the RAG pipeline.""" + +import inspect +from typing import Callable, Dict, List, Optional + +from docq.domain import Assistant +from llama_index.core.indices.base import BaseIndex +from llama_index.core.llms import LLM, ChatMessage, ChatResponse, MessageRole +from llama_index.core.schema import NodeWithScore +from llama_index.retrievers.bm25 import BM25Retriever + + +def search_stage( + user_query: str, + indices: List[BaseIndex], + reranker: Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]] + | Callable[[Dict[str, List[NodeWithScore]]], List[NodeWithScore]], + query_preprocessor: Optional[Callable[[str], List[str]]] = None, + top_k: int = 10, +) -> List[NodeWithScore]: + """Search stage of the RAG pipeline. + + Args: + user_query (str): The user query. + indices (List[BaseIndex]): The list of indices to search. + reranker (Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]]): The reranker to use. `func(search_results: List[List[NodeWithScore]], user_query: Optional[List[str]] = None) -> List[NodeWithScore]`. If not provided, a default reranker that uses X is used. + query_preprocessor (Optional[Callable[[str], List[str]]]): The preprocessor to use. `func(user_query: str) -> List[str]`. Defaults to no preprocessing. + top_k (int): The number of results to return per search. + + Returns: + List[NodeWithScore]: The reduced and reranked search results. + """ + if reranker is None: + raise ValueError("Reranker is required") + + _vector_retrievers = {} + _bm25_retrievers = {} + + # 1. Prepare retrievers + for i, index in enumerate(indices): + _vector_retrievers[f"vector_{index.index_id}_{i}"] = index.as_retriever(similarity_top_k=top_k) + _bm25_retrievers[f"bm25_{index.index_id}_{i}"] = BM25Retriever.from_defaults( + docstore=index.docstore, similarity_top_k=top_k + ) + + # 2. Preprocess user query if preprocessor is provided + processed_queries = query_preprocessor(user_query) if query_preprocessor else [user_query] + + # 3. Run the list of queries through each retriever + vector_results = {} + bm25_results = {} + + for i, query in enumerate(processed_queries): + for key, vr in _vector_retrievers.items(): + vector_results[f"{key}_query_{i}"] = vr.retrieve(query) + + for key, br in _bm25_retrievers.items(): + bm25_results[f"{key}_query_{i}"] = br.retrieve(query) + + # 4. Combine results + combined_results = {**vector_results, **bm25_results} + + if callable(reranker): + reranker_params = inspect.signature(reranker).parameters + if len(reranker_params) == 2: + reranked_results = reranker(combined_results, processed_queries) + elif len(reranker_params) == 1: + reranked_results = reranker(combined_results) + else: + raise ValueError("Reranker function must accept either one or two arguments") + else: + raise TypeError("Reranker must be a callable") + + return reranked_results + + +def generation_stage( + user_query: str, + assistant: Assistant, + search_results: List[NodeWithScore], + message_history: List[ChatMessage], + llm: LLM, +) -> ChatResponse: + """Generation stage of the RAG pipeline. + + Args: + user_query (str): The user query. + assistant (Assistant): The assistant. + search_results (List[NodeWithScore]): The search results. + message_history (List[ChatMessage]): The message history. + llm (LLM): The LLM. + + Returns: + ChatResponse: The response from the LLM. + """ + # build system message + system_message = ChatMessage(role=MessageRole.SYSTEM, content=assistant.system_message_content) + + # build query message + query_message = ChatMessage( + role=MessageRole.USER, + content=assistant.user_prompt_template_content.format( + context_str="\n".join([node.text for node in search_results]), + query_str=user_query, + ), + ) + + chat_messages = [system_message] + message_history + [query_message] + + # Generate response + response = llm.chat(messages=chat_messages) + + # TODO: Add source references to the response + + return response + + +def rag_pipeline( + user_query: str, + indices: List[BaseIndex], + assistant: Assistant, + message_history: List[ChatMessage], + llm: LLM, + reranker: Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]] + | Callable[[Dict[str, List[NodeWithScore]]], List[NodeWithScore]], + query_preprocessor: Optional[Callable[[str], List[str]]] = None, + top_k: int = 10, +) -> ChatResponse: + """Orchestrates the RAG pipeline. + + Args: + user_query (str): The user query. + indices (List[BaseIndex]): The list of indices to search. + assistant (Assistant): The assistant. + message_history (List[ChatMessage]): The message history. + llm (LLM): The LLM. + reranker (Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]]): The reranker to use. `func(search_results: List[List[NodeWithScore]], user_query: Optional[str] = None) -> List[NodeWithScore]`. If not provided, a default reranker that uses X is used. + query_preprocessor (Optional[Callable[[str], List[str]]]): The preprocessor to use. `func(user_query: str) -> List[str]`. Defaults to no preprocessing. + top_k (int): The number of results to return per search. + + Returns: + ChatResponse: The response from the LLM. + """ + # Search stage + search_results = search_stage( + user_query=user_query, indices=indices, reranker=reranker, query_preprocessor=query_preprocessor, top_k=top_k + ) + # Generation stage + response = generation_stage( + user_query=user_query, + search_results=search_results, + message_history=message_history, + assistant=assistant, + llm=llm, + ) + + return response diff --git a/source/docq/support/store.py b/source/docq/support/store.py index d3be164f..76d1eb7a 100644 --- a/source/docq/support/store.py +++ b/source/docq/support/store.py @@ -221,6 +221,7 @@ def _clean_public_chat_history() -> None: @tracer.start_as_current_span(name="_get_storage_context") def _get_storage_context(space: SpaceKey) -> StorageContext: + """Get the storage context for a Space. This loads all stores from the Space directory aka `persist_dir`.""" return StorageContext.from_defaults(persist_dir=get_index_dir(space)) diff --git a/web/__init__.py b/web/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web/page_handlers/__init__.py b/web/page_handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web/page_handlers/ml_eng_tools/__init__.py b/web/page_handlers/ml_eng_tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web/page_handlers/ml_eng_tools/rag.py b/web/page_handlers/ml_eng_tools/rag.py index edb8faf6..1c83cda0 100644 --- a/web/page_handlers/ml_eng_tools/rag.py +++ b/web/page_handlers/ml_eng_tools/rag.py @@ -1,38 +1,50 @@ """ML Eng - Visualise index.""" + import logging as log import os -from typing import List, Optional +from typing import List import streamlit as st from docq.config import SpaceType from docq.data_source.list import SpaceDataSources from docq.domain import Assistant, SpaceKey -from docq.manage_assistants import llama_index_chat_prompt_template_from_assistant from docq.manage_spaces import get_space_data_source, list_space from docq.model_selection.main import LlmUsageSettingsCollection, ModelCapability, get_saved_model_settings_collection -from docq.support.llm import ( - _get_default_storage_context, - _get_service_context, - _get_storage_context, +from docq.support.llama_index.node_post_processors import reciprocal_rank_fusion +from docq.support.llm import _get_service_context +from docq.support.rag_pipeline import rag_pipeline +from docq.support.store import ( + _get_path, + _map_space_type_to_datascope, + _StoreDir, ) -from docq.support.store import _get_path, _map_space_type_to_datascope, _StoreDir from llama_index.core import VectorStoreIndex from llama_index.core.base.response.schema import Response -from llama_index.core.indices import load_index_from_storage, load_indices_from_storage +from llama_index.core.indices import load_index_from_storage from llama_index.core.indices.base import BaseIndex +from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.schema import BaseNode, Document, NodeWithScore from llama_index.core.storage import StorageContext from llama_index.retrievers.bm25 import BM25Retriever -from ml_eng_tools.visualise_index import visualise_vector_store_index +from page_handlers.ml_eng_tools.visualise_index import visualise_vector_store_index from streamlit.delta_generator import DeltaGenerator from utils.layout import auth_required, render_page_title_and_favicon from utils.sessions import get_selected_org_id -render_page_title_and_favicon() +render_page_title_and_favicon(layout="wide") auth_required(requiring_selected_org_admin=True) above_tabs_container = st.container() -chat_tab, index_tab = st.tabs(["chat_tab", "index_tab"]) +left_col, right_col = st.columns(2) +chat_tab, index_tab = left_col.tabs(["chat_tab", "index_tab"]) + + +# select a live space +# create an experimental index +# chat with space over experimental index (no saved threads, clear the chat) +# iterate on the assist prompt +# visualise the retrieved chunks for a chat +# layout chat and assist prompt on the left. visualise the retrieved chunks on the right def _get_experiement_dir(space: SpaceKey, experiment_id: str) -> str: @@ -47,30 +59,21 @@ def _get_experiments_storage_context(space: SpaceKey, experiment_id: str) -> Sto return StorageContext.from_defaults(persist_dir=_get_experiement_dir(space, experiment_id)) -def _load_index( - space: SpaceKey, model_settings_collection: LlmUsageSettingsCollection, exp_id: Optional[str] = None +def _load_experiment_index_from_storage( + space: SpaceKey, model_settings_collection: LlmUsageSettingsCollection, exp_id: str ) -> BaseIndex: """Load index from storage.""" - storage_context = _get_storage_context(space) + storage_context = _get_experiments_storage_context(space, exp_id) # _get_storage_context(space) service_context = _get_service_context(model_settings_collection) return load_index_from_storage(storage_context=storage_context, service_context=service_context) -def _load_indices_from_storage( - space: SpaceKey, model_settings_collection: LlmUsageSettingsCollection -) -> List[BaseIndex]: - # set service context explicitly for multi model compatibility - sc = _get_service_context(model_settings_collection) - return load_indices_from_storage( - storage_context=_get_storage_context(space), service_context=sc, callback_manager=sc.callback_manager - ) - def _load_vector_store_index( - space: SpaceKey, model_settings_collection: LlmUsageSettingsCollection + space: SpaceKey, model_settings_collection: LlmUsageSettingsCollection, exp_id: str ) -> VectorStoreIndex: """Load index from storage.""" - storage_context = _get_storage_context(space) + storage_context = _get_experiments_storage_context(space, exp_id) service_context = _get_service_context(model_settings_collection) return VectorStoreIndex.from_vector_store(storage_context.vector_store, service_context=service_context) @@ -91,7 +94,7 @@ def _create_vector_index( log.debug("service context: ", sc.embed_model) index_ = VectorStoreIndex.from_documents( documents, - storage_context=_get_default_storage_context(), + storage_context=_get_experiments_storage_context(space, exp_id), service_context=sc, show_progress=True, kwargs=model_settings_collection.model_usage_settings[ModelCapability.EMBEDDING].additional_args, @@ -130,8 +133,14 @@ def render_documents(docs: list[Document]): st.write(doc) -selected_org_id = get_selected_org_id() experiment_id = "sasdfasdf" +selected_org_id = get_selected_org_id() +above_tabs_container.write(f"Experiment ID: `{experiment_id}` within Org ID: `{selected_org_id}`") +above_tabs_container.write( + f"All index creation and loading is contained withing a sub dir called `exp_{experiment_id}`" +) + + spaces = [] selected_space = None if selected_org_id: @@ -144,6 +153,7 @@ def render_documents(docs: list[Document]): format_func=lambda x: x[2], label_visibility="visible", index=0, + key="selected_space", ) @@ -158,15 +168,17 @@ def render_documents(docs: list[Document]): docs = load_documents_from_space(selected_space_key) + if not docs: + above_tabs_container.write("No documents found in the Space.") + st.stop() + index_ = None try: - index_ = load_index_from_storage( - storage_context=_get_experiments_storage_context(selected_space_key, experiment_id), - service_context=_get_service_context(saved_model_settings), - ) - # index_ = _load_vector_store_index(selected_space_key, saved_model_settings) - except: + # index_ = _load_vector_store_index(selected_space_key, saved_model_settings, experiment_id) + index_ = _load_experiment_index_from_storage(selected_space_key, saved_model_settings, experiment_id) + except Exception as e: above_tabs_container.write("No index. Index the space first.") + print(e) IndexSpaceButton = above_tabs_container.button("Index Space") @@ -175,41 +187,25 @@ def render_documents(docs: list[Document]): log.debug("models: ", saved_model_settings) index_ = _create_vector_index(docs, saved_model_settings, selected_space_key, experiment_id) # storage_context.persist(persist_dir=_get_experiement_dir(selected_space_key, experiment_id)) - index_.storage_context.persist(persist_dir=_get_experiement_dir(selected_space_key, experiment_id)) + index_.storage_context.persist() # persist to the experiment dir + # index_.storage_context.persist(persist_dir=_get_experiement_dir(selected_space_key, experiment_id)) with index_tab: - if index_: + if index_ and isinstance(index_, VectorStoreIndex): visualise_vector_store_index(index_) - # render_documents(docs) - - # index_ = _load_index_from_storage(selected_space_key, saved_model_settings) - # index_ = _get_storage_context(selected_space_key).vector_store - # index_ = VectorStoreIndex.from_vector_store(_get_storage_context(selected_space_key).vector_store) - - # indices = _load_indices_from_storage(selected_space_key, saved_model_settings) - # storage_context = _get_storage_context(selected_space_key) - - # print("space_indices: ", len(indices)) - # for each_index in indices: - # print("each_index: ", each_index.index_id) - - # _index = _load_index(space, saved_model_settings) - # if isinstance(_index, DocumentSummaryIndex): - # visualise_document_summary_index(_index) - # elif isinstance(_index, VectorStoreIndex): - # visualise_vector_store_index(_index) - # else: - # st.write("Visualiser not available for index type: ", _index.index_struct_cls.__name__) def prepare_chat(): ch = st.session_state.get(f"rag_test_chat_history_content_{selected_space_key.id_}", None) if not ch: - st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"] = ["Hello, ask a question."] + st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"] = [ + ChatMessage(role=MessageRole.ASSISTANT, content="Hello, ask me a question.") + ] -prepare_chat() +with chat_tab: + prepare_chat() def render_retrieval_query_ui(container: DeltaGenerator): @@ -255,6 +251,7 @@ def render_retrieval_query_ui(container: DeltaGenerator): def render_retreival_results(results: List[NodeWithScore]): + st.write("Retrieved Chunks:") for node in results: with st.expander(node.node_id): st.write(node) @@ -262,6 +259,10 @@ def render_retreival_results(results: List[NodeWithScore]): def handle_chat_input(): """Handle chat input.""" + if not index_: + st.error("Index not loaded. Please select a space and load the index first.") + return + space_indices = [index_] persona = Assistant( @@ -276,46 +277,72 @@ def handle_chat_input(): # text_qa_template=llama_index_chat_prompt_template_from_persona(persona).partial_format(history_str=""), # ) - from llama_index.core.retrievers import QueryFusionRetriever - from llama_index.core.retrievers.fusion_retriever import FUSION_MODES + # from llama_index.core.retrievers import QueryFusionRetriever + # from llama_index.core.retrievers.fusion_retriever import FUSION_MODES - vector_retriever = index_.as_retriever(similarity_top_k=10) + # vector_retriever = index_.as_retriever(similarity_top_k=10) - bm25_retriever = BM25Retriever.from_defaults(docstore=index_.docstore, similarity_top_k=10) - retriever = QueryFusionRetriever( - [vector_retriever, bm25_retriever], - similarity_top_k=5, - num_queries=4, # set this to 1 to disable query generation - mode=FUSION_MODES.RECIPROCAL_RANK, - use_async=False, - verbose=True, - llm=_get_service_context(saved_model_settings).llm, - # query_gen_prompt="...", # we could override the query generation prompt here - ) + # bm25_retriever = BM25Retriever.from_defaults(docstore=index_.docstore, similarity_top_k=10) + # retriever = QueryFusionRetriever( + # [vector_retriever, bm25_retriever], + # similarity_top_k=5, + # num_queries=4, # set this to 1 to disable query generation + # mode=FUSION_MODES.RECIPROCAL_RANK, + # use_async=False, + # verbose=True, + # llm=_get_service_context(saved_model_settings).llm, + # # query_gen_prompt="...", # we could override the query generation prompt here + # ) - from llama_index.core.query_engine import RetrieverQueryEngine + # from llama_index.core.query_engine import RetrieverQueryEngine + + # query_engine = RetrieverQueryEngine.from_args( + # retriever, + # service_context=_get_service_context(saved_model_settings), + # text_qa_template=llama_index_chat_prompt_template_from_assistant(persona).partial_format(history_str=""), # noqa: F821 + # ) - query_engine = RetrieverQueryEngine.from_args( - retriever, - service_context=_get_service_context(saved_model_settings), - text_qa_template=llama_index_chat_prompt_template_from_assistant(persona).partial_format(history_str=""), # noqa: F821 + chat_history: List[ChatMessage] = st.session_state.get( + f"rag_test_chat_history_content_{selected_space_key.id_}", [] ) + print("chat_history: ", chat_history) + query = st.session_state.get("chat_input_rag_test", None) if query: # query_embed = _get_service_context(saved_model_settings).embed_model.get_query_embedding(query) # query_bundle = QueryBundle(query_str=query, custom_embedding_strs=query, embedding=query_embed) # ret_results = query_engine.retrieve(query_bundle) - resp = query_engine.query(query) + # resp = query_engine.query(query) + + resp = rag_pipeline( + user_query=query, + indices=space_indices, + assistant=persona, + message_history=chat_history, + llm=_get_service_context(saved_model_settings).llm, + reranker=lambda results: reciprocal_rank_fusion(results), + query_preprocessor=None, + top_k=6, + ) if isinstance(resp, Response): - st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend([query, resp.response]) + # st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend([query, resp.response]) + st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend( + [ + ChatMessage(role=MessageRole.USER, content=query), + ChatMessage(role=MessageRole.ASSISTANT, content=resp.response), + ] + ) def render_chat(): # st.write(st.session_state.get("chat_input_rag_test", "Hello, ask a questions.")) - for chat in st.session_state.get(f"rag_test_chat_history_content_{selected_space_key.id_}", []): - st.write(chat) + chat_history: list[ChatMessage] = st.session_state.get( + f"rag_test_chat_history_content_{selected_space_key.id_}", [] + ) + for cm in chat_history: + st.write(f"{cm.role}: {cm.content}") st.chat_input( "Type your question here", @@ -333,4 +360,5 @@ def clear_chat(): with chat_tab: st.button("Clear Chat", on_click=clear_chat) -render_chat() + + render_chat() diff --git a/web/page_handlers/ml_eng_tools/visualise_index.py b/web/page_handlers/ml_eng_tools/visualise_index.py index f5ccc053..d0fe88a7 100644 --- a/web/page_handlers/ml_eng_tools/visualise_index.py +++ b/web/page_handlers/ml_eng_tools/visualise_index.py @@ -1,6 +1,5 @@ """ML Eng - Visualise index.""" import json -import token from typing import Optional, cast import streamlit as st @@ -8,14 +7,12 @@ from docq.domain import SpaceKey from docq.manage_spaces import list_space from docq.model_selection.main import LlmUsageSettingsCollection, get_saved_model_settings_collection -from docq.support.llm import _get_service_context, _get_storage_context -from docq.support.store import get_models_dir -from llama_index.core import Settings +from docq.support.llm import _get_service_context +from docq.support.store import _get_storage_context from llama_index.core.indices import DocumentSummaryIndex, VectorStoreIndex, load_index_from_storage from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.embeddings.huggingface_optimum import OptimumEmbedding -from transformers import AutoTokenizer from utils.layout import auth_required, render_page_title_and_favicon from utils.sessions import get_selected_org_id diff --git a/web/utils/layout.py b/web/utils/layout.py index 4056618b..37a05147 100644 --- a/web/utils/layout.py +++ b/web/utils/layout.py @@ -38,6 +38,7 @@ ) from opentelemetry import trace from st_pages import hide_pages, translate_icon +from streamlit.commands.page_config import Layout from streamlit.components.v1 import html from streamlit.delta_generator import DeltaGenerator from streamlit.elements.image import AtomicImage @@ -360,7 +361,9 @@ def __always_hidden_pages() -> None: def render_page_title_and_favicon( - page_display_title: Optional[str] = None, browser_title: Optional[str] = None + page_display_title: Optional[str] = None, + browser_title: Optional[str] = None, + layout: Optional[Layout] = "centered", ) -> None: """Handle setting browser page title and favicon. Separately render in app page title with icon defined in show_pages(). @@ -405,6 +408,7 @@ def render_page_title_and_favicon( page_icon=favicon_path, page_title=browser_title if browser_title else f"{browser_title_prefix} - {_page_display_title}", menu_items={"About": about_menu_content}, + layout=layout, ) except StreamlitAPIException: pass From d6692225b3baa145f4d3b229f3bb7e9c0e4cff40 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 5 Oct 2024 23:35:08 +0100 Subject: [PATCH 02/14] feat(ml tools): add Assistant selection and edit - refactor: replace ASSISTANT tuple type with the Assistant dataclass --- source/docq/domain.py | 29 ++-- source/docq/manage_assistants.py | 130 ++++++++++++--- source/docq/manage_spaces.py | 2 +- source/docq/support/rag_pipeline.py | 37 +++-- web/page_handlers/ml_eng_tools/rag.py | 222 +++++++++++++------------- web/utils/layout_assistants.py | 53 +++--- 6 files changed, 296 insertions(+), 177 deletions(-) diff --git a/source/docq/domain.py b/source/docq/domain.py index 7780b363..b8fbb0de 100644 --- a/source/docq/domain.py +++ b/source/docq/domain.py @@ -10,7 +10,7 @@ from enum import Enum from typing import Any, Optional, Self -from .config import OrganisationFeatureType, SpaceType +from docq.config import OrganisationFeatureType, SpaceType _SEPARATOR_FOR_STR = ":" _SEPARATOR_FOR_VALUE = "_" @@ -120,25 +120,36 @@ def create_instance(document_link: str, document_text: str, indexed_on: Optional ) raise e + +class AssistantType(Enum): + """Persona type.""" + + SIMPLE_CHAT = "Simple Chat" + AGENT = "Agent" + ASK = "Ask" + + @dataclass class Assistant: """A assistant at it's core is a system prompt and user prompt template that tunes the LLM to take on a certain persona and behave/respond a particular way.""" key: str """Unique ID for a Persona instance""" + scoped_id: str + """Scoped ID for a Persona instance.""" name: str """Friendly name for the persona""" + type: AssistantType + """Type of the persona""" + archived: bool + """Whether the persona is soft deleted or not""" system_message_content: str """Content of the system message. This is where the persona is defined.""" user_prompt_template_content: str """Template for the user prompt aka query. This template is used to generate the content for the user prompt/query that will be sent to the LLM (as a user message).""" llm_settings_collection_key: str """The key of the LLM settings collection to use for LLM calls by this assistant. """ - - -class AssistantType(Enum): - """Persona type.""" - - SIMPLE_CHAT = "Simple Chat" - AGENT = "Agent" - ASK = "Ask" + created_at: datetime + """The timestamp when the assistant record was created.""" + updated_at: datetime + """The timestamp when the assistant record was last updated.""" diff --git a/source/docq/manage_assistants.py b/source/docq/manage_assistants.py index b69d9130..90c776e8 100644 --- a/source/docq/manage_assistants.py +++ b/source/docq/manage_assistants.py @@ -2,7 +2,7 @@ import logging as log import sqlite3 from contextlib import closing -from datetime import datetime +from datetime import UTC, datetime from typing import List, Optional from llama_index.core.base.llms.types import ChatMessage, MessageRole @@ -126,8 +126,8 @@ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ -# id, name, type, archived, system_prompt_template, user_prompt_template, llm_settings_collection_key, created_at, updated_at, scoped_id -ASSISTANT = tuple[int, str, str, bool, str, str, str, datetime, datetime, str] +# # id, name, type, archived, system_prompt_template, user_prompt_template, llm_settings_collection_key, created_at, updated_at, scoped_id +# ASSISTANT = tuple[int, str, str, bool, str, str, str, datetime, datetime, str] def _init(org_id: Optional[int] = None) -> None: @@ -182,16 +182,71 @@ def get_assistant_fixed( """Get the personas.""" result = {} if assistant_type == AssistantType.SIMPLE_CHAT: - result = {key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in SIMPLE_CHAT_PERSONAS.items()} + result = { + key: Assistant( + key=key, + type=AssistantType.SIMPLE_CHAT, + archived=False, + **persona, + llm_settings_collection_key=llm_settings_collection_key, + created_at=datetime.now(tz=UTC), + updated_at=datetime.now(tz=UTC), + ) + for key, persona in SIMPLE_CHAT_PERSONAS.items() + } elif assistant_type == AssistantType.AGENT: result = {key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in AGENT_PERSONAS.items()} elif assistant_type == AssistantType.ASK: - result = {key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in ASK_PERSONAS.items()} + result = { + key: Assistant( + key=key, + type=AssistantType.ASK, + archived=False, + **persona, + llm_settings_collection_key=llm_settings_collection_key, + created_at=datetime.now(tz=UTC), + updated_at=datetime.now(tz=UTC), + ) + for key, persona in ASK_PERSONAS.items() + } else: result = { - **{key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in SIMPLE_CHAT_PERSONAS.items()}, - **{key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in AGENT_PERSONAS.items()}, - **{key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in ASK_PERSONAS.items()}, + **{ + key: Assistant( + key=key, + type=AssistantType.SIMPLE_CHAT, + archived=False, + **persona, + llm_settings_collection_key=llm_settings_collection_key, + created_at=datetime.now(tz=UTC), + updated_at=datetime.now(tz=UTC), + ) + for key, persona in SIMPLE_CHAT_PERSONAS.items() + }, + **{ + key: Assistant( + key=key, + type=AssistantType.AGENT, + archived=False, + **persona, + llm_settings_collection_key=llm_settings_collection_key, + created_at=datetime.now(tz=UTC), + updated_at=datetime.now(tz=UTC), + ) + for key, persona in AGENT_PERSONAS.items() + }, + **{ + key: Assistant( + key=key, + type=AssistantType.ASK, + archived=False, + **persona, + llm_settings_collection_key=llm_settings_collection_key, + created_at=datetime.now(tz=UTC), + updated_at=datetime.now(tz=UTC), + ) + for key, persona in ASK_PERSONAS.items() + }, } return result @@ -207,23 +262,29 @@ def get_assistant_or_default(assistant_scoped_id: Optional[str] = None, org_id: """ if assistant_scoped_id: assistant_data = get_assistant(assistant_scoped_id=assistant_scoped_id, org_id=org_id) - return Assistant( - key=str(assistant_data[0]), - name=assistant_data[1], - system_message_content=assistant_data[4], - user_prompt_template_content=assistant_data[5], - llm_settings_collection_key=assistant_data[6], - ) + return assistant_data + # return Assistant( + # key=str(assistant_data[0]), + # name=assistant_data[1], + # system_message_content=assistant_data[4], + # user_prompt_template_content=assistant_data[5], + # llm_settings_collection_key=assistant_data[6], + # ) else: key = "default" return Assistant( key=key, - llm_settings_collection_key="azure_openai_with_local_embedding", + scoped_id=f"global_{key}", + type=AssistantType.SIMPLE_CHAT, + archived=False, **SIMPLE_CHAT_PERSONAS[key], + llm_settings_collection_key="azure_openai_with_local_embedding", + created_at=datetime.now(tz=UTC), + updated_at=datetime.now(tz=UTC), ) -def list_assistants(org_id: Optional[int] = None, assistant_type: Optional[AssistantType] = None) -> list[ASSISTANT]: +def list_assistants(org_id: Optional[int] = None, assistant_type: Optional[AssistantType] = None) -> list[Assistant]: """List the assistants. Args: @@ -231,7 +292,7 @@ def list_assistants(org_id: Optional[int] = None, assistant_type: Optional[Assis assistant_type (Optional[AssistantType]): The assistant type. Returns: - list[ASSISTANT]: The list of assistants. This includes a compound ID that of ID + scope. This is to avoid ID clashes between global and org scope tables on gets. + list[Assistant]: The list of assistants. This includes a compound ID that of ID + scope. This is to avoid ID clashes between global and org scope tables on gets. """ scope = "global" if org_id: @@ -250,13 +311,28 @@ def list_assistants(org_id: Optional[int] = None, assistant_type: Optional[Assis ) as connection, closing(connection.cursor()) as cursor: cursor.execute(sql, params) rows = cursor.fetchall() + # return [ + # (row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], f"{scope}_{row[0]}") + # for row in rows + # ] return [ - (row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], f"{scope}_{row[0]}") + Assistant( + key=str(row[0]), + name=row[1], + type=row[2], + archived=row[3], + system_message_content=row[4], + user_prompt_template_content=row[5], + llm_settings_collection_key=row[6], + created_at=row[7], + updated_at=row[8], + scoped_id=f"{scope}_{row[0]}", + ) for row in rows ] -def get_assistant(assistant_scoped_id: str, org_id: Optional[int]) -> ASSISTANT: +def get_assistant(assistant_scoped_id: str, org_id: Optional[int]) -> Assistant: """Get the assistant. If just assistant_id then will try to get from global scope table. @@ -287,7 +363,19 @@ def get_assistant(assistant_scoped_id: str, org_id: Optional[int]) -> ASSISTANT: ) else: raise ValueError(f"No Assistant with: id = '{id_}' in global scope. scope= '{scope}'") - return (row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], assistant_scoped_id) + # return (row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], assistant_scoped_id) + return Assistant( + key=str(row[0]), + name=row[1], + type=AssistantType(row[2]), + archived=row[3], + system_message_content=row[4], + user_prompt_template_content=row[5], + llm_settings_collection_key=row[6], + created_at=row[7], + updated_at=row[8], + scoped_id=f"{scope}_{row[0]}", + ) def create_or_update_assistant( diff --git a/source/docq/manage_spaces.py b/source/docq/manage_spaces.py index a77c8f37..d6a35694 100644 --- a/source/docq/manage_spaces.py +++ b/source/docq/manage_spaces.py @@ -150,7 +150,7 @@ def list_space(org_id: int, space_type: Optional[str] = None) -> list[SPACE]: ) rows = cursor.fetchall() - print("spaces:", rows) + return [_format_space(row) for row in rows] diff --git a/source/docq/support/rag_pipeline.py b/source/docq/support/rag_pipeline.py index edac9532..ced89ea5 100644 --- a/source/docq/support/rag_pipeline.py +++ b/source/docq/support/rag_pipeline.py @@ -1,7 +1,7 @@ """Orchestrates the RAG pipeline.""" import inspect -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from docq.domain import Assistant from llama_index.core.indices.base import BaseIndex @@ -17,7 +17,8 @@ def search_stage( | Callable[[Dict[str, List[NodeWithScore]]], List[NodeWithScore]], query_preprocessor: Optional[Callable[[str], List[str]]] = None, top_k: int = 10, -) -> List[NodeWithScore]: + enable_debug: Optional[bool] = False, +) -> Tuple[List[NodeWithScore], Dict[str, Any]]: """Search stage of the RAG pipeline. Args: @@ -30,6 +31,7 @@ def search_stage( Returns: List[NodeWithScore]: The reduced and reranked search results. """ + debug: dict[str, Any] = {} if reranker is None: raise ValueError("Reranker is required") @@ -63,15 +65,23 @@ def search_stage( if callable(reranker): reranker_params = inspect.signature(reranker).parameters if len(reranker_params) == 2: - reranked_results = reranker(combined_results, processed_queries) + reranked_results = reranker(combined_results, processed_queries) # type: ignore elif len(reranker_params) == 1: - reranked_results = reranker(combined_results) + reranked_results = reranker(combined_results) # type: ignore else: raise ValueError("Reranker function must accept either one or two arguments") else: raise TypeError("Reranker must be a callable") - return reranked_results + if enable_debug: + for key, value in vector_results.items(): + debug[key] = value + for key, value in bm25_results.items(): + debug[key] = value + debug["reranked_results"] = reranked_results + debug["processed_queries"] = processed_queries + + return (reranked_results, debug) def generation_stage( @@ -80,7 +90,8 @@ def generation_stage( search_results: List[NodeWithScore], message_history: List[ChatMessage], llm: LLM, -) -> ChatResponse: + enable_debug: Optional[bool] = False, +) -> Tuple[ChatResponse, Dict[str, Any]]: """Generation stage of the RAG pipeline. Args: @@ -93,6 +104,7 @@ def generation_stage( Returns: ChatResponse: The response from the LLM. """ + debug: dict[str, Any] = {} # build system message system_message = ChatMessage(role=MessageRole.SYSTEM, content=assistant.system_message_content) @@ -110,9 +122,15 @@ def generation_stage( # Generate response response = llm.chat(messages=chat_messages) + if enable_debug: + debug["system_message"] = system_message + debug["user_prompt_template_content"] = assistant.user_prompt_template_content + debug["query_message"] = query_message + debug["search_results"] = search_results + # TODO: Add source references to the response - return response + return (response, debug) def rag_pipeline( @@ -142,11 +160,12 @@ def rag_pipeline( ChatResponse: The response from the LLM. """ # Search stage - search_results = search_stage( + search_results, debug = search_stage( user_query=user_query, indices=indices, reranker=reranker, query_preprocessor=query_preprocessor, top_k=top_k ) + # Generation stage - response = generation_stage( + response, debug = generation_stage( user_query=user_query, search_results=search_results, message_history=message_history, diff --git a/web/page_handlers/ml_eng_tools/rag.py b/web/page_handlers/ml_eng_tools/rag.py index 1c83cda0..f8861e16 100644 --- a/web/page_handlers/ml_eng_tools/rag.py +++ b/web/page_handlers/ml_eng_tools/rag.py @@ -1,43 +1,53 @@ """ML Eng - Visualise index.""" +import json import logging as log import os -from typing import List +from typing import Any, List, Tuple import streamlit as st from docq.config import SpaceType from docq.data_source.list import SpaceDataSources -from docq.domain import Assistant, SpaceKey +from docq.domain import SpaceKey +from docq.manage_assistants import list_assistants from docq.manage_spaces import get_space_data_source, list_space from docq.model_selection.main import LlmUsageSettingsCollection, ModelCapability, get_saved_model_settings_collection from docq.support.llama_index.node_post_processors import reciprocal_rank_fusion from docq.support.llm import _get_service_context -from docq.support.rag_pipeline import rag_pipeline +from docq.support.rag_pipeline import generation_stage, search_stage from docq.support.store import ( + _DataScope, _get_path, _map_space_type_to_datascope, _StoreDir, ) from llama_index.core import VectorStoreIndex -from llama_index.core.base.response.schema import Response from llama_index.core.indices import load_index_from_storage from llama_index.core.indices.base import BaseIndex -from llama_index.core.llms import ChatMessage, MessageRole +from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole from llama_index.core.schema import BaseNode, Document, NodeWithScore from llama_index.core.storage import StorageContext -from llama_index.retrievers.bm25 import BM25Retriever from page_handlers.ml_eng_tools.visualise_index import visualise_vector_store_index -from streamlit.delta_generator import DeltaGenerator from utils.layout import auth_required, render_page_title_and_favicon +from utils.layout_assistants import ( + render_assistant_create_edit_ui, + render_assistants_selector_ui, + render_datascope_selector_ui, +) from utils.sessions import get_selected_org_id render_page_title_and_favicon(layout="wide") auth_required(requiring_selected_org_admin=True) +top_container = st.container() + above_tabs_container = st.container() -left_col, right_col = st.columns(2) -chat_tab, index_tab = left_col.tabs(["chat_tab", "index_tab"]) +above_left_col, above_right_col = above_tabs_container.columns(2) + +left_col, right_col = st.columns(2) +chat_tab, index_tab, assistant_tab = left_col.tabs(["Chat", "Vec Index", "Assistant"]) +stuff_tab, search_results_tab = right_col.tabs(["stuff", "search_results"]) # select a live space # create an experimental index @@ -68,7 +78,6 @@ def _load_experiment_index_from_storage( return load_index_from_storage(storage_context=storage_context, service_context=service_context) - def _load_vector_store_index( space: SpaceKey, model_settings_collection: LlmUsageSettingsCollection, exp_id: str ) -> VectorStoreIndex: @@ -135,10 +144,9 @@ def render_documents(docs: list[Document]): experiment_id = "sasdfasdf" selected_org_id = get_selected_org_id() -above_tabs_container.write(f"Experiment ID: `{experiment_id}` within Org ID: `{selected_org_id}`") -above_tabs_container.write( - f"All index creation and loading is contained withing a sub dir called `exp_{experiment_id}`" -) +with top_container: + st.write(f"Experiment ID: `{experiment_id}` within Org ID: `{selected_org_id}`") + st.write(f"All index creation and loading is contained withing a sub dir called `exp_{experiment_id}`") spaces = [] @@ -147,7 +155,7 @@ def render_documents(docs: list[Document]): spaces.extend(list_space(selected_org_id, SpaceType.SHARED.name)) spaces.extend(list_space(selected_org_id, SpaceType.THREAD.name)) # list_shared_spaces(org_id=selected_org_id) - selected_space = above_tabs_container.selectbox( + selected_space = above_left_col.selectbox( "Space", spaces, format_func=lambda x: x[2], @@ -156,6 +164,25 @@ def render_documents(docs: list[Document]): key="selected_space", ) +selected_assistant = None +with above_right_col: + datascope = render_datascope_selector_ui() + current_org_id = None + if datascope == _DataScope.ORG: + current_org_id = get_selected_org_id() + if current_org_id is None: + st.error("Please select an organisation") + st.stop() + st.write(f"Selected Organisation: {current_org_id}") + + assistants_data = list_assistants(org_id=current_org_id) + + selected_assistant = render_assistants_selector_ui(assistants_data=assistants_data) + + if selected_assistant: + with assistant_tab: + render_assistant_create_edit_ui(org_id=current_org_id, assistant_data=selected_assistant) + if selected_space and selected_org_id: log.debug("selected_space: ", selected_space) @@ -195,12 +222,11 @@ def render_documents(docs: list[Document]): visualise_vector_store_index(index_) - def prepare_chat(): - ch = st.session_state.get(f"rag_test_chat_history_content_{selected_space_key.id_}", None) + ch = st.session_state.get(f"rag_test_chat_history_content_{selected_space_key.id_}", []) if not ch: st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"] = [ - ChatMessage(role=MessageRole.ASSISTANT, content="Hello, ask me a question.") + (ChatMessage(role=MessageRole.ASSISTANT, content="Hello, ask me a question."), None) ] @@ -208,53 +234,12 @@ def prepare_chat(): prepare_chat() -def render_retrieval_query_ui(container: DeltaGenerator): - persona = Assistant( - "Assistant", - "Assistant", - "You are a helpful AI assistant. Only use the provided context to answer the query.", - "{query_str}", - "", - ) - from llama_index.core.retrievers import QueryFusionRetriever - from llama_index.core.retrievers.fusion_retriever import FUSION_MODES - - vector_retriever = index_.as_retriever(similarity_top_k=4) - - bm25_retriever = BM25Retriever.from_defaults(docstore=index_.docstore, similarity_top_k=4) - retriever = QueryFusionRetriever( - [vector_retriever, bm25_retriever], - similarity_top_k=4, - num_queries=4, # set this to 1 to disable query generation - mode=FUSION_MODES.RECIPROCAL_RANK, - use_async=False, - verbose=True, - llm=_get_service_context(saved_model_settings).llm, - # query_gen_prompt="...", # we could override the query generation prompt here - ) - - container.write("Ask a question to retrieve documents from the index.") - container.text_area("Query", key="retrieval_query_rag_test", value="who are the cofounders fo Docq?") - retrieve_button = container.button("Retrieve") - - if retrieve_button: - query = st.session_state.get("retrieval_query_rag_test", None) - if query: - # query_embed = _get_service_context(saved_model_settings).embed_model.get_query_embedding(query) - # query_bundle = QueryBundle(query_str=query, custom_embedding_strs=query, embedding=query_embed) - # ret_results = engine.retrieve(query_bundle) - ret_results = retriever.retrieve(query) - # resp = engine.query(query) - # print("ret_results: ", ret_results) - - render_retreival_results(ret_results) - - -def render_retreival_results(results: List[NodeWithScore]): +def render_retrieval_results(results: List[NodeWithScore]): st.write("Retrieved Chunks:") for node in results: + node_json = json.loads(node.to_json()) with st.expander(node.node_id): - st.write(node) + st.write(node_json) def handle_chat_input(): @@ -265,84 +250,96 @@ def handle_chat_input(): space_indices = [index_] - persona = Assistant( - "Assistant", - "Assistant", - "You are a helpful AI assistant. Do not previous knowledge. Only use the provided context to answer the query.", - "{query_str}", - "", - ) - # engine = index_.as_query_engine( - # llm=_get_service_context(saved_model_settings).llm, - # text_qa_template=llama_index_chat_prompt_template_from_persona(persona).partial_format(history_str=""), - # ) - - # from llama_index.core.retrievers import QueryFusionRetriever - # from llama_index.core.retrievers.fusion_retriever import FUSION_MODES - - # vector_retriever = index_.as_retriever(similarity_top_k=10) - - # bm25_retriever = BM25Retriever.from_defaults(docstore=index_.docstore, similarity_top_k=10) - # retriever = QueryFusionRetriever( - # [vector_retriever, bm25_retriever], - # similarity_top_k=5, - # num_queries=4, # set this to 1 to disable query generation - # mode=FUSION_MODES.RECIPROCAL_RANK, - # use_async=False, - # verbose=True, - # llm=_get_service_context(saved_model_settings).llm, - # # query_gen_prompt="...", # we could override the query generation prompt here - # ) - - # from llama_index.core.query_engine import RetrieverQueryEngine + if selected_assistant: + persona = selected_assistant - # query_engine = RetrieverQueryEngine.from_args( - # retriever, - # service_context=_get_service_context(saved_model_settings), - # text_qa_template=llama_index_chat_prompt_template_from_assistant(persona).partial_format(history_str=""), # noqa: F821 + # persona = Assistant( + # "Assistant", + # "Assistant", + # "You are a helpful AI assistant. Do not use previous knowledge to answer queries. Only use the provided context to answer the query. If you cannot provide a reasonable answer based on the given context and message history then say 'Sorry, I cannot provide an answer to that question.'", + # "context: {context_str}\nquery: {query_str}", + # "", # ) - chat_history: List[ChatMessage] = st.session_state.get( + chat_messages: List[Tuple[ChatMessage, Any]] = st.session_state.get( f"rag_test_chat_history_content_{selected_space_key.id_}", [] ) - print("chat_history: ", chat_history) + chat_history = [cm[0] for cm in chat_messages] query = st.session_state.get("chat_input_rag_test", None) - if query: + if query and persona: # query_embed = _get_service_context(saved_model_settings).embed_model.get_query_embedding(query) # query_bundle = QueryBundle(query_str=query, custom_embedding_strs=query, embedding=query_embed) # ret_results = query_engine.retrieve(query_bundle) # resp = query_engine.query(query) - resp = rag_pipeline( + search_results, search_debug = search_stage( + user_query=query, indices=space_indices, reranker=reciprocal_rank_fusion, top_k=6 + ) + + resp, gen_debug = generation_stage( user_query=query, - indices=space_indices, assistant=persona, + search_results=search_results, message_history=chat_history, llm=_get_service_context(saved_model_settings).llm, - reranker=lambda results: reciprocal_rank_fusion(results), - query_preprocessor=None, - top_k=6, + enable_debug=True, ) - if isinstance(resp, Response): + # resp = rag_pipeline( + # user_query=query, + # indices=space_indices, + # assistant=persona, + # message_history=chat_history, + # llm=_get_service_context(saved_model_settings).llm, + # reranker=lambda results: reciprocal_rank_fusion(results), + # query_preprocessor=None, + # top_k=6, + # ) + + # with search_results_tab: + # render_retrieval_results(search_results) + + if isinstance(resp, ChatResponse): # st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend([query, resp.response]) st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend( [ - ChatMessage(role=MessageRole.USER, content=query), - ChatMessage(role=MessageRole.ASSISTANT, content=resp.response), + (ChatMessage(role=MessageRole.USER, content=query), None), + (ChatMessage(role=MessageRole.ASSISTANT, content=resp.message.content), gen_debug), ] ) +def render_stuff_click_handler(debug) -> None: + for key, value in debug.items(): + if key == "search_results": + with search_results_tab: + render_retrieval_results(value) + else: + with stuff_tab.expander(key): + st.write() + st.write(value) + + def render_chat(): - # st.write(st.session_state.get("chat_input_rag_test", "Hello, ask a questions.")) - chat_history: list[ChatMessage] = st.session_state.get( + chat_history: List[Tuple[ChatMessage, Any]] = st.session_state.get( f"rag_test_chat_history_content_{selected_space_key.id_}", [] ) - for cm in chat_history: - st.write(f"{cm.role}: {cm.content}") + + for i, (cm, debug) in enumerate(chat_history): + col1, col2 = st.columns(spec=[0.9, 0.1], gap="small") # Adjust the column width ratios as needed + + with col1: + st.write(f"{cm.role.name}: {cm.content}") + + with col2: + if debug: + st.button( + ":bug:", + key=f"debug_bt_{i}", + on_click=lambda debug=debug: render_stuff_click_handler(debug), + ) st.chat_input( "Type your question here", @@ -355,9 +352,6 @@ def clear_chat(): st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"] = [] -if index_: - render_retrieval_query_ui(above_tabs_container) - with chat_tab: st.button("Clear Chat", on_click=clear_chat) diff --git a/web/utils/layout_assistants.py b/web/utils/layout_assistants.py index 0b1ba9b7..5ca44329 100644 --- a/web/utils/layout_assistants.py +++ b/web/utils/layout_assistants.py @@ -4,7 +4,7 @@ import streamlit as st from docq.domain import AssistantType -from docq.manage_assistants import ASSISTANT, create_or_update_assistant +from docq.manage_assistants import Assistant, create_or_update_assistant from docq.model_selection.main import LLM_MODEL_COLLECTIONS from docq.support.store import _DataScope from utils.error_ui import _handle_error_state_ui, set_error_state_for_ui @@ -28,45 +28,50 @@ """ -def render_assistant_create_edit_ui(org_id: Optional[int] = None, assistant_data: Optional[ASSISTANT] = None, key_suffix: Optional[str] = "new") -> None: +def render_assistant_create_edit_ui( + org_id: Optional[int] = None, assistant_data: Optional[Assistant] = None, key_suffix: Optional[str] = "new" +) -> None: """Render assistant create/edit form.""" _handle_error_state_ui(f"assistant_edit_form_error_{key_suffix}") with st.form(key=f"assistant_edit_{key_suffix}", clear_on_submit=True): button_label = "Create Assistant" assistant_id = None if assistant_data: - assistant_id = assistant_data[0] + assistant_id = assistant_data.key button_label = "Update Assistant" - st.write(f"ID: {assistant_data[0]}") - st.write(f"Created At: {assistant_data[7]} | Updated At: {assistant_data[8]}") + st.write(f"ID: {assistant_data.key}") + st.write(f"Created At: {assistant_data.created_at} | Updated At: {assistant_data.updated_at}") _handle_validation_state_ui("assistant_edit_form_name_validation_{key_suffix}") + st.text_input( label="Name", placeholder="Assistant 1", key=f"assistant_edit_name_{key_suffix}", - value=assistant_data[1] if assistant_data else None, + value=assistant_data.name if assistant_data else None, ) - print([key for key, _ in LLM_MODEL_COLLECTIONS.items()]) + # print([key for key, _ in LLM_MODEL_COLLECTIONS.items()]) st.selectbox( "Type", options=[persona_type for persona_type in AssistantType], format_func=lambda x: x.value, label_visibility="visible", key=f"assistant_edit_type_{key_suffix}", - index=[persona_type.name for persona_type in AssistantType].index(assistant_data[2]) if assistant_data else 1, + index=[persona_type.name for persona_type in AssistantType].index(str(assistant_data.type)) + if assistant_data + else 1, ) st.text_area( label="System Prompt Template", placeholder="", key=f"assistant_edit_system_prompt_template_{key_suffix}", - value=assistant_data[4] if assistant_data else EXAMPLE_SYSTEM_PROMPT, + value=assistant_data.system_message_content if assistant_data else EXAMPLE_SYSTEM_PROMPT, height=200, ) st.text_area( label="User Prompt Template", placeholder="", key=f"assistant_edit_user_prompt_template_{key_suffix}", - value=assistant_data[5] if assistant_data else EXAMPLE_USER_PROMPT_TEMPLATE, + value=assistant_data.user_prompt_template_content if assistant_data else EXAMPLE_USER_PROMPT_TEMPLATE, height=200, ) st.selectbox( @@ -75,7 +80,9 @@ def render_assistant_create_edit_ui(org_id: Optional[int] = None, assistant_data format_func=lambda x: x.name, label_visibility="visible", key=f"assistant_edit_model_settings_collection_{key_suffix}", - index=[key for key, _ in LLM_MODEL_COLLECTIONS.items()].index(assistant_data[6]) if assistant_data else 0, + index=[key for key, _ in LLM_MODEL_COLLECTIONS.items()].index(assistant_data.llm_settings_collection_key) + if assistant_data + else 0, ) st.text_input( label="Space Group ID", placeholder="(Optional) Space group for knowledge", key=f"assistant_edit_space_group_id_{key_suffix}" @@ -87,37 +94,37 @@ def render_assistant_create_edit_ui(org_id: Optional[int] = None, assistant_data def render_assistants_selector_ui( - assistants_data: list[ASSISTANT], selected_assistant_scoped_id: Optional[str] = None -) -> ASSISTANT | None: + assistants_data: list[Assistant], selected_assistant_scoped_id: Optional[str] = None +) -> Assistant | None: """Render assistants selector and create/edit assistant form.""" selected_index = 0 selected_assistant = None if selected_assistant_scoped_id: for i, assistant in enumerate(assistants_data): - if assistant[9] == selected_assistant_scoped_id: + if assistant.scoped_id == selected_assistant_scoped_id: selected_assistant = assistant selected_index = i break selected_assistant = st.selectbox( label="Assistant", - options=[assistant for assistant in assistants_data], - format_func=lambda x: x[1], + options=assistants_data, + format_func=lambda x: x.name, label_visibility="visible", index=selected_index, ) return selected_assistant -def render_assistants_listing_ui(assistants_data: list[ASSISTANT], org_id: Optional[int] = None) -> None: +def render_assistants_listing_ui(assistants_data: list[Assistant], org_id: Optional[int] = None) -> None: """Render assistants listing.""" for assistant in assistants_data: - with st.expander(f"{assistant[1]} ({assistant[0]})", expanded=False): - st.write(f"ID: {assistant[0]}") - st.write(f"Scoped ID: {assistant[9]}") - st.write(f"Created At: {assistant[7]} | Updated At: {assistant[8]}") - edit = st.button(label="Edit Assistant", key=f"edit_assistant_{assistant[0]}") + with st.expander(f"{assistant.name} ({assistant.key})", expanded=False): + st.write(f"ID: {assistant.key}") + st.write(f"Scoped ID: {assistant.scoped_id}") + st.write(f"Created At: {assistant.created_at} | Updated At: {assistant.updated_at}") + edit = st.button(label="Edit Assistant", key=f"edit_assistant_{assistant.key}") if edit: - render_assistant_create_edit_ui(org_id=org_id, assistant_data=assistant, key_suffix=str(assistant[0])) + render_assistant_create_edit_ui(org_id=org_id, assistant_data=assistant, key_suffix=str(assistant.key)) def handle_assistant_create_edit(org_id: Optional[int] = None, assistant_id: Optional[int] = None, key_suffix: Optional[str] = "new") -> None: From 081a86f015757fe82e09e1f96be87ea4763f1fc7 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 5 Oct 2024 23:57:16 +0100 Subject: [PATCH 03/14] refactor: move Assistants management to admin section tab hide unused tools --- web/index.py | 22 ++++++------- web/page_handlers/admin/admin_assistants.py | 34 +++++++++++++++++++++ web/page_handlers/admin/index.py | 28 ++++++++++++++--- 3 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 web/page_handlers/admin/admin_assistants.py diff --git a/web/index.py b/web/index.py index c1039042..39bc1023 100644 --- a/web/index.py +++ b/web/index.py @@ -49,18 +49,18 @@ ), # Do not remove: This is used as the G Drive data source integration auth redirect page StPage(page="page_handlers/admin/index.py", title="Admin Section", url_path="Admin_Section", icon="💂"), "🤖  Tools", - StPage(page="page_handlers/ml_eng_tools/assistants.py", title="Assistants", url_path="Assistants"), - StPage( - page="page_handlers/ml_eng_tools/visualise_index.py", - title="Visualise Index", - url_path="Visualise_Index", - ), + # StPage(page="page_handlers/ml_eng_tools/assistants.py", title="Assistants", url_path="Assistants"), + # StPage( + # page="page_handlers/ml_eng_tools/visualise_index.py", + # title="Visualise Index", + # url_path="Visualise_Index", + # ), StPage(page="page_handlers/ml_eng_tools/rag.py", title="RAG", url_path="RAG"), - StPage( - page="page_handlers/ml_eng_tools/visualise_agent_messages.py", - title="Visualise Agent Messages", - url_path="Visualise_Agent_Messages", - ), + # StPage( + # page="page_handlers/ml_eng_tools/visualise_agent_messages.py", + # title="Visualise Agent Messages", + # url_path="Visualise_Agent_Messages", + # ), ] pages = [] diff --git a/web/page_handlers/admin/admin_assistants.py b/web/page_handlers/admin/admin_assistants.py new file mode 100644 index 00000000..20372400 --- /dev/null +++ b/web/page_handlers/admin/admin_assistants.py @@ -0,0 +1,34 @@ +"""Page: Admin / Manage Orgs.""" + +import streamlit as st +from docq.manage_assistants import list_assistants +from docq.support.store import _DataScope +from utils.layout import tracer +from utils.layout_assistants import ( + render_assistant_create_edit_ui, + render_assistants_listing_ui, + render_assistants_selector_ui, + render_datascope_selector_ui, +) +from utils.sessions import get_selected_org_id + + +@tracer.start_as_current_span("admin_assistants_page") +def admin_assistants_page() -> None: + """Page: Admin / Manage Assistants.""" + datascope = render_datascope_selector_ui() + current_org_id = None + if datascope == _DataScope.ORG: + current_org_id = get_selected_org_id() + if current_org_id is None: + st.error("Please select an organisation") + st.stop() + st.write(f"Selected Organisation: {current_org_id}") + + assistants_data = list_assistants(org_id=current_org_id) + + render_assistants_selector_ui(assistants_data=assistants_data) + with st.expander("+New Assistant", expanded=False): + render_assistant_create_edit_ui(current_org_id) + + render_assistants_listing_ui(assistants_data=assistants_data, org_id=current_org_id) \ No newline at end of file diff --git a/web/page_handlers/admin/index.py b/web/page_handlers/admin/index.py index dd34b6fc..e811a663 100644 --- a/web/page_handlers/admin/index.py +++ b/web/page_handlers/admin/index.py @@ -5,6 +5,7 @@ from utils.observability import baggage_as_attributes, tracer from utils.sessions import is_current_user_selected_org_admin, is_current_user_super_admin +from web.page_handlers.admin.admin_assistants import admin_assistants_page from web.page_handlers.admin.admin_integrations import admin_integrations_page from web.page_handlers.admin.admin_logs import admin_logs_page from web.page_handlers.admin.admin_orgs import admin_orgs_page @@ -18,6 +19,7 @@ def super_and_org_admin_pages() -> None: """Sections if both super admin and current selected org admin.""" ( + admin_assistants, admin_orgs, admin_users, admin_user_groups, @@ -26,10 +28,22 @@ def super_and_org_admin_pages() -> None: admin_settings, admin_chat_integrations, admin_logs, - ) = st.tabs(["Orgs", "Users", "User Groups", "Spaces", "Space Groups", "Settings", "Chat Integrations", "Logs"]) - - + ) = st.tabs( + [ + "Assistants", + "Orgs", + "Users", + "User Groups", + "Spaces", + "Space Groups", + "Settings", + "Chat Integrations", + "Logs", + ] + ) + with admin_assistants: + admin_assistants_page() with admin_orgs: admin_orgs_page() @@ -59,6 +73,7 @@ def super_and_org_admin_pages() -> None: def org_admin_pages() -> None: """Sections if only org admin.""" ( + admin_assistants, admin_orgs, admin_users, admin_user_groups, @@ -67,7 +82,12 @@ def org_admin_pages() -> None: admin_settings, admin_chat_integrations, admin_logs, - ) = st.tabs(["Org", "Users", "User Groups", "Spaces", "Space Groups", "Settings", "Chat Integrations", "Logs"]) + ) = st.tabs( + ["Assistants", "Org", "Users", "User Groups", "Spaces", "Space Groups", "Settings", "Chat Integrations", "Logs"] + ) + + with admin_assistants: + admin_assistants_page() with admin_orgs: admin_orgs_page() From 2e862f407c2472227a288cbdf610195f037ca85e Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 5 Oct 2024 23:58:11 +0100 Subject: [PATCH 04/14] build: bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 427acf6c..50654ae0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "docq" -version = "0.13.8" +version = "0.13.10" description = "Docq.AI - Your private ChatGPT alternative. Securely unlock knowledge from confidential documents." authors = ["Docq.AI Team "] maintainers = ["Docq.AI Team "] From 811c34a9dd585b39486d1a96286ced6f6c78a26f Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sun, 6 Oct 2024 12:47:33 +0100 Subject: [PATCH 05/14] fix(Assistants): breaks from Assistants dataclass refactor --- web/utils/layout.py | 42 ++++++++++++------------------------------ 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/web/utils/layout.py b/web/utils/layout.py index 37a05147..79cda55c 100644 --- a/web/utils/layout.py +++ b/web/utils/layout.py @@ -304,15 +304,6 @@ def __no_staff_menu() -> None: ] ) - -@tracer.start_as_current_span("__no_admin_menu") -def __no_admin_menu() -> None: - # hide_pages(["Admin_Section"]) - # hide_pages(["ML Engineering"]) - # FIXME: new to reimplement this the new the ST 1.37 way - pass - - def __embed_page_config() -> None: st.markdown( """ @@ -353,17 +344,10 @@ def __hide_all_empty_divs() -> None: unsafe_allow_html=True, ) - -def __always_hidden_pages() -> None: - """These pages are always hidden whether the user is an admin or not.""" - # hide_pages(["widget", "signup", "verify", "Admin_Spaces"]) - # FIXME: new to reimplement this the new the ST 1.37 way - - def render_page_title_and_favicon( page_display_title: Optional[str] = None, browser_title: Optional[str] = None, - layout: Optional[Layout] = "centered", + layout: Optional[Layout] = None, ) -> None: """Handle setting browser page title and favicon. Separately render in app page title with icon defined in show_pages(). @@ -408,7 +392,7 @@ def render_page_title_and_favicon( page_icon=favicon_path, page_title=browser_title if browser_title else f"{browser_title_prefix} - {_page_display_title}", menu_items={"About": about_menu_content}, - layout=layout, + layout=layout or "centered", ) except StreamlitAPIException: pass @@ -443,7 +427,7 @@ def __resend_verification_ui( @tracer.start_as_current_span("render __login_form") def __login_form() -> None: - __no_admin_menu() + # __no_admin_menu() if system_feature_enabled(SystemFeatureType.FREE_USER_SIGNUP, show_message=False): st.markdown("Don't have an account? Signup here", unsafe_allow_html=True) @@ -494,8 +478,7 @@ def __not_authorised() -> None: def public_access() -> None: """Menu options for public access.""" # __no_staff_menu() - __no_admin_menu() - __always_hidden_pages() + # __no_admin_menu() __hide_all_empty_divs() @@ -508,7 +491,6 @@ def auth_required( span = trace.get_current_span() span.add_event("Checking authorisation") auth = None - __always_hidden_pages() __hide_all_empty_divs() session_state_existed = session_state_exists() @@ -540,8 +522,8 @@ def auth_required( # if show_logout_button: # __logout_button() - if not auth.get(SessionKeyNameForAuth.SELECTED_ORG_ADMIN.name, False): - __no_admin_menu() + if not auth.get(SessionKeyNameForAuth.SELECTED_ORG_ADMIN.name, False): # noqa: SIM102 + # __no_admin_menu() if requiring_selected_org_admin: __not_authorised() return False @@ -966,7 +948,7 @@ def _render_assistant_selection(feature: FeatureKey) -> None: ) if selected: - selected_assistant_scoped_id = selected[9] + selected_assistant_scoped_id = selected.scoped_id set_selected_assistant(selected_assistant_scoped_id) @@ -1576,9 +1558,8 @@ def _render_view_space_details_with_container( ) -> DeltaGenerator: id_, org_id, name, summary, archived, ds_type, ds_configs, _, created_at, updated_at = space_data has_view_perm = org_id == get_selected_org_id() - + container = st.expander(format_archived(name, archived)) if use_expander else st.container() if has_view_perm: - container = st.expander(format_archived(name, archived)) if use_expander else st.container() with container: if not use_expander: st.write(format_archived(name, archived)) @@ -1587,7 +1568,7 @@ def _render_view_space_details_with_container( st.write(f"Summary: _{summary}_") st.write(f"Type: **{data_source[1]}**") st.write(f"Created At: {format_datetime(created_at)} | Updated At: {format_datetime(updated_at)}") - return container + return container def _render_edit_space_details_form(space_data: Tuple, data_source: Tuple) -> None: @@ -1684,8 +1665,9 @@ def list_spaces_ui(admin_access: bool = False) -> None: def show_space_details_ui(space: SpaceKey) -> None: """Show details of a space.""" s = get_shared_space(space.id_) - ds = get_space_data_source_choice_by_type(s[5]) - _render_view_space_details_with_container(s, ds) + if s: + ds = get_space_data_source_choice_by_type(s[5]) + _render_view_space_details_with_container(s, ds) def list_logs_ui(type_: LogType) -> None: From 33ca13ed42c939dd395bcfe6bd76417434a804e9 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sun, 6 Oct 2024 14:46:20 +0100 Subject: [PATCH 06/14] chore: add GH Copilot instructions file. --- .copilot-instructions.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .copilot-instructions.md diff --git a/.copilot-instructions.md b/.copilot-instructions.md new file mode 100644 index 00000000..c4a55755 --- /dev/null +++ b/.copilot-instructions.md @@ -0,0 +1,3 @@ +- ALWAYS use all the code in the active code file. +- Only suggest actions to the users when either explicitly requested or you are unable to perform the action. +- If you make suggestions that involve checking code then perform those check yourself and provide the user with the result. \ No newline at end of file From 8f998f7bbfb86cb11a7fa9690dcaa6486daa58cc Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sun, 6 Oct 2024 14:46:41 +0100 Subject: [PATCH 07/14] tests: fix broken test from Assistant refactor --- tests/integration/backend_integration_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/integration/backend_integration_test.py b/tests/integration/backend_integration_test.py index 2896e942..62ad4056 100644 --- a/tests/integration/backend_integration_test.py +++ b/tests/integration/backend_integration_test.py @@ -2,6 +2,7 @@ import os from contextlib import suppress +from datetime import datetime from shutil import rmtree from typing import Generator @@ -146,10 +147,15 @@ def test_chat_private_feature(features: dict[str, domain.FeatureKey], saved_mode persona = domain.Assistant( key="test-persona", + scoped_id="global_test-persona", name="Test Persona", + type=domain.AssistantType.SIMPLE_CHAT, system_message_content=system_prompt, user_prompt_template_content=user_prompt_template_content, llm_settings_collection_key=saved_model_settings.key, + archived=False, + created_at=datetime(2021, 1, 1), + updated_at=datetime(2021, 1, 1), ) thread_id = 0 From 389c3635e522862eea42f5d2ffac3a85ae43804a Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Mon, 7 Oct 2024 01:42:09 +0100 Subject: [PATCH 08/14] feat(ML tool): add the HyDE query preprocessing to the RAG tool --- source/docq/support/rag_pipeline.py | 57 +++++++++++++++++-- web/page_handlers/ml_eng_tools/rag.py | 42 +++++++++----- .../ml_eng_tools/visualise_index.py | 26 +-------- .../visualise_vector_store_index.py | 26 +++++++++ 4 files changed, 109 insertions(+), 42 deletions(-) create mode 100644 web/page_handlers/ml_eng_tools/visualise_vector_store_index.py diff --git a/source/docq/support/rag_pipeline.py b/source/docq/support/rag_pipeline.py index ced89ea5..9e253fea 100644 --- a/source/docq/support/rag_pipeline.py +++ b/source/docq/support/rag_pipeline.py @@ -15,7 +15,9 @@ def search_stage( indices: List[BaseIndex], reranker: Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]] | Callable[[Dict[str, List[NodeWithScore]]], List[NodeWithScore]], - query_preprocessor: Optional[Callable[[str], List[str]]] = None, + llm: LLM, + message_history: List[ChatMessage], + query_preprocessor: Optional[Callable[[LLM, str, List[ChatMessage]], List[str]]] = None, top_k: int = 10, enable_debug: Optional[bool] = False, ) -> Tuple[List[NodeWithScore], Dict[str, Any]]: @@ -25,7 +27,7 @@ def search_stage( user_query (str): The user query. indices (List[BaseIndex]): The list of indices to search. reranker (Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]]): The reranker to use. `func(search_results: List[List[NodeWithScore]], user_query: Optional[List[str]] = None) -> List[NodeWithScore]`. If not provided, a default reranker that uses X is used. - query_preprocessor (Optional[Callable[[str], List[str]]]): The preprocessor to use. `func(user_query: str) -> List[str]`. Defaults to no preprocessing. + query_preprocessor (Optional[Callable[[LLM, str, List[ChatMessage]], List[str]]]): The preprocessor to use. `func(user_query: str) -> List[str]`. Defaults to no preprocessing. top_k (int): The number of results to return per search. Returns: @@ -46,7 +48,7 @@ def search_stage( ) # 2. Preprocess user query if preprocessor is provided - processed_queries = query_preprocessor(user_query) if query_preprocessor else [user_query] + processed_queries = query_preprocessor(llm, user_query, message_history) if query_preprocessor else [user_query] # 3. Run the list of queries through each retriever vector_results = {} @@ -141,7 +143,7 @@ def rag_pipeline( llm: LLM, reranker: Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]] | Callable[[Dict[str, List[NodeWithScore]]], List[NodeWithScore]], - query_preprocessor: Optional[Callable[[str], List[str]]] = None, + query_preprocessor: Optional[Callable[[LLM, str, List[ChatMessage]], List[str]]] = None, top_k: int = 10, ) -> ChatResponse: """Orchestrates the RAG pipeline. @@ -153,7 +155,7 @@ def rag_pipeline( message_history (List[ChatMessage]): The message history. llm (LLM): The LLM. reranker (Callable[[Dict[str, List[NodeWithScore]], Optional[List[str]]], List[NodeWithScore]]): The reranker to use. `func(search_results: List[List[NodeWithScore]], user_query: Optional[str] = None) -> List[NodeWithScore]`. If not provided, a default reranker that uses X is used. - query_preprocessor (Optional[Callable[[str], List[str]]]): The preprocessor to use. `func(user_query: str) -> List[str]`. Defaults to no preprocessing. + query_preprocessor (Optional[Callable[[LLM, str, List[ChatMessage]], List[str]]]): The preprocessor to use. `func(user_query: str) -> List[str]`. Defaults to no preprocessing. top_k (int): The number of results to return per search. Returns: @@ -161,7 +163,13 @@ def rag_pipeline( """ # Search stage search_results, debug = search_stage( - user_query=user_query, indices=indices, reranker=reranker, query_preprocessor=query_preprocessor, top_k=top_k + user_query=user_query, + indices=indices, + reranker=reranker, + llm=llm, + message_history=message_history, + query_preprocessor=query_preprocessor, + top_k=top_k, ) # Generation stage @@ -174,3 +182,40 @@ def rag_pipeline( ) return response + + +def hyde_query_preprocessor(llm: LLM, user_query: str, history: List[ChatMessage]) -> List[str]: + """Hyde query preprocessor. + + Args: + user_query (str): The user query. + + Returns: + List[str]: The list of processed queries. + """ + HYDE_TMPL = ( + "Please write a passage to answer the \n" + "Use the current conversation available in \n" + "Try to include as many key details as possible.\n" + "\n" + "\n" + "{chat_history_str}\n" + "\n" + "\n" + "{query_str}\n" + "\n" + "\n" + "\n" + 'Passage:"""\n' + ) + history_str = "\n".join([str(x) for x in history]) + prompt = HYDE_TMPL.format(chat_history_str=history_str, query_str=user_query) + + response = llm.complete(prompt=prompt, formatted=True) + # hyde_template = PromptTemplate(template=HYDE_TMPL, prompt_type=PromptType.SUMMARY) + # span.add_event(name="hyde_prompt_template_created", attributes={"template": str(hyde_template)}) + # hyde_query_transform_component = HyDEQueryTransform( + # llm=llm, hyde_prompt=hyde_template, prompt_args={"chat_history_str": history_str} + # ).as_query_component() + + return [response.text] diff --git a/web/page_handlers/ml_eng_tools/rag.py b/web/page_handlers/ml_eng_tools/rag.py index f8861e16..73c1abd8 100644 --- a/web/page_handlers/ml_eng_tools/rag.py +++ b/web/page_handlers/ml_eng_tools/rag.py @@ -14,7 +14,7 @@ from docq.model_selection.main import LlmUsageSettingsCollection, ModelCapability, get_saved_model_settings_collection from docq.support.llama_index.node_post_processors import reciprocal_rank_fusion from docq.support.llm import _get_service_context -from docq.support.rag_pipeline import generation_stage, search_stage +from docq.support.rag_pipeline import generation_stage, hyde_query_preprocessor, search_stage from docq.support.store import ( _DataScope, _get_path, @@ -27,7 +27,6 @@ from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole from llama_index.core.schema import BaseNode, Document, NodeWithScore from llama_index.core.storage import StorageContext -from page_handlers.ml_eng_tools.visualise_index import visualise_vector_store_index from utils.layout import auth_required, render_page_title_and_favicon from utils.layout_assistants import ( render_assistant_create_edit_ui, @@ -36,6 +35,8 @@ ) from utils.sessions import get_selected_org_id +from web.page_handlers.ml_eng_tools.visualise_vector_store_index import visualise_vector_store_index + render_page_title_and_favicon(layout="wide") auth_required(requiring_selected_org_admin=True) @@ -226,7 +227,7 @@ def prepare_chat(): ch = st.session_state.get(f"rag_test_chat_history_content_{selected_space_key.id_}", []) if not ch: st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"] = [ - (ChatMessage(role=MessageRole.ASSISTANT, content="Hello, ask me a question."), None) + (ChatMessage(role=MessageRole.ASSISTANT, content="Hello, ask me a question."), None, None) ] @@ -274,8 +275,17 @@ def handle_chat_input(): # ret_results = query_engine.retrieve(query_bundle) # resp = query_engine.query(query) + llm = _get_service_context(saved_model_settings).llm + search_results, search_debug = search_stage( - user_query=query, indices=space_indices, reranker=reciprocal_rank_fusion, top_k=6 + user_query=query, + indices=space_indices, + reranker=reciprocal_rank_fusion, + llm=llm, + message_history=chat_history, + top_k=6, + query_preprocessor=hyde_query_preprocessor, + enable_debug=True, ) resp, gen_debug = generation_stage( @@ -283,7 +293,7 @@ def handle_chat_input(): assistant=persona, search_results=search_results, message_history=chat_history, - llm=_get_service_context(saved_model_settings).llm, + llm=llm, enable_debug=True, ) @@ -305,14 +315,14 @@ def handle_chat_input(): # st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend([query, resp.response]) st.session_state[f"rag_test_chat_history_content_{selected_space_key.id_}"].extend( [ - (ChatMessage(role=MessageRole.USER, content=query), None), - (ChatMessage(role=MessageRole.ASSISTANT, content=resp.message.content), gen_debug), + (ChatMessage(role=MessageRole.USER, content=query), None, None), + (ChatMessage(role=MessageRole.ASSISTANT, content=resp.message.content), gen_debug, search_debug), ] ) -def render_stuff_click_handler(debug) -> None: - for key, value in debug.items(): +def render_stuff_click_handler(gen_debug, search_debug) -> None: + for key, value in gen_debug.items(): if key == "search_results": with search_results_tab: render_retrieval_results(value) @@ -321,24 +331,30 @@ def render_stuff_click_handler(debug) -> None: st.write() st.write(value) + with stuff_tab.expander("Processed Queries"): + st.write(search_debug["processed_queries"]) + def render_chat(): - chat_history: List[Tuple[ChatMessage, Any]] = st.session_state.get( + chat_history: List[Tuple[ChatMessage, Any, Any]] = st.session_state.get( f"rag_test_chat_history_content_{selected_space_key.id_}", [] ) - for i, (cm, debug) in enumerate(chat_history): + for i, (cm, gen_debug, search_debug) in enumerate(chat_history): col1, col2 = st.columns(spec=[0.9, 0.1], gap="small") # Adjust the column width ratios as needed with col1: st.write(f"{cm.role.name}: {cm.content}") with col2: - if debug: + print(gen_debug == None, search_debug == None) + if gen_debug != None and search_debug != None: st.button( ":bug:", key=f"debug_bt_{i}", - on_click=lambda debug=debug: render_stuff_click_handler(debug), + on_click=lambda gen_debug=gen_debug, search_debug=search_debug: render_stuff_click_handler( + gen_debug, search_debug + ), ) st.chat_input( diff --git a/web/page_handlers/ml_eng_tools/visualise_index.py b/web/page_handlers/ml_eng_tools/visualise_index.py index d0fe88a7..32bfca74 100644 --- a/web/page_handlers/ml_eng_tools/visualise_index.py +++ b/web/page_handlers/ml_eng_tools/visualise_index.py @@ -1,5 +1,5 @@ """ML Eng - Visualise index.""" -import json + from typing import Optional, cast import streamlit as st @@ -16,6 +16,8 @@ from utils.layout import auth_required, render_page_title_and_favicon from utils.sessions import get_selected_org_id +from ..ml_eng_tools.visualise_vector_store_index import visualise_vector_store_index + render_page_title_and_favicon() auth_required(requiring_selected_org_admin=True) @@ -106,28 +108,6 @@ def visualise_document_summary_index(_index: DocumentSummaryIndex) -> None: st.divider() -def visualise_vector_store_index(_index: VectorStoreIndex) -> None: - """Visualise document summary index.""" - docs = _index.docstore.docs - st.write("Index class: ", _index.index_struct_cls.__name__) - for doc_id in docs: - doc_json = json.loads(docs[doc_id].to_json()) - # st.write(docs[doc_id].get_content(metadata_mode=MetadataMode.ALL)) - metadata_keys = doc_json["metadata"].keys() - - with st.expander(label=doc_id): - st.write(doc_json) - - if "excerpt_keywords" in metadata_keys: - keyword_count = len(doc_json["metadata"]["excerpt_keywords"].split(", ")) - st.write(f"Metadata Keyword Count: {keyword_count}") - - # for key, entity_label in DEFAULT_ENTITY_MAP.items(): - # if entity_label in metadata_keys: - # x_count = len(doc_json["metadata"][entity_label]) - # st.write(f"Metadata Entity '{entity_label}' count: {x_count}") - - if selected_space and selected_org_id: print("selected_space type: ", selected_space[7]) diff --git a/web/page_handlers/ml_eng_tools/visualise_vector_store_index.py b/web/page_handlers/ml_eng_tools/visualise_vector_store_index.py new file mode 100644 index 00000000..1f440fbb --- /dev/null +++ b/web/page_handlers/ml_eng_tools/visualise_vector_store_index.py @@ -0,0 +1,26 @@ +import json + +import streamlit as st +from llama_index.core.indices import VectorStoreIndex + + +def visualise_vector_store_index(_index: VectorStoreIndex) -> None: + """Visualise document summary index.""" + docs = _index.docstore.docs + st.write("Index class: ", _index.index_struct_cls.__name__) + for doc_id in docs: + doc_json = json.loads(docs[doc_id].to_json()) + # st.write(docs[doc_id].get_content(metadata_mode=MetadataMode.ALL)) + metadata_keys = doc_json["metadata"].keys() + + with st.expander(label=doc_id): + st.write(doc_json) + + if "excerpt_keywords" in metadata_keys: + keyword_count = len(doc_json["metadata"]["excerpt_keywords"].split(", ")) + st.write(f"Metadata Keyword Count: {keyword_count}") + + # for key, entity_label in DEFAULT_ENTITY_MAP.items(): + # if entity_label in metadata_keys: + # x_count = len(doc_json["metadata"][entity_label]) + # st.write(f"Metadata Entity '{entity_label}' count: {x_count}") From d898da8b4afd65c70eeb596f43c066d97a6552ad Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Mon, 7 Oct 2024 01:42:48 +0100 Subject: [PATCH 09/14] style: clean debug line --- web/page_handlers/ml_eng_tools/rag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/web/page_handlers/ml_eng_tools/rag.py b/web/page_handlers/ml_eng_tools/rag.py index 73c1abd8..bfea003d 100644 --- a/web/page_handlers/ml_eng_tools/rag.py +++ b/web/page_handlers/ml_eng_tools/rag.py @@ -347,7 +347,6 @@ def render_chat(): st.write(f"{cm.role.name}: {cm.content}") with col2: - print(gen_debug == None, search_debug == None) if gen_debug != None and search_debug != None: st.button( ":bug:", From ab5e3374d8d8461a374ddb658c32c28e2afbafe3 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Tue, 8 Oct 2024 00:09:39 +0100 Subject: [PATCH 10/14] feat(api): add file upload endpoint --- web/api/files_handler.py | 77 ++++++++++++++++++++++++++++++++++++++++ web/api/models.py | 6 ++++ 2 files changed, 83 insertions(+) create mode 100644 web/api/files_handler.py diff --git a/web/api/files_handler.py b/web/api/files_handler.py new file mode 100644 index 00000000..99d4e48c --- /dev/null +++ b/web/api/files_handler.py @@ -0,0 +1,77 @@ +"""This module contains the API endpoint to upload files.""" + +import logging +from typing import Self + +from docq.config import SpaceType +from docq.domain import SpaceKey +from docq.manage_documents import upload +from pydantic import ValidationError +from tornado.web import HTTPError, escape + +from web.api.base_handlers import BaseRequestHandler +from web.api.models import FileUploadRequestModel +from web.api.utils.auth_utils import authenticated +from web.utils.streamlit_application import st_app + +# Configuration +ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "md", "docx", "pptx", "xlsx"} + + +def allowed_file(filename: str) -> bool: + """Check if the file has an allowed extension.""" + return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + + +@st_app.api_route("/api/v1/files/{}/upload") +class UploadFileHandler(BaseRequestHandler): + """Handle /api/v1/files/upload requests.""" + + @authenticated + def post(self: Self) -> None: + """Handle POST request to upload a file. + + Upload one of more files to a space. + """ + try: + requestPayload = FileUploadRequestModel.model_validate_json(self.request.body) + if "file" not in self.request.files: + raise HTTPError(400, reason="No file part") + + files = self.request.files["file"] + for file_info in files: + filename = file_info["filename"] + + if filename == "": + raise HTTPError(400, reason="No selected file") + + if not allowed_file(filename): + raise HTTPError( + 400, + reason=f"File type not allowed. Allowed file types are: {', '.join(list(ALLOWED_EXTENSIONS))}", + ) + + # Secure the filename + filename = escape.native_str(escape.url_escape(filename)) + space_type = SpaceType(requestPayload.space_type) + + if not space_type: + raise HTTPError(400, reason="Invalid space type") + + space_key = SpaceKey( + org_id=self.selected_org_id, + id_=requestPayload.space_id, + type_=space_type, + ) + + upload(filename=filename, content=file_info["body"], space=space_key) + + self.set_status(201) # 201 Created + self.write({"message": f"{len(files)} File(s) successfully uploaded"}) + except ValidationError as e: + raise HTTPError(400, reason="Bad request") from e + except HTTPError as e: + raise e + except Exception as e: + logging.error("Error: ", e) + raise HTTPError(500, reason="Internal server error", log_message=f"Error: {str(e)}") from e diff --git a/web/api/models.py b/web/api/models.py index 280c9794..1a743084 100644 --- a/web/api/models.py +++ b/web/api/models.py @@ -100,3 +100,9 @@ class SpacesResponseModel(BaseResponseModel): class ThreadPostRequestModel(CamelModel): """Pydantic model for the request body.""" topic: str + +class FileUploadRequestModel(CamelModel): + """Request model for file upload.""" + + space_id: int + space_type: SPACE_TYPE From 5c79ea59afbabb7d523c55eb54eb1e8e60c3274c Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 12 Oct 2024 01:18:25 +0100 Subject: [PATCH 11/14] refactor(API): spaces file upload and listing endpoint --- source/docq/manage_spaces.py | 4 +- source/docq/run_queries.py | 17 ++-- web/api/base_handlers.py | 5 +- web/api/files_handler.py | 77 ---------------- web/api/index_handler.py | 2 + web/api/models.py | 20 +++-- web/api/spaces_files_handler.py | 137 +++++++++++++++++++++++++++++ web/api/spaces_handler.py | 36 ++++---- web/api/threads_handler.py | 17 +++- web/utils/streamlit_application.py | 5 ++ 10 files changed, 208 insertions(+), 112 deletions(-) delete mode 100644 web/api/files_handler.py create mode 100644 web/api/spaces_files_handler.py diff --git a/source/docq/manage_spaces.py b/source/docq/manage_spaces.py index d6a35694..56fe9b5f 100644 --- a/source/docq/manage_spaces.py +++ b/source/docq/manage_spaces.py @@ -72,7 +72,7 @@ def _format_space(row: Any) -> SPACE: row: (id, org_id, name, summary, archived, datasource_type, datasource_configs, space_type, created_at, updated_at) Returns: - tuple[int, int, str, str, bool, str, dict, datetime, datetime] - [id, org_id, name, summary, archived, datasource_type, datasource_configs, created_at, updated_at] + tuple[int, int, str, str, bool, str, dict, datetime, datetime] - [id, org_id, name, summary, archived, datasource_type, datasource_configs, space_type, created_at, updated_at] """ return (row[0], row[1], row[2], row[3], bool(row[4]), row[5], json.loads(row[6]), row[7], row[8], row[9]) @@ -255,7 +255,7 @@ def get_shared_spaces(space_ids: List[int]) -> list[SPACE]: """Get a shared spaces by ids. Returns: - list[tuple[int, int, str, str, bool, str, dict, datetime, datetime]] - [id, org_id, name, summary, archived, datasource_type, datasource_configs, created_at, updated_at] + list[tuple[int, int, str, str, bool, str, dict, datetime, datetime]] - [id, org_id, name, summary, archived, datasource_type, datasource_configs, space_type, created_at, updated_at] """ log.debug("get_shared_spaces(): Getting space with ids=%s", space_ids) with closing( diff --git a/source/docq/run_queries.py b/source/docq/run_queries.py index 68053586..d532e4ee 100644 --- a/source/docq/run_queries.py +++ b/source/docq/run_queries.py @@ -9,8 +9,7 @@ from llama_index.core.llms import ChatMessage, MessageRole from docq.config import OrganisationFeatureType -from docq.domain import FeatureKey, SpaceKey -from docq.manage_assistants import Assistant +from docq.domain import Assistant, FeatureKey, SpaceKey from docq.manage_documents import format_document_sources from docq.model_selection.main import LlmUsageSettingsCollection from docq.support.llm import query_error, run_ask, run_chat @@ -18,6 +17,7 @@ get_history_table_name, get_history_thread_table_name, get_public_sqlite_usage_file, + get_sqlite_shared_system_file, get_sqlite_usage_file, ) @@ -125,7 +125,7 @@ def _retrieve_messages( return rows -def list_thread_history(feature: FeatureKey, id_: Optional[int] = None) -> list[tuple[int, str, int]]: +def list_thread_history(feature: FeatureKey, id_: Optional[int] = None) -> list[tuple[int, str, int, int]]: """List threads or a thread if id_ is provided.""" tablename = get_history_thread_table_name(feature.type_) rows = None @@ -137,10 +137,17 @@ def list_thread_history(feature: FeatureKey, id_: Optional[int] = None) -> list[ table=tablename, ) ) + + connection.execute(f"ATTACH DATABASE '{get_sqlite_shared_system_file()}' AS db2") if id_: - rows = cursor.execute(f"SELECT id, topic, created_at FROM {tablename} WHERE id = ?", (id_,)).fetchall() # noqa: S608 + rows = cursor.execute( + f"SELECT t.id, t.topic, t.created_at, s.id as space_id FROM {tablename} as t LEFT JOIN db2.spaces AS s ON s.name LIKE 'Thread-' || t.id || ' %' WHERE t.id = ?", + (id_,), + ).fetchall() # noqa: S608 else: - rows = cursor.execute(f"SELECT id, topic, created_at FROM {tablename} ORDER BY created_at DESC").fetchall() # noqa: S608 + rows = cursor.execute( + f"SELECT t.id, t.topic, t.created_at, s.id as space_id FROM {tablename} as t LEFT JOIN db2.spaces as s ON s.name LIKE 'Thread-' || t.id || ' %' ORDER BY t.created_at DESC", + ).fetchall() # noqa: S608 return rows diff --git a/web/api/base_handlers.py b/web/api/base_handlers.py index 51318622..bc90cdca 100644 --- a/web/api/base_handlers.py +++ b/web/api/base_handlers.py @@ -26,6 +26,7 @@ def check_xsrf_cookie(self: Self) -> bool: """Override the XSRF cookie check.""" # If `True`, POST, PUT, and DELETE are block unless the `_xsrf` cookie is set. # Safe with token based authN + print("check_xsrf_cookie() called") return False @property @@ -56,7 +57,9 @@ def write_error(self: Self, status_code: int, **kwargs: Any) -> None: error_response["reason"] = exc_value.reason error_response["statusCode"] = status_code - self.finish(json.dumps(error_response)) + resp_json = json.dumps(error_response) + print("write_error() called: ", resp_json) + self.finish(resp_json) # auth_header = self.request.headers.get("Authorization") # if not auth_header: diff --git a/web/api/files_handler.py b/web/api/files_handler.py deleted file mode 100644 index 99d4e48c..00000000 --- a/web/api/files_handler.py +++ /dev/null @@ -1,77 +0,0 @@ -"""This module contains the API endpoint to upload files.""" - -import logging -from typing import Self - -from docq.config import SpaceType -from docq.domain import SpaceKey -from docq.manage_documents import upload -from pydantic import ValidationError -from tornado.web import HTTPError, escape - -from web.api.base_handlers import BaseRequestHandler -from web.api.models import FileUploadRequestModel -from web.api.utils.auth_utils import authenticated -from web.utils.streamlit_application import st_app - -# Configuration -ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "md", "docx", "pptx", "xlsx"} - - -def allowed_file(filename: str) -> bool: - """Check if the file has an allowed extension.""" - return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS - - -@st_app.api_route("/api/v1/files/{}/upload") -class UploadFileHandler(BaseRequestHandler): - """Handle /api/v1/files/upload requests.""" - - @authenticated - def post(self: Self) -> None: - """Handle POST request to upload a file. - - Upload one of more files to a space. - """ - try: - requestPayload = FileUploadRequestModel.model_validate_json(self.request.body) - if "file" not in self.request.files: - raise HTTPError(400, reason="No file part") - - files = self.request.files["file"] - for file_info in files: - filename = file_info["filename"] - - if filename == "": - raise HTTPError(400, reason="No selected file") - - if not allowed_file(filename): - raise HTTPError( - 400, - reason=f"File type not allowed. Allowed file types are: {', '.join(list(ALLOWED_EXTENSIONS))}", - ) - - # Secure the filename - filename = escape.native_str(escape.url_escape(filename)) - space_type = SpaceType(requestPayload.space_type) - - if not space_type: - raise HTTPError(400, reason="Invalid space type") - - space_key = SpaceKey( - org_id=self.selected_org_id, - id_=requestPayload.space_id, - type_=space_type, - ) - - upload(filename=filename, content=file_info["body"], space=space_key) - - self.set_status(201) # 201 Created - self.write({"message": f"{len(files)} File(s) successfully uploaded"}) - except ValidationError as e: - raise HTTPError(400, reason="Bad request") from e - except HTTPError as e: - raise e - except Exception as e: - logging.error("Error: ", e) - raise HTTPError(500, reason="Internal server error", log_message=f"Error: {str(e)}") from e diff --git a/web/api/index_handler.py b/web/api/index_handler.py index 8129f887..e059fbbc 100644 --- a/web/api/index_handler.py +++ b/web/api/index_handler.py @@ -19,6 +19,7 @@ class name: route replace capitalise route segments remove `/` and `_`. Example: chat_completion_handler, # noqa: F401 DO NOT REMOVE hello_handler, # noqa: F401 DO NOT REMOVE rag_completion_handler, # noqa: F401 DO NOT REMOVE + spaces_files_handler, # noqa: F401 DO NOT REMOVE spaces_handler, # noqa: F401 DO NOT REMOVE threads_handler, # noqa: F401 DO NOT REMOVE token_handler, # noqa: F401 DO NOT REMOVE @@ -30,6 +31,7 @@ class name: route replace capitalise route segments remove `/` and `_`. Example: "hello_handler", "rag_completion_handler", "spaces_handler", + "spaces_files_handler", "threads_handler", "token_handler", "index_handler", diff --git a/web/api/models.py b/web/api/models.py index 1a743084..9870e8d2 100644 --- a/web/api/models.py +++ b/web/api/models.py @@ -33,6 +33,7 @@ class ThreadModel(CamelModel): id_: int = Field(..., alias="id") topic: str + space_id: Optional[int] = None created_at: str @@ -54,6 +55,13 @@ class SpaceModel(CamelModel): created_at: str updated_at: str +class FileModel(CamelModel): + """Model for a File.""" + + link: str + indexed_on: int + size: int + class BaseResponseModel(CamelModel, ABC): """All HTTP API response models should inherit from this class.""" @@ -97,12 +105,12 @@ class SpacesResponseModel(BaseResponseModel): response: list[SpaceModel] +class SpaceFilesResponseModel(BaseResponseModel): + """HTTP response model for a list of files in a Space.""" + + response: list[FileModel] + + class ThreadPostRequestModel(CamelModel): """Pydantic model for the request body.""" topic: str - -class FileUploadRequestModel(CamelModel): - """Request model for file upload.""" - - space_id: int - space_type: SPACE_TYPE diff --git a/web/api/spaces_files_handler.py b/web/api/spaces_files_handler.py new file mode 100644 index 00000000..37c44279 --- /dev/null +++ b/web/api/spaces_files_handler.py @@ -0,0 +1,137 @@ +"""This module contains the API endpoint to upload files.""" + +import logging +from typing import Self + +from docq import manage_spaces +from docq.config import SpaceType +from docq.domain import SpaceKey +from docq.manage_documents import upload +from docq.manage_spaces import get_shared_space +from pydantic import ValidationError +from tornado.web import HTTPError, escape + +from web.api.base_handlers import BaseRequestHandler +from web.api.utils.auth_utils import authenticated +from web.utils.streamlit_application import st_app + +from .models import FileModel, SpaceFilesResponseModel + +# Configuration +ALLOWED_EXTENSIONS = {"txt", "pdf", "png", "jpg", "jpeg", "gif", "md", "docx", "pptx", "xlsx"} + + +def allowed_file(filename: str) -> bool: + """Check if the file has an allowed extension.""" + return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + + +@st_app.api_route("/api/v1/spaces/{space_id}/files/upload") +class UploadFileHandler(BaseRequestHandler): + """Handle /api/v1/spaces/{space_id}/files/upload requests.""" + + @authenticated + def post(self: Self, space_id: int) -> None: + """Handle POST request to upload a file. + + Upload one of more files to a space. + """ + try: + # print("Body:", self.request.body) + + # content_type = self.request.headers.get("Content-Type", "") + # if content_type.startswith("multipart/form-data"): + # fields = cgi.FieldStorage( + # fp=io.BytesIO(self.request.body), headers=self.request.headers, environ={"REQUEST_METHOD": "POST"} + # ) + # print("Fields:", fields) + + # if "docq_files" in fields: + # file_item = fields["docq_files"] + # # Process the file data as it's being read + # while True: + # chunk = file_item.file.read(8192) # Read in 8KB chunks + # if not chunk: + # break + + if "docq_files" not in self.request.files: + raise HTTPError(400, reason="No file part") + + files = self.request.files["docq_files"] # 'file' is what every we want the form field to be called. + for file_info in files: + filename = file_info["filename"] + + if filename == "": + raise HTTPError(400, reason="No selected file") + + if not allowed_file(filename): + raise HTTPError( + 400, + reason=f"File type not allowed. Allowed file types are: {', '.join(list(ALLOWED_EXTENSIONS))}", + ) + + # Secure the filename + filename = escape.native_str(escape.url_escape(filename)) + + space = get_shared_space(space_id, self.selected_org_id) + if not space: + raise HTTPError(404, reason=f"Space id {space_id} not found") + + print("Space_type:", space[7]) + space_type = SpaceType(str(space[7]).lower()) + + if not space_type: + raise HTTPError(400, reason="Invalid space type") + + space_key = SpaceKey( + org_id=self.selected_org_id, + id_=space_id, + type_=space_type, + ) + + upload(filename=filename, content=file_info["body"], space=space_key) + + self.set_status(201) # 201 Created + self.write({"message": f"{len(files)} File(s) successfully uploaded"}) + except ValidationError as e: + raise HTTPError(400, reason="Bad request") from e + except HTTPError as e: + raise e + except Exception as e: + logging.error("Error: ", e) + raise HTTPError(500, reason="Internal server error", log_message=f"Error: {str(e)}") from e + + +@st_app.api_route("/api/v1/spaces/{space_id}/files") +class SpaceFilesHandler(BaseRequestHandler): + """Handle /api/v1/spaces/{space_id}/files requests.""" + + @authenticated + def get(self: Self, space_id: int) -> None: + """Handle GET request. Get all files in a space.""" + space = get_shared_space(space_id, self.selected_org_id) + if not space: + raise HTTPError(404, reason=f"Space id {space_id} not found") + + print("Space_type:", space[7]) + space_type = SpaceType(str(space[7]).lower()) + + space_key = SpaceKey( + org_id=self.selected_org_id, + id_=space_id, + type_=space_type, + ) + + files = manage_spaces.list_documents(space_key) + + files_response = SpaceFilesResponseModel( + response=[ + FileModel( + size=file.size, + link=file.link, + indexed_on=file.indexed_on, + ) + for file in files + ] + ) + self.write(files_response.model_dump(by_alias=True)) diff --git a/web/api/spaces_handler.py b/web/api/spaces_handler.py index 9ad4a622..7c968436 100644 --- a/web/api/spaces_handler.py +++ b/web/api/spaces_handler.py @@ -5,7 +5,6 @@ import docq.manage_spaces as m_spaces import docq.run_queries as rq from docq.data_source.list import SpaceDataSources -from docq.manage_documents import upload from pydantic import BaseModel, ValidationError from tornado.web import HTTPError @@ -89,9 +88,8 @@ def get(self: Self) -> None: try: space_type = self.get_query_argument("space_type", None) - print("space_type", space_type) spaces = m_spaces.list_space(self.selected_org_id, space_type) - print("spaces", spaces) + space_model_list: list[SpaceModel] = [_map_to_space_model(space) for space in spaces] spaces_response_model = SpacesResponseModel(response=space_model_list) @@ -126,23 +124,23 @@ def update(self: Self, space_id: int) -> None: raise HTTPError(501, reason="Not implemented") -@st_app.api_route("/api/v1/spaces/{space_id}/files/upload") -class SpaceFileUploadHandler(BaseRequestHandler): - """Handle /api/spaces/{space_id}/files/upload requests.""" +# @st_app.api_route("/api/v1/spaces/{space_id}/files/upload") +# class SpaceFileUploadHandler(BaseRequestHandler): +# """Handle /api/spaces/{space_id}/files/upload requests.""" - __FILE_SIZE_LIMIT = 200 * 1024 * 1024 - __FILE_NAME_LIMIT = 100 +# __FILE_SIZE_LIMIT = 200 * 1024 * 1024 +# __FILE_NAME_LIMIT = 100 - @authenticated - def post(self: Self, space_id: int) -> None: - """Handle POST request.""" - space = get_space(self.selected_org_id, space_id) - fileinfo = self.request.files["filearg"][0] - fname = fileinfo["filename"] +# @authenticated +# def post(self: Self, space_id: int) -> None: +# """Handle POST request.""" +# space = get_space(self.selected_org_id, space_id) +# fileinfo = self.request.files["filearg"][0] +# fname = fileinfo["filename"] - if len(fileinfo["body"]) > self.__FILE_SIZE_LIMIT: - raise HTTPError(400, reason="File too large", log_message="File size exceeds the limit") +# if len(fileinfo["body"]) > self.__FILE_SIZE_LIMIT: +# raise HTTPError(400, reason="File too large", log_message="File size exceeds the limit") - upload(fname[: self.__FILE_NAME_LIMIT], fileinfo["body"], space) - self.set_status(201) # 201 Created - self.write(f"File {fname} is uploaded successfully.") +# upload(fname[: self.__FILE_NAME_LIMIT], fileinfo["body"], space) +# self.set_status(201) # 201 Created +# self.write(f"File {fname} is uploaded successfully.") diff --git a/web/api/threads_handler.py b/web/api/threads_handler.py index 3ca8fbe3..c04ad306 100644 --- a/web/api/threads_handler.py +++ b/web/api/threads_handler.py @@ -1,4 +1,5 @@ """Handle chat and rag threads.""" +import logging from datetime import datetime from typing import Self @@ -31,7 +32,12 @@ def _get_thread_object(result: tuple) -> dict: # TODO: when we refactor the data layer to return data model classes instead of tuples, we can remove this function - return {"id": result[0], "topic": result[1], "created_at": str(result[2])} + return { + "id": result[0], + "topic": result[1], + "space_id": result[3], + "created_at": str(result[2]), + } @st_app.api_route("/api/v1/{feature}/threads") @@ -58,6 +64,9 @@ def get(self: Self, feature_: FEATURE) -> None: try: threads = rq.list_thread_history(feature) + # print("threads:") + # for thread in threads: + # print(thread[0], thread[1], thread[2], thread[3]) thread_response = ( [ThreadModel(**_get_thread_object(threads[i])) for i in range(len(threads))] if len(threads) > 0 else [] ) @@ -69,6 +78,7 @@ def get(self: Self, feature_: FEATURE) -> None: self.write(response) except ValidationError as e: + logging.error("ValidationError: ", e) raise HTTPError(status_code=400, reason="Bad request", log_message=str(e)) from e @authenticated @@ -119,7 +129,10 @@ def get(self: Self, feature_: FEATURE, thread_id: int) -> None: try: thread = rq.list_thread_history(feature, thread_id) - thread_response = ThreadModel(**_get_thread_object(thread[0])) if len(thread) > 0 else None + thread_space_id = get_thread_space(self.selected_org_id, thread_id).id_ + thread_response = ( + ThreadModel(**_get_thread_object(thread[0]), space_id=thread_space_id) if len(thread) > 0 else None + ) if not thread_response: raise HTTPError(404, reason="Thread not found.") diff --git a/web/utils/streamlit_application.py b/web/utils/streamlit_application.py index 66ff8016..07266c13 100644 --- a/web/utils/streamlit_application.py +++ b/web/utils/streamlit_application.py @@ -32,6 +32,11 @@ class StreamlitApplication: def get_singleton_instance(self: Self) -> Application: """Return the singleton instance of the Streamlit Tornado Application object.""" + # logging.getLogger("tornado.access").setLevel(logging.DEBUG) + # logging.getLogger("tornado.access").setLevel(logging.DEBUG) + # logging.getLogger("tornado.application").setLevel(logging.DEBUG) + # logging.getLogger("tornado.general").setLevel(logging.DEBUG) + # enable_pretty_logging() if not self.__singleton_instance: self.__singleton_instance = next(o for o in gc.get_referrers(Application) if o.__class__ is Application) return self.__singleton_instance From fd84c08df2cf06f4adf66b7bebc6ad3aea8019b6 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 12 Oct 2024 01:36:50 +0100 Subject: [PATCH 12/14] fix: enum type conversion bug introduced when we refactored to intro the Assistant model --- source/docq/manage_assistants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/docq/manage_assistants.py b/source/docq/manage_assistants.py index 90c776e8..a88962cc 100644 --- a/source/docq/manage_assistants.py +++ b/source/docq/manage_assistants.py @@ -367,7 +367,7 @@ def get_assistant(assistant_scoped_id: str, org_id: Optional[int]) -> Assistant: return Assistant( key=str(row[0]), name=row[1], - type=AssistantType(row[2]), + type=AssistantType(row[2].capitalize()), archived=row[3], system_message_content=row[4], user_prompt_template_content=row[5], From 4228a8af2ea8b163304679ddfe76d20ee4286a1b Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 12 Oct 2024 16:57:50 +0100 Subject: [PATCH 13/14] chore: clean up --- source/docq/manage_documents.py | 11 +++++++---- web/api/base_handlers.py | 2 +- web/api/spaces_files_handler.py | 3 ++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/source/docq/manage_documents.py b/source/docq/manage_documents.py index d6236c20..222d6fa8 100644 --- a/source/docq/manage_documents.py +++ b/source/docq/manage_documents.py @@ -10,10 +10,10 @@ from llama_index.core.schema import NodeWithScore from streamlit import runtime -from .data_source.main import DocumentMetadata -from .domain import SpaceKey -from .manage_spaces import reindex -from .support.store import get_upload_dir, get_upload_file +from docq.data_source.main import DocumentMetadata +from docq.domain import SpaceKey +from docq.manage_spaces import reindex +from docq.support.store import get_upload_dir, get_upload_file def upload(filename: str, content: bytes, space: SpaceKey) -> None: @@ -21,6 +21,9 @@ def upload(filename: str, content: bytes, space: SpaceKey) -> None: with open(get_upload_file(space, filename), "wb") as f: f.write(content) + # TODO: refactor to only kick off re-indexing the saved file not the whole space. + # TODO: add error handling and return success/failure status. + # TODO: to handle large files and resumable uploads, switch content to BinaryIO and then write chunks in a loop. reindex(space) diff --git a/web/api/base_handlers.py b/web/api/base_handlers.py index bc90cdca..91993992 100644 --- a/web/api/base_handlers.py +++ b/web/api/base_handlers.py @@ -26,7 +26,7 @@ def check_xsrf_cookie(self: Self) -> bool: """Override the XSRF cookie check.""" # If `True`, POST, PUT, and DELETE are block unless the `_xsrf` cookie is set. # Safe with token based authN - print("check_xsrf_cookie() called") + # print("check_xsrf_cookie() called") return False @property diff --git a/web/api/spaces_files_handler.py b/web/api/spaces_files_handler.py index 37c44279..31b2c261 100644 --- a/web/api/spaces_files_handler.py +++ b/web/api/spaces_files_handler.py @@ -77,7 +77,6 @@ def post(self: Self, space_id: int) -> None: if not space: raise HTTPError(404, reason=f"Space id {space_id} not found") - print("Space_type:", space[7]) space_type = SpaceType(str(space[7]).lower()) if not space_type: @@ -89,9 +88,11 @@ def post(self: Self, space_id: int) -> None: type_=space_type, ) + # save the file in the correct folder related to the space. upload(filename=filename, content=file_info["body"], space=space_key) self.set_status(201) # 201 Created + # TODO: add a response model. ideally should have success/fail status for each file. self.write({"message": f"{len(files)} File(s) successfully uploaded"}) except ValidationError as e: raise HTTPError(400, reason="Bad request") from e From d6cacf93786f278584df402e66c916a298f00b80 Mon Sep 17 00:00:00 2001 From: Janaka Abeywardhana Date: Sat, 12 Oct 2024 16:59:31 +0100 Subject: [PATCH 14/14] build: bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 50654ae0..09fa7a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "docq" -version = "0.13.10" +version = "0.13.11" description = "Docq.AI - Your private ChatGPT alternative. Securely unlock knowledge from confidential documents." authors = ["Docq.AI Team "] maintainers = ["Docq.AI Team "]