-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalpha_DocumentContextManager.py
51 lines (40 loc) · 1.86 KB
/
alpha_DocumentContextManager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Not using a DB
import torch
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import time
class DocumentContextManager:
def __init__(self):
self.documents = {}
self.embeddings = {}
self.metadata = {}
# Load pre-trained model and tokenizer
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.model = BertModel.from_pretrained('bert-base-uncased')
def add_document(self, doc_id, text, filename):
self.documents[doc_id] = text
embedding = self._embed_text(text)
self.embeddings[doc_id] = embedding
self.metadata[doc_id] = {
'filename': filename,
'upload_time': time.time(), # Store upload time
'summary': text[:100] # Store a brief summary or the first 100 characters
}
def _embed_text(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs)
# Mean pooling to get a single vector for the document
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.cpu().numpy()
def get_similar_documents(self, query, top_k=5):
query_embedding = self._embed_text(query)
similarities = {}
for doc_id, doc_embedding in self.embeddings.items():
sim = cosine_similarity(query_embedding, doc_embedding)
similarities[doc_id] = sim[0][0]
sorted_docs = sorted(similarities.items(), key=lambda item: item[1], reverse=True)
return [(doc_id, self.documents[doc_id], self.embeddings[doc_id]) for doc_id, _ in sorted_docs[:top_k]]
def get_document_metadata(self, doc_id):
return self.metadata.get(doc_id, {})