-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrag.py
109 lines (83 loc) · 3.11 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# import Document class from doc.py
from doc import Documents
import uuid
from typing import List, Dict
import cohere
# get cohere api key from .env
from dotenv import load_dotenv
import os
load_dotenv()
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
co = cohere.Client(COHERE_API_KEY)
class Rag:
"""
A class representing a chatbot.
Parameters:
docs (Documents): An instance of the Documents class representing the collection of documents.
Attributes:
conversation_id (str): The unique ID for the conversation.
docs (Documents): An instance of the Documents class representing the collection of documents.
Methods:
generate_response(message): Generates a response to the user's message.
retrieve_docs(response): Retrieves documents based on the search queries in the response.
"""
def __init__(self, docs: Documents):
self.docs = docs
self.conversation_id = str(uuid.uuid4())
def search_query(self, message: str):
# If there are search queries, retrieve documents and respond
if response.search_queries:
return response
else:
return False, response
def generate_response(self, message: str, doc: Documents, response):
"""
Generates a response to the user's message.
Parameters:
message (str): The user's message.
Yields:
Event: A response event generated by the chatbot.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.
"""
# Generate search queries (if any)
if response.search_queries:
response = co.chat(
message=message,
documents=doc,
conversation_id=self.conversation_id,
stream=True,
)
for event in response:
yield event
# If there is no search query, directly respond
else:
response = co.chat(
message=message,
conversation_id=self.conversation_id,
stream=True
)
for event in response:
yield event
def retrieve_docs(self, response) -> List[Dict[str, str]]:
"""
Retrieves documents based on the search queries in the response.
Parameters:
response: The response object containing search queries.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.
"""
# Get the query(s)
queries = []
for search_query in response.search_queries:
queries.append(search_query["text"])
# Retrieve documents for each query
retrieved_docs = []
for query in queries:
retrieved_docs.extend(self.docs.retrieve(query))
# # Uncomment this code block to display the chatbot's retrieved documents
# print("DOCUMENTS RETRIEVED:")
# for idx, doc in enumerate(retrieved_docs):
# print(f"doc_{idx}: {doc}")
# print("\n")
return retrieved_docs