diff --git a/requirements.txt b/requirements.txt index 26211e8..e8db7d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ grobid_tei_xml==0.1.3 tqdm==4.66.2 pyyaml==6.0.1 pytest==8.1.1 -streamlit==1.33.0 +streamlit==1.36.0 lxml Beautifulsoup4 python-dotenv diff --git a/streamlit_app.py b/streamlit_app.py index 6df586c..6472b5e 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -42,8 +42,6 @@ 'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral' } -DISABLE_MEMORY = ['zephyr-7b-beta'] - if 'rqa' not in st.session_state: st.session_state['rqa'] = {} @@ -108,36 +106,6 @@ } ) -css_modify_left_column = ''' - -''' -css_modify_right_column = ''' - -''' -css_disable_scrolling_container = ''' - -''' - - -# st.markdown(css_lock_column_fixed, unsafe_allow_html=True) -# st.markdown(css2, unsafe_allow_html=True) - def new_file(): st.session_state['loaded_embeddings'] = None @@ -188,7 +156,7 @@ def init_qa(model, embeddings_name=None, api_key=None): ) embeddings = HuggingFaceEmbeddings( model_name=OPEN_EMBEDDINGS[embeddings_name]) - st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None + # st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None else: st.error("The model was not loaded properly. Try reloading. ") st.stop() @@ -233,23 +201,27 @@ def get_file_hash(fname): return hash_md5.hexdigest() -def play_old_messages(): +def play_old_messages(container): if st.session_state['messages']: for message in st.session_state['messages']: if message['role'] == 'user': - with st.chat_message("user"): - st.markdown(message['content']) + container.chat_message("user").markdown(message['content']) elif message['role'] == 'assistant': - with st.chat_message("assistant"): - if mode == "LLM": - st.markdown(message['content'], unsafe_allow_html=True) - else: - st.write(message['content']) + if mode == "LLM": + container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True) + else: + container.chat_message("assistant").write(message['content']) # is_api_key_provided = st.session_state['api_key'] with st.sidebar: + st.title("📝 Scientific Document Insights Q/A") + st.subheader("Upload a scientific article in PDF, ask questions, get insights.") + st.markdown( + ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") + + st.divider() st.session_state['model'] = model = st.selectbox( "Model:", options=OPENAI_MODELS + list(OPEN_MODELS.keys()), @@ -305,22 +277,18 @@ def play_old_messages(): # else: # is_api_key_provided = st.session_state['api_key'] - st.button( - 'Reset chat memory.', - key="reset-memory-button", - on_click=clear_memory, - help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.", - disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None) + # st.button( + # 'Reset chat memory.', + # key="reset-memory-button", + # on_click=clear_memory, + # help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.", + # disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None) left_column, right_column = st.columns([1, 1]) +right_column = right_column.container(height=600, border=False) +left_column = left_column.container(height=600, border=False) with right_column: - st.title("📝 Scientific Document Insights Q/A") - st.subheader("Upload a scientific article in PDF, ask questions, get insights.") - - st.markdown( - ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") - uploaded_file = st.file_uploader( "Upload an article", type=("pdf", "txt"), @@ -330,11 +298,14 @@ def play_old_messages(): help="The full-text is extracted using Grobid." ) -question = st.chat_input( - "Ask something about the article", - # placeholder="Can you give me a short summary?", - disabled=not uploaded_file -) + placeholder = st.empty() + messages = st.container(height=300, border=False) + + question = st.chat_input( + "Ask something about the article", + # placeholder="Can you give me a short summary?", + disabled=not uploaded_file + ) query_modes = { "llm": "LLM Q/A", @@ -355,6 +326,10 @@ def play_old_messages(): "relevant paragraphs to the question in the paper. " "Question coefficient attempt to estimate how effective the question will be answered." ) + st.session_state['ner_processing'] = st.checkbox( + "Identify materials and properties.", + help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.' + ) # Add a checkbox for showing annotations # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True) @@ -372,11 +347,6 @@ def play_old_messages(): help="Number of chunks to consider when answering a question", disabled=not uploaded_file) - st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.") - st.markdown( - 'The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.', - unsafe_allow_html=True) - st.divider() st.header("Documentation") @@ -403,7 +373,7 @@ def play_old_messages(): st.error("Before uploading a document, you must enter the API key. ") st.stop() - with right_column: + with left_column: with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'): binary = uploaded_file.getvalue() tmp_file = NamedTemporaryFile() @@ -416,8 +386,6 @@ def play_old_messages(): st.session_state['loaded_embeddings'] = True st.session_state.messages = [] - # timestamp = datetime.utcnow() - def rgb_to_hex(rgb): return "#{:02x}{:02x}{:02x}".format(*rgb) @@ -439,41 +407,21 @@ def generate_color_gradient(num_elements): with right_column: - # css = ''' - # - # ''' - # st.markdown(css, unsafe_allow_html=True) - - # st.markdown( - # """ - # - # """, - # unsafe_allow_html=True, - # ) - if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id: for message in st.session_state.messages: - with st.chat_message(message["role"]): + with messages.chat_message(message["role"]): if message['mode'] == "llm": - st.markdown(message["content"], unsafe_allow_html=True) + messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) elif message['mode'] == "embeddings": - st.write(message["content"]) + messages.chat_message(message["role"]).write(message["content"]) if message['mode'] == "question_coefficient": - st.markdown(message["content"], unsafe_allow_html=True) + messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) if model not in st.session_state['rqa']: st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `") st.stop() - with st.chat_message("user"): - st.markdown(question) - st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) + messages.chat_message("user").markdown(question) + st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) text_response = None if mode == "embeddings": @@ -484,12 +432,13 @@ def generate_color_gradient(num_elements): context_size=context_size ) elif mode == "llm": - with st.spinner("Generating LLM response..."): - _, text_response, coordinates = st.session_state['rqa'][model].query_document( - question, - st.session_state.doc_id, - context_size=context_size - ) + with placeholder: + with st.spinner("Generating LLM response..."): + _, text_response, coordinates = st.session_state['rqa'][model].query_document( + question, + st.session_state.doc_id, + context_size=context_size + ) elif mode == "question_coefficient": with st.spinner("Estimate question/context relevancy..."): @@ -511,28 +460,28 @@ def generate_color_gradient(num_elements): if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.") - with st.chat_message("assistant"): - if mode == "llm": - if st.session_state['ner_processing']: - with st.spinner("Processing NER on LLM response..."): - entities = gqa.process_single_text(text_response) - decorated_text = decorate_text_with_annotations(text_response.strip(), entities) - decorated_text = decorated_text.replace('class="label material"', 'style="color:green"') - decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text) - text_response = decorated_text - st.markdown(text_response, unsafe_allow_html=True) - else: - st.write(text_response) - st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response}) + if mode == "llm": + if st.session_state['ner_processing']: + with st.spinner("Processing NER on LLM response..."): + entities = gqa.process_single_text(text_response) + decorated_text = decorate_text_with_annotations(text_response.strip(), entities) + decorated_text = decorated_text.replace('class="label material"', 'style="color:green"') + decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text) + text_response = decorated_text + messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True) + else: + messages.chat_message("assistant").write(text_response) + st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response}) elif st.session_state.loaded_embeddings and st.session_state.doc_id: - play_old_messages() + play_old_messages(messages) with left_column: if st.session_state['binary']: pdf_viewer( input=st.session_state['binary'], - annotation_outline_size=1, + annotation_outline_size=2, annotations=st.session_state['annotations'], - render_text=True + render_text=True, + height=700 )