-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
348 lines (289 loc) · 12.1 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# @packages
from dataclasses import dataclass
from langchain.callbacks import get_openai_callback
from langchain.chains.conversation.memory import ConversationSummaryMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import DirectoryLoader, Docx2txtLoader, PyPDFLoader, TextLoader, UnstructuredExcelLoader
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from typing import Literal
import os
import shutil
import streamlit as st
# @scripts
from helpers import web_scrape_site
# Get the API key for the LLM & embedding model (If required, currently using OpenAI)
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# Set page configuration
st.set_page_config(page_title="Ask Chatbot")
@dataclass
class Message:
"""
Class to contain & track messages
"""
origin: Literal["human", "AI"]
message: str
def load_directory_documents(path_to_data):
"""
Loads & extracts text data within a local directory for a custom knowledge base.
Accepts the path_to_data.
Anticipates to load any .txt, .pdf, .csv, .docx, or .xlsx files in the directory.
Many loader classes available, see docs: https://python.langchain.com/docs/integrations/document_loaders/
Returns the text documents.
"""
# Define loaders
pdf_loader = DirectoryLoader(
path_to_data, glob="./*.pdf", loader_cls=PyPDFLoader, use_multithreading=True
)
txt_loader = DirectoryLoader(path_to_data, glob="./*.txt", loader_cls=TextLoader)
csv_loader = DirectoryLoader(path_to_data, glob="./*.csv", loader_cls=CSVLoader)
word_loader = DirectoryLoader(
path_to_data, glob="./*.docx", loader_cls=Docx2txtLoader
)
excel_loader = DirectoryLoader(
path_to_data, glob="./*.xlsx", loader_cls=UnstructuredExcelLoader
)
loaders = [
pdf_loader,
txt_loader,
csv_loader,
word_loader,
excel_loader,
]
documents = []
for loader in loaders:
documents.extend(loader.load())
if len(documents) == 0:
# Terminate the app if no data found
st.write(f"No data found within: {path_to_data}")
st.stop()
# Display results
filenames = []
filenames.append("Uploaded Documents:")
for doc in documents:
filename = doc.metadata.get("source")
# There will be multiple .csv docs per uploaded .csv
if filename not in filenames:
filenames.append(filename)
# Remove text before the last '/' character
cleaned_filenames = [filename.split("/")[-1] for filename in filenames]
# Combine the filenames into a single string for the HTML
filenames_combined = "<br>".join(cleaned_filenames)
div = f"""
<div class="chat-row">
<div class="chat-bubble human-bubble">​{filenames_combined}</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
return documents
def get_chunks(documents):
"""
Chunks the documents for vector embedding.
Accepts a list of documents & returns the split text.
"""
# Chunk data for embedding
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
return texts
def embed_and_persist_vectors(texts, persist_dir):
"""
Performs vector embeddings and persists the Chorma vector store to disk.
Accepts the split texts & a path to store the vectors.
Returns the persist_directory that was used.
Note: Different embedding models will output different vector dimensionalities,
require different resources, and have different performance characteristics.
Ensure vector compatibility with the LLM chatbot.
"""
try:
# Create a Chroma vector store and embed the split text
vector_store = Chroma.from_documents(
documents=texts, embedding=OpenAIEmbeddings(), persist_directory=persist_dir
)
# Persist the vector store to disk
vector_store.persist()
vector_store = None
except Exception as e:
print("An error occurred creating the vector store: ", e)
return persist_dir
def create_vector_store(persist_dir):
"""
Loads & returns the Chroma vector store persisted on disk.
Accepts the path where the vectors were stored & returns the vector store.
Note: If the knowledge base is unchanged, embedding & persisting the data first can be skipped.
Useful when embedding large amounts of data.
"""
# Loads the vector store persisted to disk
vector_store = Chroma(
persist_directory=persist_dir, embedding_function=OpenAIEmbeddings()
)
return vector_store
def load_and_process_data(path_to_data, persist_dir, remove_existing_persist_dir):
"""
Executes functions to load & process data, perform vector embedding, and persist results.
Accepts a path for the data to load, the persist directory,
and a boolean to clear the current vector store.
"""
# Cleans the existing persist directory
if os.path.exists(persist_dir) and remove_existing_persist_dir:
try:
# Delete files & subdirectories within the directory
absolute_path = os.path.abspath(persist_dir)
shutil.rmtree(absolute_path)
print(f"Deleted directory: {absolute_path}")
except Exception as e:
print(f"Error while deleting directory: {e}")
if not os.path.exists(persist_dir):
try:
os.makedirs(persist_dir)
except Exception as e:
print(f"Error making directory: {e}")
# Loads text from the documents
documents = load_directory_documents(path_to_data)
print(f"Loaded {len(documents)} documents")
# Splits data for vector embedding
texts = get_chunks(documents)
print(f"Split into {len(texts)} chunks")
# Performs vector embedding and persists results
persist_dir = embed_and_persist_vectors(texts, persist_dir)
print(f"Persisted vectors to: {persist_dir}")
def load_css():
"""
Retrieves page styles
"""
with open("./static/styles.css", "r") as f:
css = f"<style>{f.read()}</style>"
st.markdown(css, unsafe_allow_html=True)
def init_web_scraping():
"""
Executes helper functions for web scrapping the text from a website.
Default behavior will also search & scrape any links with an 'a' tag on the website.
"""
demo_website_url = "https://blog.langchain.dev/graph-based-metadata-filtering-for-improving-vector-search-in-rag-applications/"
output_folder_name = "data"
web_scrape_site(demo_website_url, output_folder_name)
def initialize_session_state():
"""
Creates session state for convo with the LLM
"""
# Define chat history
if "history" not in st.session_state:
st.session_state.history = []
# Define a token count
if "token_count" not in st.session_state:
st.session_state.token_count = 0
# Define vars to ensure a block of code is run only once
if "web_scraping" not in st.session_state:
st.session_state.web_scraping = False
if "load_and_process" not in st.session_state:
st.session_state.load_and_process = False
# Define a conversation chain
if "conversation" not in st.session_state:
# Path to the data to process
path_to_data = "./data/"
# Name for the local vector store
db_name = "demo"
# Directory to persist/load the vector store
persist_dir = f"chroma-db_{db_name}"
# Web scraping functionality
web_scraping_actions = True
if web_scraping_actions and not st.session_state.web_scraping:
with st.spinner("Web scraping site..."):
# Executes web scraping on the URL defined in the function above
init_web_scraping()
st.session_state.web_scraping = True
# Load data functionality
load_data = True
remove_existing_persist_dir = True
if load_data and not st.session_state.load_and_process:
with st.spinner("Loading data and creating vector embeddings..."):
# Loads data & creates vector embeddings
load_and_process_data(
path_to_data, persist_dir, remove_existing_persist_dir
)
st.session_state.load_and_process = True
# Create a vector store to serve the custom knowledge base
vector_store = create_vector_store(persist_dir)
# Define the Large Lanuage Model (LLM) for the chatbot
llm = ChatOpenAI(temperature=0.2, model_name="gpt-4")
# Define the conversational retrieval chain
st.session_state.conversation = ConversationalRetrievalChain.from_llm(
llm=llm,
# Define a retriever for the knowledge base context
retriever=vector_store.as_retriever(search_kwargs={"k": 5}),
# Create a Memory object
memory=ConversationSummaryMemory(
llm=llm, memory_key="chat_history", return_messages=True
),
)
def on_click_callback():
"""
Manages chat history in session state
"""
# Wrap code into a get OpenAI callback for the token count
with get_openai_callback() as callback:
# Get the prompt from session state
human_prompt = st.session_state.human_prompt
# Call the conversation chain defined in session state on user prompt
llm_response = st.session_state.conversation({"question": human_prompt})
# Persist the prompt and llm_response in session state
st.session_state.history.append(Message("human", human_prompt))
st.session_state.history.append(Message("AI", llm_response["answer"]))
# Pesist token count in session state
st.session_state.token_count += callback.total_tokens
# Clear the prompt value
st.session_state.human_prompt = ""
def main():
try:
load_css()
initialize_session_state()
# Setup web page text
st.title("Ask Chatbot 🤖")
st.header("Let's Talk About Your Data (or Whatever) 💬")
# Create a container for the chat between the user & LLM
chat_placeholder = st.container()
# Create a form for the user prompt
prompt_placeholder = st.form("chat-form")
# Create a empty placeholder for the token count
token_placeholder = st.empty()
# Display chat history within chat_placehoder
with chat_placeholder:
for chat in st.session_state.history:
div = f"""
<div class="chat-row {'' if chat.origin == 'AI' else 'row-reverse'}">
<img class="chat-icon" src="{'https://ask-chatbot.netlify.app/public/ai_icon.png' if chat.origin == 'AI' else 'https://ask-chatbot.netlify.app/public/user_icon.png'}" width=32 height=32>
<div class="chat-bubble {'ai-bubble' if chat.origin == 'AI' else 'human-bubble'}">​{chat.message}</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for _ in range(3):
st.markdown("")
# Create the user prompt within prompt_placeholder
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
cols[0].text_input(
"Chat",
placeholder="Send a message",
label_visibility="collapsed",
key="human_prompt",
)
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=on_click_callback,
)
# Display # of tokens used & conversation context within token_placeholder
token_placeholder.caption(
f"""
Used {st.session_state.token_count} tokens \n
Debug LangChain conversation:
{st.session_state.conversation.memory.buffer}
"""
)
except Exception as error:
print("There has been an error: ", error)
if __name__ == "__main__":
main()