-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdoc.py
147 lines (121 loc) · 4.85 KB
/
doc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import cohere
import os
import hnswlib
import json
import uuid
import requests
from typing import List, Dict
import validators
from unstructured.partition.html import partition_html
from unstructured.chunking.title import chunk_by_title
# get cohere api key from .env
from dotenv import load_dotenv
import os
load_dotenv()
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
co = cohere.Client(COHERE_API_KEY)
headers={'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/118.0.0.0 Safari/537.36'}
class Documents:
"""
A class representing a collection of documents.
Parameters:
sources (list): A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys.
Attributes:
sources (list): A list of dictionaries representing the sources of the documents.
docs (list): A list of dictionaries representing the documents, with 'title', 'content', and 'url' keys.
docs_embs (list): A list of the associated embeddings for the documents.
retrieve_top_k (int): The number of documents to retrieve during search.
rerank_top_k (int): The number of documents to rerank after retrieval.
docs_len (int): The number of documents in the collection.
index (hnswlib.Index): The index used for document retrieval.
Methods:
load(): Loads the data from the sources and partitions the HTML content into chunks.
embed(): Embeds the documents using the Cohere API.
index(): Indexes the documents for efficient retrieval.
retrieve(query): Retrieves documents based on the given query.
"""
def __init__(self, sources: List[Dict[str, str]]):
self.sources = sources
self.docs = []
self.docs_embs = []
self.retrieve_top_k = 10
self.rerank_top_k = 3
self.load()
self.embed()
self.index()
def load(self) -> None:
"""
Loads the documents from the sources and chunks the HTML content.
"""
print("Loading documents...")
for source in self.sources:
try:
elements = partition_html(url=source["url"], headers=headers)
except:
continue
chunks = chunk_by_title(elements)
for chunk in chunks:
self.docs.append(
{
"title": source["title"],
"text": str(chunk),
"url": source["url"],
}
)
def embed(self) -> None:
"""
Embeds the documents using the Cohere API.
"""
print("Embedding documents...")
batch_size = 90
self.docs_len = len(self.docs)
for i in range(0, self.docs_len, batch_size):
batch = self.docs[i : min(i + batch_size, self.docs_len)]
texts = [item["text"] for item in batch]
docs_embs_batch = co.embed(
texts=texts, model="embed-english-v3.0", input_type="search_document"
).embeddings
self.docs_embs.extend(docs_embs_batch)
def index(self) -> None:
"""
Indexes the documents for efficient retrieval.
"""
print("Indexing documents...")
self.idx = hnswlib.Index(space="ip", dim=1024)
self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64)
self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs))))
print(f"Indexing complete with {self.idx.get_current_count()} documents.")
def retrieve(self, query: str) -> List[Dict[str, str]]:
"""
Retrieves documents based on the given query.
Parameters:
query (str): The query to retrieve documents for.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'text', and 'url' keys.
"""
docs_retrieved = []
query_emb = co.embed(
texts=[query], model="embed-english-v3.0", input_type="search_query"
).embeddings
doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]
docs_to_rerank = []
for doc_id in doc_ids:
docs_to_rerank.append(self.docs[doc_id]["text"])
rerank_results = co.rerank(
query=query,
documents=docs_to_rerank,
top_n=self.rerank_top_k,
model="rerank-english-v2.0",
)
doc_ids_reranked = []
for result in rerank_results:
doc_ids_reranked.append(doc_ids[result.index])
for doc_id in doc_ids_reranked:
docs_retrieved.append(
{
"title": self.docs[doc_id]["title"],
"text": self.docs[doc_id]["text"],
"url": self.docs[doc_id]["url"],
}
)
return docs_retrieved