diff --git a/autollm/__init__.py b/autollm/__init__.py index 7c32b9d8..d7c61a62 100644 --- a/autollm/__init__.py +++ b/autollm/__init__.py @@ -4,7 +4,7 @@ and vector databases, along with various utility functions. """ -__version__ = '0.1.3' +__version__ = '0.1.4' __author__ = 'safevideo' __license__ = 'AGPL-3.0' diff --git a/autollm/auto/query_engine.py b/autollm/auto/query_engine.py index 959fed8e..964847b0 100644 --- a/autollm/auto/query_engine.py +++ b/autollm/auto/query_engine.py @@ -3,7 +3,7 @@ from llama_index import Document, ServiceContext, VectorStoreIndex from llama_index.embeddings.utils import EmbedType from llama_index.indices.query.base import BaseQueryEngine -from llama_index.prompts.base import PromptTemplate +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode @@ -24,11 +24,11 @@ def create_query_engine( llm_api_base: Optional[str] = None, # service_context_params system_prompt: str = None, - query_wrapper_prompt: str = None, + query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = True, embed_model: Union[str, EmbedType] = "default", # ["default", "local"] chunk_size: Optional[int] = 512, - chunk_overlap: Optional[int] = 200, + chunk_overlap: Optional[int] = 100, context_window: Optional[int] = None, enable_title_extractor: bool = False, enable_summary_extractor: bool = False, @@ -61,7 +61,7 @@ def create_query_engine( llm_temperature (float): The temperature to use for the LLM. llm_api_base (str): The API base to use for the LLM. system_prompt (str): The system prompt to use for the query engine. - query_wrapper_prompt (str): The query wrapper prompt to use for the query engine. + query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine. enable_cost_calculator (bool): Flag to enable cost calculator logging. embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI, "local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large) @@ -133,10 +133,15 @@ def create_query_engine( refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE) else: refine_prompt_template = None + + # Convert query_wrapper_prompt to PromptTemplate if it is a string + if isinstance(query_wrapper_prompt, str): + query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt) response_synthesizer = get_response_synthesizer( service_context=service_context, - response_mode=response_mode, + text_qa_template=query_wrapper_prompt, refine_template=refine_prompt_template, + response_mode=response_mode, structured_answer_filtering=structured_answer_filtering) return vector_store_index.as_query_engine( @@ -213,7 +218,7 @@ def from_defaults( llm_temperature: float = 0.1, # service_context_params system_prompt: str = None, - query_wrapper_prompt: str = None, + query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = True, embed_model: Union[str, EmbedType] = "default", # ["default", "local"] chunk_size: Optional[int] = 512, @@ -246,7 +251,7 @@ def from_defaults( llm_temperature (float): The temperature to use for the LLM. llm_api_base (str): The API base to use for the LLM. system_prompt (str): The system prompt to use for the query engine. - query_wrapper_prompt (str): The query wrapper prompt to use for the query engine. + query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine. enable_cost_calculator (bool): Flag to enable cost calculator logging. embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI, "local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large) diff --git a/autollm/auto/service_context.py b/autollm/auto/service_context.py index ff45e794..9aad5294 100644 --- a/autollm/auto/service_context.py +++ b/autollm/auto/service_context.py @@ -65,11 +65,14 @@ def from_defaults( """ if not system_prompt and not query_wrapper_prompt: system_prompt, query_wrapper_prompt = set_default_prompt_template() - # Convert system_prompt to ChatPromptTemplate if it is a string + # Convert query_wrapper_prompt to PromptTemplate if it is a string if isinstance(query_wrapper_prompt, str): query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt) callback_manager: CallbackManager = kwargs.get('callback_manager', CallbackManager()) + kwargs.pop( + 'callback_manager', None) # Make sure callback_manager is not passed to ServiceContext twice + if enable_cost_calculator: llm_model_name = llm.metadata.model_name if not "default" else "gpt-3.5-turbo" callback_manager.add_handler(CostCalculatingHandler(model_name=llm_model_name, verbose=True))