forked from AllAboutAI-YT/easy-local-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPubMedBert_generate_embeddings.py
279 lines (217 loc) · 9.23 KB
/
PubMedBert_generate_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import os
import json
import pynvml
import time
from tqdm import tqdm
# import ollama
import torch
import threading
from transformers import AutoTokenizer, AutoModel
# Initialize PubMedBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
)
model = AutoModel.from_pretrained(
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
RED = "\033[91m"
RESET_COLOR = "\033[0m"
# Constants
EMBEDDINGS_DIR = "Embeddings"
MOD_TIME_FILE = os.path.join(EMBEDDINGS_DIR, "mod_times.json")
PT_EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "vault_embeddings.pt")
TXT_EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "vault_embeddings.txt")
embeddings_file = os.path.join(EMBEDDINGS_DIR, PT_EMBEDDINGS_FILE)
UPDATED_VAULT_JSON_FILE = "vault.json"
def convert_pt_to_txt(pt_file, txt_file):
embeddings = torch.load(pt_file)
with open(txt_file, "w", encoding="utf-8") as f:
for embedding in embeddings:
f.write(json.dumps(embedding) + "\n")
def convert_txt_to_pt(txt_file, pt_file):
embeddings = []
with open(txt_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line.strip())
embeddings.append(data)
torch.save(embeddings, pt_file)
# Function to read vault data from JSON file
def read_vault_data(json_file):
with open(json_file, "r", encoding="utf-8") as f:
return json.load(f)
# Function to save embeddings
def save_embeddings(new_embeddings):
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
# Load existing embeddings if they exist
if os.path.exists(embeddings_file):
existing_embeddings = torch.load(embeddings_file)
existing_embeddings.extend(new_embeddings)
updated_embeddings = existing_embeddings
else:
updated_embeddings = new_embeddings
# Save the updated embeddings
torch.save(updated_embeddings, embeddings_file)
# Function to load embeddings and file modification time
def load_embeddings(vault_name):
embeddings_file = os.path.join(EMBEDDINGS_DIR, f"{vault_name}_embeddings.pt")
if os.path.exists(embeddings_file) and os.path.exists(MOD_TIME_FILE):
embeddings = torch.load(embeddings_file)
# with open(MOD_TIME_FILE, "r") as f:
# mod_time_data = json.load(f)
# mod_time = mod_time_data.get(vault_name)
return embeddings # , mod_time
return None, None
# Function to save embeddings in text format
def save_embeddings_txt(new_embeddings):
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
embeddings_file = os.path.join(EMBEDDINGS_DIR, "vault_embeddings.txt")
with open(embeddings_file, "a", encoding="utf-8") as f:
for embedding in new_embeddings:
f.write(json.dumps(embedding) + "\n")
embeddings_file = None
# Function to check GPU temperature
def check_gpu_temperature():
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
pynvml.nvmlShutdown()
return temp
def generate_embeddings(vault_data, vault_name, start_idx=0):
vault_embeddings = []
progress_log = os.path.join("embeddings", f"{vault_name}_progress.json")
last_processed_chunk_id = None
last_processed_file_path = None
# Load progress log if exists
if os.path.exists(progress_log):
with open(progress_log, "r") as f:
progress_data = json.load(f)
last_processed_chunk_id = progress_data.get("last_processed_chunk_id")
last_processed_file_path = progress_data.get("file_path")
print(f"Last processed chunk ID: {last_processed_chunk_id}")
print(f"File path: {last_processed_file_path}")
else:
print(f"Progress log '{progress_log}' does not exist.")
for entry in tqdm(vault_data, desc="Generating embeddings"):
file_path = entry["file_name"]
modification_time = entry["modification_time"]
chunks = entry["chunks"]
for chunk in chunks:
content = chunk["text"]
chunk_id = chunk["id"]
# Check GPU temperature before processing each content (Placeholder function)
while check_gpu_temperature() > 51:
print("GPU temp is too high. Pausing Temporarily...\n")
time.sleep(30)
try:
# Tokenize and encode the text, and move tensors to GPU
inputs = tokenizer(
content,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512,
).to(device)
with torch.no_grad():
outputs = model(**inputs.to(device)) # Move inputs to GPU
# Take the mean of the hidden states to create a single embedding vector
embeddings = (
torch.mean(outputs.last_hidden_state, dim=1).squeeze().tolist()
)
vault_embeddings.append({chunk_id: embeddings})
# Save checkpoint every 100 embeddings or at the end of each entry
if len(vault_embeddings) % 100 == 0 or entry == vault_data[-1]:
save_embeddings_txt(vault_embeddings)
vault_embeddings = []
data_to_save = {
"last_processed_chunk_id": chunk_id,
"file_path": file_path,
}
with open(progress_log, "w") as f:
json.dump(data_to_save, f)
except Exception as e:
print(f"Error processing chunk {chunk_id} in file {file_path}: {e}")
return vault_embeddings
def load_embeddings_txt(embeddings_file):
embeddings = []
if os.path.exists(embeddings_file):
with open(embeddings_file, "r", encoding="utf-8") as f:
for line_number, line in enumerate(f, start=1):
try:
data = json.loads(line.strip())
embeddings.append(list(data.keys())[0]) # Extract the ID
except json.JSONDecodeError as e:
start = max(0, e.pos - 20)
end = min(len(line), e.pos + 20)
print(
f"JSONDecodeError: Extra data at line {line_number}, column {e.pos}:\n{line[start:end]}"
)
return embeddings
def filter_vault_data(vault_data, embeddings_ids):
filtered_vault_data = []
for entry in vault_data:
filtered_chunks = []
for chunk in entry["chunks"]:
if chunk["id"] not in embeddings_ids:
filtered_chunks.append(chunk)
if filtered_chunks:
entry["chunks"] = filtered_chunks
filtered_vault_data.append(entry)
return filtered_vault_data
def clean_vault_json_file(vault_data):
# Load IDs from vault_embeddings.txt
embeddings_ids = load_embeddings_txt(TXT_EMBEDDINGS_FILE)
if embeddings_ids:
# Filter vault_data based on embeddings_ids
filtered_vault_data = filter_vault_data(vault_data, embeddings_ids)
# Write vault.json
with open(UPDATED_VAULT_JSON_FILE, "w", encoding="utf-8") as f:
json.dump(filtered_vault_data, f, indent=2)
print(
f"Filtered vault.json based on embeddings and saved as {UPDATED_VAULT_JSON_FILE}"
)
return filtered_vault_data
else:
return vault_data
def main():
# Locate vault.json automatically
vault_data = read_vault_data("vault.json")
print(f"Loaded vault.json ")
# sending in the cleaned json data to the same vault_data
vault_data = clean_vault_json_file(vault_data)
print(f"cleaned vault.json ")
vault_name = os.path.splitext(os.path.basename("vault.json"))[0]
total_chunks = sum(len(file_dict["chunks"]) for file_dict in vault_data)
saved_embeddings = load_embeddings(vault_name)
vault_embeddings = None
if os.path.exists(
os.path.join(EMBEDDINGS_DIR, f"{vault_name}_embeddings.pt")
) and total_chunks == len(saved_embeddings):
print(f"Loaded saved embeddings for {vault_name}")
vault_embeddings = saved_embeddings
else:
print(f"Generating new embeddings for {vault_name}")
# Create separate threads for generating and saving embeddings
generate_thread = threading.Thread(
target=generate_embeddings, args=(vault_data, vault_name)
)
save_thread = threading.Thread(
target=save_embeddings_txt, args=(vault_embeddings,)
)
# Start the threads
generate_thread.start()
save_thread.start()
# Wait for both threads to complete
generate_thread.join()
save_thread.join()
if os.path.exists(TXT_EMBEDDINGS_FILE):
convert_txt_to_pt(TXT_EMBEDDINGS_FILE, PT_EMBEDDINGS_FILE)
print(f"Converted {TXT_EMBEDDINGS_FILE} to {PT_EMBEDDINGS_FILE}")
else:
print(f"{TXT_EMBEDDINGS_FILE} does not exist.")
clean_vault_json_file(vault_data)
print(f"Embeddings generation completed for {vault_name}")
if __name__ == "__main__":
main()