forked from tslmy/llamacron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtool_for_my_notes.py
165 lines (148 loc) · 6.32 KB
/
tool_for_my_notes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Make a tool for accessing my personal notes, stored in a directory of text files.
"""
import logging
from typing import Optional
from llama_index import (
ServiceContext,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
)
from llama_index.tools import BaseTool
from llama_index.vector_stores import ChromaVectorStore
from pydantic import BaseModel
PATH_TO_NOTES = "demo-notes"
SHOULD_IGNORE_PERSISTED_INDEX = False
def __create_index(
input_dir: str,
storage_context: Optional[StorageContext] = None,
service_context: Optional[ServiceContext] = None,
) -> VectorStoreIndex:
"""
Creates an index from a directory of documents.
"""
documents = SimpleDirectoryReader(
input_dir=input_dir,
# https://docs.llamaindex.ai/en/stable/module_guides/loading/simpledirectoryreader.html#reading-from-subdirectories
recursive=True,
# https://docs.llamaindex.ai/en/stable/module_guides/loading/simpledirectoryreader.html#restricting-the-files-loaded
# Before including image files here, `mamba install pillow`.
# Before including audio files here, `pip install openai-whisper`.
required_exts=[".md", ".txt"],
).load_data()
return VectorStoreIndex.from_documents(
# https://docs.llamaindex.ai/en/stable/api_reference/indices/vector_store.html#llama_index.indices.vector_store.base.VectorStoreIndex.from_documents
documents=documents,
service_context=service_context,
storage_context=storage_context,
show_progress=True,
)
def make_tool(service_context: ServiceContext) -> BaseTool:
"""
Creates a tool for accessing my private information, or anything about me.
These can be my notes, my calendar, my emails, etc.
"""
# An index is a lightweight view to the database.
notes_index = __get_index(service_context)
notes_query_engine = notes_index.as_query_engine(
service_context=service_context,
similarity_top_k=5,
# For a query engine hidden inside an Agent, streaming really doesn't make sense.
# https://docs.llamaindex.ai/en/stable/module_guides/deploying/query_engine/streaming.html#streaming
streaming=False,
)
# Convert it to a tool.
from llama_index.tools import ToolMetadata
from llama_index.tools.query_engine import QueryEngineTool
class NotesQueryingToolSchema(BaseModel):
input: str
notes_query_engine_tool = QueryEngineTool(
query_engine=notes_query_engine,
metadata=ToolMetadata(
name="look_up_notes",
description="""Search the user's notes about a particular keyword.
Input should be the keyword that you want to search the user's notes with.""",
fn_schema=NotesQueryingToolSchema,
),
)
# Sub Question Query Engine: breaks down the user's question into sub questions.
# https://docs.llamaindex.ai/en/stable/examples/query_engine/sub_question_query_engine.html
from llama_index.query_engine import SubQuestionQueryEngine
from llama_index.question_gen.llm_generators import LLMQuestionGenerator
from sub_question_generating_prompt_in_keywords import (
SUB_QUESTION_PROMPT_TEMPLATE_WITH_KEYWORDS,
)
sub_question_query_engine = SubQuestionQueryEngine.from_defaults(
query_engine_tools=[notes_query_engine_tool],
question_gen=LLMQuestionGenerator.from_defaults(
service_context=service_context,
prompt_template_str=SUB_QUESTION_PROMPT_TEMPLATE_WITH_KEYWORDS,
),
service_context=service_context,
verbose=True,
)
# Convert it to a tool.
class AboutTheUserToolSchema(BaseModel):
input: str
sub_question_query_engine_tool = QueryEngineTool(
query_engine=sub_question_query_engine,
metadata=ToolMetadata(
name="about_the_user",
description="""Provides information about the user themselves, including the user's opinions on a given topic.
Input should be the topic about which you want to learn about the user. For example, you can ask:
"opinions about X", "food that I enjoy", "my financial standing", etc. """,
fn_schema=AboutTheUserToolSchema,
),
)
return sub_question_query_engine_tool
def __get_index(service_context: ServiceContext):
logger = logging.getLogger("__get_index")
# https://www.trychroma.com/
import chromadb
from chromadb.config import Settings
db = chromadb.PersistentClient(
path="./chroma_db",
settings=Settings(
# https://docs.trychroma.com/telemetry#opting-out
anonymized_telemetry=False
),
)
try:
chroma_collection = db.get_collection("notes")
except ValueError:
logger.info("The Chrome DB collection does not exist. Creating.")
should_create_index = True
else:
logger.info("Storage exists.")
if SHOULD_IGNORE_PERSISTED_INDEX:
db.delete_collection("notes")
logger.info(
"But it's requested to ignore the persisted data. We'll delete the Chrome DB collection."
)
should_create_index = True
else:
logger.info("We'll load from the Chrome DB collection.")
should_create_index = False
if should_create_index:
chroma_collection = db.create_collection("notes")
# assign chroma as the vector_store to the context
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return __create_index(
input_dir=PATH_TO_NOTES,
storage_context=storage_context,
service_context=service_context,
)
# If we are using file-based storage, we would have to call `persist` manually:
# index.storage_context.persist(persist_dir=STORAGE_DIR)
# But this doesn't apply to DBs like Chroma.
# else, load the existing index.
# assign chroma as the vector_store to the context
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return VectorStoreIndex.from_vector_store(
vector_store,
service_context=service_context,
storage_context=storage_context,
)