forked from AllAboutAI-YT/easy-local-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocalrag_no_rewrite.py
137 lines (112 loc) · 4.49 KB
/
localrag_no_rewrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import ollama
import os
from openai import OpenAI
import argparse
# ANSI escape codes for colors
PINK = "\033[95m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
NEON_GREEN = "\033[92m"
RESET_COLOR = "\033[0m"
# Function to open a file and return its contents as a string
def open_file(filepath):
with open(filepath, "r", encoding="utf-8") as infile:
return infile.read()
# Function to get relevant context from the vault based on user input
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
if vault_embeddings.nelement() == 0: # Check if the tensor has any elements
return []
# Encode the rewritten input
input_embedding = ollama.embeddings(
model="mxbai-embed-large", prompt=rewritten_input
)["embedding"]
# Compute cosine similarity between the input and vault embeddings
cos_scores = torch.cosine_similarity(
torch.tensor(input_embedding).unsqueeze(0), vault_embeddings
)
# Adjust top_k if it's greater than the number of available scores
top_k = min(top_k, len(cos_scores))
# Sort the scores and get the top-k indices
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
# Get the corresponding context from the vault
relevant_context = [vault_content[idx].strip() for idx in top_indices]
return relevant_context
# Function to interact with the Ollama model
def ollama_chat(
user_input,
system_message,
vault_embeddings,
vault_content,
ollama_model,
conversation_history,
):
# Get relevant context from the vault
relevant_context = get_relevant_context(
user_input, vault_embeddings_tensor, vault_content, top_k=3
)
if relevant_context:
# Convert list to a single string with newlines between items
context_str = "\n".join(relevant_context)
print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
else:
print(CYAN + "No relevant context found." + RESET_COLOR)
# Prepare the user's input by concatenating it with the relevant context
user_input_with_context = user_input
if relevant_context:
user_input_with_context = context_str + "\n\n" + user_input
# Append the user's input to the conversation history
conversation_history.append({"role": "user", "content": user_input_with_context})
# Create a message history including the system message and the conversation history
messages = [{"role": "system", "content": system_message}, *conversation_history]
# Send the completion request to the Ollama model
response = client.chat.completions.create(model=ollama_model, messages=messages)
# Append the model's response to the conversation history
conversation_history.append(
{"role": "assistant", "content": response.choices[0].message.content}
)
# Return the content of the response from the model
return response.choices[0].message.content
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Ollama Chat")
parser.add_argument(
"--model", default="qwen", help="Ollama model to use (default: llama3)"
)
args = parser.parse_args()
# Configuration for the Ollama API client
client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
# Load the vault content
vault_content = []
if os.path.exists("vault.txt"):
with open("vault.txt", "r", encoding="utf-8") as vault_file:
vault_content = vault_file.readlines()
# Generate embeddings for the vault content using Ollama
vault_embeddings = []
for content in vault_content:
response = ollama.embeddings(model="mxbai-embed-large", prompt=content)
vault_embeddings.append(response["embedding"])
# Convert to tensor and print embeddings
vault_embeddings_tensor = torch.tensor(vault_embeddings)
torch.save(vault_embeddings_tensor, "embeddings.pt")
print("Embeddings for each line in the vault:")
print(vault_embeddings_tensor)
# Conversation loop
conversation_history = []
system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"
while True:
user_input = input(
YELLOW
+ "Ask a question about your documents (or type 'quit' to exit): "
+ RESET_COLOR
)
if user_input.lower() == "quit":
break
response = ollama_chat(
user_input,
system_message,
vault_embeddings_tensor,
vault_content,
args.model,
conversation_history,
)
print(NEON_GREEN + "Response: \n\n" + response + RESET_COLOR)