diff --git a/requirements.txt b/requirements.txt
index 26211e8..e8db7d3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,7 +7,7 @@ grobid_tei_xml==0.1.3
tqdm==4.66.2
pyyaml==6.0.1
pytest==8.1.1
-streamlit==1.33.0
+streamlit==1.36.0
lxml
Beautifulsoup4
python-dotenv
diff --git a/streamlit_app.py b/streamlit_app.py
index 6df586c..6472b5e 100644
--- a/streamlit_app.py
+++ b/streamlit_app.py
@@ -42,8 +42,6 @@
'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral'
}
-DISABLE_MEMORY = ['zephyr-7b-beta']
-
if 'rqa' not in st.session_state:
st.session_state['rqa'] = {}
@@ -108,36 +106,6 @@
}
)
-css_modify_left_column = '''
-
-'''
-css_modify_right_column = '''
-
-'''
-css_disable_scrolling_container = '''
-
-'''
-
-
-# st.markdown(css_lock_column_fixed, unsafe_allow_html=True)
-# st.markdown(css2, unsafe_allow_html=True)
-
def new_file():
st.session_state['loaded_embeddings'] = None
@@ -188,7 +156,7 @@ def init_qa(model, embeddings_name=None, api_key=None):
)
embeddings = HuggingFaceEmbeddings(
model_name=OPEN_EMBEDDINGS[embeddings_name])
- st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
+ # st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()
@@ -233,23 +201,27 @@ def get_file_hash(fname):
return hash_md5.hexdigest()
-def play_old_messages():
+def play_old_messages(container):
if st.session_state['messages']:
for message in st.session_state['messages']:
if message['role'] == 'user':
- with st.chat_message("user"):
- st.markdown(message['content'])
+ container.chat_message("user").markdown(message['content'])
elif message['role'] == 'assistant':
- with st.chat_message("assistant"):
- if mode == "LLM":
- st.markdown(message['content'], unsafe_allow_html=True)
- else:
- st.write(message['content'])
+ if mode == "LLM":
+ container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True)
+ else:
+ container.chat_message("assistant").write(message['content'])
# is_api_key_provided = st.session_state['api_key']
with st.sidebar:
+ st.title("📝 Scientific Document Insights Q/A")
+ st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
+ st.markdown(
+ ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
+
+ st.divider()
st.session_state['model'] = model = st.selectbox(
"Model:",
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
@@ -305,22 +277,18 @@ def play_old_messages():
# else:
# is_api_key_provided = st.session_state['api_key']
- st.button(
- 'Reset chat memory.',
- key="reset-memory-button",
- on_click=clear_memory,
- help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
- disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
+ # st.button(
+ # 'Reset chat memory.',
+ # key="reset-memory-button",
+ # on_click=clear_memory,
+ # help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
+ # disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
left_column, right_column = st.columns([1, 1])
+right_column = right_column.container(height=600, border=False)
+left_column = left_column.container(height=600, border=False)
with right_column:
- st.title("📝 Scientific Document Insights Q/A")
- st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
-
- st.markdown(
- ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
-
uploaded_file = st.file_uploader(
"Upload an article",
type=("pdf", "txt"),
@@ -330,11 +298,14 @@ def play_old_messages():
help="The full-text is extracted using Grobid."
)
-question = st.chat_input(
- "Ask something about the article",
- # placeholder="Can you give me a short summary?",
- disabled=not uploaded_file
-)
+ placeholder = st.empty()
+ messages = st.container(height=300, border=False)
+
+ question = st.chat_input(
+ "Ask something about the article",
+ # placeholder="Can you give me a short summary?",
+ disabled=not uploaded_file
+ )
query_modes = {
"llm": "LLM Q/A",
@@ -355,6 +326,10 @@ def play_old_messages():
"relevant paragraphs to the question in the paper. "
"Question coefficient attempt to estimate how effective the question will be answered."
)
+ st.session_state['ner_processing'] = st.checkbox(
+ "Identify materials and properties.",
+ help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.'
+ )
# Add a checkbox for showing annotations
# st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
@@ -372,11 +347,6 @@ def play_old_messages():
help="Number of chunks to consider when answering a question",
disabled=not uploaded_file)
- st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
- st.markdown(
- 'The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.',
- unsafe_allow_html=True)
-
st.divider()
st.header("Documentation")
@@ -403,7 +373,7 @@ def play_old_messages():
st.error("Before uploading a document, you must enter the API key. ")
st.stop()
- with right_column:
+ with left_column:
with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
binary = uploaded_file.getvalue()
tmp_file = NamedTemporaryFile()
@@ -416,8 +386,6 @@ def play_old_messages():
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []
- # timestamp = datetime.utcnow()
-
def rgb_to_hex(rgb):
return "#{:02x}{:02x}{:02x}".format(*rgb)
@@ -439,41 +407,21 @@ def generate_color_gradient(num_elements):
with right_column:
- # css = '''
- #
- # '''
- # st.markdown(css, unsafe_allow_html=True)
-
- # st.markdown(
- # """
- #
- # """,
- # unsafe_allow_html=True,
- # )
-
if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
for message in st.session_state.messages:
- with st.chat_message(message["role"]):
+ with messages.chat_message(message["role"]):
if message['mode'] == "llm":
- st.markdown(message["content"], unsafe_allow_html=True)
+ messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
elif message['mode'] == "embeddings":
- st.write(message["content"])
+ messages.chat_message(message["role"]).write(message["content"])
if message['mode'] == "question_coefficient":
- st.markdown(message["content"], unsafe_allow_html=True)
+ messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
if model not in st.session_state['rqa']:
st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
st.stop()
- with st.chat_message("user"):
- st.markdown(question)
- st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
+ messages.chat_message("user").markdown(question)
+ st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
text_response = None
if mode == "embeddings":
@@ -484,12 +432,13 @@ def generate_color_gradient(num_elements):
context_size=context_size
)
elif mode == "llm":
- with st.spinner("Generating LLM response..."):
- _, text_response, coordinates = st.session_state['rqa'][model].query_document(
- question,
- st.session_state.doc_id,
- context_size=context_size
- )
+ with placeholder:
+ with st.spinner("Generating LLM response..."):
+ _, text_response, coordinates = st.session_state['rqa'][model].query_document(
+ question,
+ st.session_state.doc_id,
+ context_size=context_size
+ )
elif mode == "question_coefficient":
with st.spinner("Estimate question/context relevancy..."):
@@ -511,28 +460,28 @@ def generate_color_gradient(num_elements):
if not text_response:
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
- with st.chat_message("assistant"):
- if mode == "llm":
- if st.session_state['ner_processing']:
- with st.spinner("Processing NER on LLM response..."):
- entities = gqa.process_single_text(text_response)
- decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
- decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
- decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
- text_response = decorated_text
- st.markdown(text_response, unsafe_allow_html=True)
- else:
- st.write(text_response)
- st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
+ if mode == "llm":
+ if st.session_state['ner_processing']:
+ with st.spinner("Processing NER on LLM response..."):
+ entities = gqa.process_single_text(text_response)
+ decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
+ decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
+ decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
+ text_response = decorated_text
+ messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True)
+ else:
+ messages.chat_message("assistant").write(text_response)
+ st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
- play_old_messages()
+ play_old_messages(messages)
with left_column:
if st.session_state['binary']:
pdf_viewer(
input=st.session_state['binary'],
- annotation_outline_size=1,
+ annotation_outline_size=2,
annotations=st.session_state['annotations'],
- render_text=True
+ render_text=True,
+ height=700
)