diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 2a13043..fbe05e6 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -7,7 +7,6 @@ from langchain.chains import create_extraction_chain from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ map_rerank_prompt -from langchain.evaluation import PairwiseEmbeddingDistanceEvalChain, load_evaluator, EmbeddingDistance from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever from langchain.schema import Document @@ -273,7 +272,7 @@ def query_storage_and_embeddings(self, query: str, doc_id, context_size=4) -> Li """ db = self.data_storage.embeddings_dict[doc_id] retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") - relevant_documents = retriever.get_relevant_documents(query) + relevant_documents = retriever.invoke(query) return relevant_documents @@ -284,7 +283,7 @@ def analyse_query(self, query, doc_id, context_size=4): # search_type="similarity_score_threshold" # ) retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") - relevant_documents = retriever.get_relevant_documents(query) + relevant_documents = retriever.invoke(query) relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] for doc in relevant_documents] @@ -338,7 +337,7 @@ def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list): def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list): db = self.data_storage.embeddings_dict[doc_id] retriever = db.as_retriever(search_kwargs={"k": context_size}) - relevant_documents = retriever.get_relevant_documents(query) + relevant_documents = retriever.invoke(query) relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] for doc in relevant_documents] @@ -361,7 +360,7 @@ def get_full_context_by_document(self, doc_id): def _get_context_multiquery(self, doc_id, query, context_size=4): db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm) - relevant_documents = multi_query_retriever.get_relevant_documents(query) + relevant_documents = multi_query_retriever.invoke(query) return relevant_documents def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False): diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index e8cc8e3..3933013 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -148,15 +148,15 @@ def parse_grobid_xml(self, text, coordinates=False): soup = BeautifulSoup(text, 'xml') blocks_header = get_xml_nodes_header(soup, use_paragraphs=True) - passages.append({ - "text": f"authors: {biblio['authors']}", - "type": passage_type, - "section": "
", - "subSection": "", - "passage_id": "hauthors", - "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in - blocks_header['authors']]) - }) + # passages.append({ + # "text": f"authors: {biblio['authors']}", + # "type": passage_type, + # "section": "
", + # "subSection": "", + # "passage_id": "hauthors", + # "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in + # blocks_header['authors']]) + # }) passages.append({ "text": self.post_process(" ".join([node.text for node in blocks_header['title']])), diff --git a/requirements.txt b/requirements.txt index f0ef0f9..f85b6d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,14 +16,17 @@ dateparser # LLM chromadb==0.4.24 -tiktoken==0.6.0 -openai==1.16.2 -langchain==0.1.14 -langchain-core==0.1.40 +tiktoken==0.7.0 +openai==1.42.0 +langchain==0.2.14 +langchain-core==0.2.34 +langchain-openai==0.1.22 +langchain-huggingface==0.0.3 +langchain-community==0.2.12 typing-inspect==0.9.0 typing_extensions==4.11.0 pydantic==2.6.4 sentence_transformers==2.6.1 -streamlit-pdf-viewer==0.0.17 +streamlit-pdf-viewer==0.0.18-dev1 umap-learn plotly \ No newline at end of file diff --git a/streamlit_app.py b/streamlit_app.py index edbcea9..cbd3a48 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -6,10 +6,11 @@ import dotenv from grobid_quantities.quantities import QuantitiesAPI from langchain.memory import ConversationBufferWindowMemory -from langchain_community.chat_models.openai import ChatOpenAI -from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings -from langchain_community.embeddings.openai import OpenAIEmbeddings +from langchain_community.callbacks import PromptLayerCallbackHandler +from langchain_community.chat_models import ChatOpenAI from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_openai import OpenAIEmbeddings from streamlit_pdf_viewer import pdf_viewer from document_qa.ner_client_generic import NERClientGeneric @@ -97,6 +98,9 @@ if 'embeddings' not in st.session_state: st.session_state['embeddings'] = None +if 'scroll_to_first_annotation' not in st.session_state: + st.session_state['scroll_to_first_annotation'] = False + st.set_page_config( page_title="Scientific Document Insights Q/A", page_icon="📝", @@ -169,7 +173,8 @@ def init_qa(model, embeddings_name=None, api_key=None): repo_id=OPEN_MODELS[model], temperature=0.01, max_new_tokens=4092, - model_kwargs={"max_length": 8192} + model_kwargs={"max_length": 8192}, + callbacks=[PromptLayerCallbackHandler(pl_tags=[model, "document-qa"])] ) embeddings = HuggingFaceEmbeddings( model_name=OPEN_EMBEDDINGS[embeddings_name]) @@ -233,8 +238,8 @@ def play_old_messages(container): # is_api_key_provided = st.session_state['api_key'] with st.sidebar: - st.title("📝 Scientific Document Insights Q/A") - st.subheader("Upload a scientific article in PDF, ask questions, get insights.") + st.title("📝 Document Q/A") + st.markdown("Upload a scientific article in PDF, ask questions, get insights.") st.markdown( ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") @@ -301,14 +306,14 @@ def play_old_messages(container): # help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.", # disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None) -left_column, right_column = st.columns([1, 1]) +left_column, right_column = st.columns([5, 4]) right_column = right_column.container(border=True) left_column = left_column.container(border=True) with right_column: uploaded_file = st.file_uploader( - "Upload an article", - type=("pdf", "txt"), + "Upload a scientific article", + type=("pdf"), on_change=new_file, disabled=st.session_state['model'] is not None and st.session_state['model'] not in st.session_state['api_keys'], @@ -343,6 +348,10 @@ def play_old_messages(container): "relevant paragraphs to the question in the paper. " "Question coefficient attempt to estimate how effective the question will be answered." ) + st.session_state['scroll_to_first_annotation'] = st.checkbox( + "Scroll to context", + help='The PDF viewer will automatically scroll to the first relevant passage in the document.' + ) st.session_state['ner_processing'] = st.checkbox( "Identify materials and properties.", help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.' @@ -415,7 +424,6 @@ def generate_color_gradient(num_elements): with right_column: if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id: - # messages.chat_message("user").markdown(question) st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) for message in st.session_state.messages: @@ -491,5 +499,6 @@ def generate_color_gradient(num_elements): input=st.session_state['binary'], annotation_outline_size=2, annotations=st.session_state['annotations'], - render_text=True + render_text=True, + scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state['scroll_to_first_annotation']) else None )