forked from AllAboutAI-YT/easy-local-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOllama_RAG_chat_TTS.py
299 lines (240 loc) · 8.68 KB
/
Ollama_RAG_chat_TTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import argparse
import json
import os
import torch
import ollama
from tqdm import tqdm
from gtts import gTTS
import pygame
import io
import re
import threading
import queue
"""
TODO: vault_embeddings.pt has been replaced with vault_embeddings.txt
"""
# Constants
EMBEDDINGS_DIR = "Embeddings"
# ANSI escape codes for colors
PINK = "\033[95m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
NEON_GREEN = "\033[92m"
RESET_COLOR = "\033[0m"
# Magenta
MAGENTA = "\033[35m"
# Blue
BLUE = "\033[94m"
# Red
RED = "\033[91m"
# Bold
BOLD = "\033[1m"
"""
####### ####### #####
# # # #
# # #
# # #####
# # #
# # # #
# # #####
"""
# Function to convert text to speech using gTTS and play using pygame with speed adjustment
def text_to_speech(text, speed=2.0, volume=1):
try:
tts = gTTS(
lang="en-gb",
text=text,
tld="co.uk",
)
audio_fp = io.BytesIO()
tts.write_to_fp(audio_fp)
audio_fp.seek(0)
pygame.mixer.init()
pygame.mixer.music.load(audio_fp)
# Set the volume (range 0.0 to 1.0)
pygame.mixer.music.set_volume(volume)
pygame.mixer.music.play()
while pygame.mixer.music.get_busy():
pygame.time.Clock().tick(10)
pygame.mixer.quit()
except Exception as e:
print(f"Error occurred during playback: {e}")
# Function to load vault content from JSON file
def load_vault_content(json_file):
vault_content = {}
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
for entry in data:
for chunk in entry["chunks"]:
vault_content[chunk["id"]] = chunk["text"].strip()
return vault_content
# Function to read file content
def open_file(filepath):
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
"""
###### ## ## ### ########
## ## ## ## ## ## ##
## ## ## ## ## ##
## ######### ## ## ##
## ## ## ######### ##
## ## ## ## ## ## ##
###### ## ## ## ## ##
"""
# Function to interact with the Ollama model
def ollama_chat(
user_input,
system_message,
vault_embeddings,
vault_content,
ollama_model,
conversation_history,
):
# Get relevant context from the vault
relevant_context = get_relevant_context(
user_input, vault_embeddings, vault_content, top_k=3
)
if relevant_context:
# Convert list to a single string with newlines between items
context_str = "\n\n".join(relevant_context)
print(
"Context Pulled from Documents: \n\n"
+ CYAN
+ context_str
+ RESET_COLOR
+ "\n\n"
)
else:
print("No relevant context found.")
# Prepare the user's input by concatenating it with the relevant context
user_input_with_context = user_input
if relevant_context:
user_input_with_context = context_str + "\n\n" + user_input
# Append the user's input to the conversation history
conversation_history.append({"role": "user", "content": user_input_with_context})
# Create a message history including the system message and the conversation history
messages = [{"role": "system", "content": system_message}, *conversation_history]
# Send the completion request to the Ollama model with stream=True
stream = ollama.chat(
model=ollama_model,
messages=messages,
stream=True,
keep_alive=-1,
)
# Queue for sentences
q = queue.Queue()
# Start the worker thread
worker_thread = threading.Thread(target=process_queue, args=(q,))
worker_thread.start()
for chunk in stream:
process_chunk(chunk, q)
print(NEON_GREEN + chunk["message"]["content"], end="", flush=True)
# Wait for all tasks to be done and stop the worker thread
q.join()
q.put(None) # Send sentinel value to stop the worker
worker_thread.join()
# Print the response
print(RESET_COLOR + "\n")
# Function to load embeddings and file modification time
def load_embeddings():
embeddings_file = os.path.join(EMBEDDINGS_DIR, "all-minilm_vault_embeddings.pt")
if os.path.exists(embeddings_file):
embeddings = torch.load(embeddings_file)
return embeddings
return []
"""
#### ###### ##### ##### ###### # ###### # # ## # # ##### #### #### # # ##### ###### # # #####
# # # # # # # # # # # # # ## # # # # # # ## # # # # # #
# ##### # # # ##### # ##### # # # # # # # # # # # # # # # ##### ## #
# ### # # ##### # # # # # ###### # # # # # # # # # # # # ## #
# # # # # # # # # # # # # # ## # # # # # # ## # # # # #
#### ###### # # # ###### ###### ###### ## # # # # # #### #### # # # ###### # # #
"""
# Function to get relevant context from the vault based on user input
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
if not vault_embeddings: # Check if the list is empty
return []
# Encode the rewritten input
# input_embedding = ollama.embeddings(model="mxbai-embed-large", prompt=rewritten_input)["embedding"]
# Encode the rewritten input
input_embedding = ollama.embeddings(
model="all-minilm", prompt=rewritten_input, keep_alive=-1
)["embedding"]
# Create a tensor from input_embedding
input_embedding_tensor = torch.tensor(input_embedding).unsqueeze(0)
# Prepare embeddings and ids lists
embeddings_list = []
ids_list = []
for embedding_dict in vault_embeddings:
for chunk_id, embedding in embedding_dict.items():
embeddings_list.append(embedding)
ids_list.append(chunk_id)
# Create a tensor from the embeddings list
vault_embeddings_tensor = torch.tensor(embeddings_list)
# Compute cosine similarity between the input and vault embeddings
cos_scores = torch.cosine_similarity(
input_embedding_tensor, vault_embeddings_tensor
)
# Adjust top_k if it's greater than the number of available scores
top_k = min(top_k, len(cos_scores))
# Sort the scores and get the top-k indices
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
# Get the corresponding chunk_ids from the top indices
top_chunk_ids = [ids_list[idx] for idx in top_indices]
# Get the corresponding context from the vault content using chunk_ids
relevant_context = [vault_content[chunk_id] for chunk_id in top_chunk_ids]
print(chunk_id)
return relevant_context
# Function to process the queue
def process_queue(q):
while True:
sentence = q.get()
if sentence is None: # Sentinel value to stop the worker
break
text_to_speech(sentence)
q.task_done()
# Function to process each chunk
def process_chunk(chunk, q):
content = ""
content = content + chunk["message"]["content"]
sentences = re.split(r"\n\n", content) # Split based on sentence boundaries
for sentence in sentences:
if sentence.strip():
q.put(sentence.strip())
def main():
parser = argparse.ArgumentParser(description="Ollama Chat")
parser.add_argument(
"--model", default="phi3", help="Ollama model to use (default: llama3)"
)
args = parser.parse_args()
# Example conversation history
conversation_history = [{"role": "system", "content": "Welcome to Ollama Chat!"}]
vault_embeddings = load_embeddings()
# Example usage
json_file = "Backup_vault.json"
vault_content = load_vault_content(json_file)
while True:
user_input = input(
"\n"
+ RED
+ BOLD
+ "Enter your message (or 'exit' to quit):"
+ "\n"
+ RESET_COLOR
+ "\n"
)
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
system_message = "use proper puncuations, You will give precise and concise answers from the given context and question"
# Interact with the Ollama model
if user_input:
ollama_chat(
user_input,
system_message,
vault_embeddings,
vault_content,
args.model,
conversation_history,
)
if __name__ == "__main__":
main()