-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored source code using postgres
- Loading branch information
Showing
1 changed file
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Process | ||
Automatically generated by Colab. | ||
Original file is located at | ||
https://colab.research.google.com/drive/1Q1UxrmNHSKvNRqlksj7Lf0sObE5Zeq74 | ||
""" | ||
|
||
import psycopg2 | ||
import json | ||
import os | ||
from dotenv import load_dotenv | ||
from llama_cpp import Llama | ||
from huggingface_hub import hf_hub_download | ||
|
||
# Constants | ||
MODEL_NAME = "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF" | ||
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf") | ||
BATCH_SIZE = 10 # Number of records to process in each batch | ||
TIMEOUT = 5000 # Statement timeout in milliseconds | ||
|
||
# Load Llama model | ||
model = Llama( | ||
model_path=MODEL_PATH, | ||
n_threads=64, | ||
n_batch=128, | ||
n_ctx=4096, | ||
n_gpu_layers=96, | ||
mlock=True, | ||
logits_all=True | ||
) | ||
|
||
# Load environment variables | ||
load_dotenv('.env') | ||
|
||
# Establish database connection | ||
def create_db_connection(): | ||
return psycopg2.connect( | ||
dbname=os.getenv("DB_NAME"), | ||
user=os.getenv("DB_USER"), | ||
password=os.getenv("DB_PASSWORD"), | ||
host=os.getenv("DB_HOST"), | ||
port=os.getenv("DB_PORT"), | ||
connect_timeout=10 # Set connection timeout to 10 seconds | ||
) | ||
|
||
# Set statement timeout for long-running queries | ||
def set_statement_timeout(cursor, timeout=TIMEOUT): | ||
cursor.execute(f"SET statement_timeout = {timeout};") | ||
|
||
# Read starting ID from environment | ||
def read_starting_id(): | ||
return int(os.getenv("STARTING_ID", 0)) | ||
|
||
# Save updated ID to the .env file | ||
def save_updated_id(id_value): | ||
with open('.env', 'r') as file: | ||
lines = file.readlines() | ||
|
||
updated_lines = [ | ||
f"STARTING_ID={id_value}\n" if line.startswith("STARTING_ID=") else line | ||
for line in lines | ||
] | ||
|
||
if not any(line.startswith("STARTING_ID=") for line in lines): | ||
updated_lines.append(f"STARTING_ID={id_value}\n") | ||
|
||
with open('.env', 'w') as file: | ||
file.writelines(updated_lines) | ||
|
||
# Process a triple and sentence using Llama model | ||
def process_triple(triple, sentence): | ||
prompt = f""" | ||
SYSTEM: You are a computational biologist tasked with evaluating scientific claims. | ||
Your role requires you to apply critical thinking and your expertise to interpret data and research findings accurately. | ||
Answer 'Yes' or 'No' to directly address the query posed. | ||
USER: ('Does the phrase "{triple}" receive at least indirect support from the statement: "{sentence}"?',). | ||
ASSISTANT: | ||
""" | ||
|
||
response = model(prompt=prompt, max_tokens=1, temperature=0, echo=False, logprobs=True) | ||
return response["choices"][0]["text"], json.dumps(response["choices"][0]["logprobs"]) | ||
|
||
# Process batches of records from the database | ||
def process_batches(cursor, starting_id): | ||
while True: | ||
cursor.execute(""" | ||
SELECT "id", "triple", "sentence" | ||
FROM public."tblBiomedicalFactcheck" | ||
WHERE "id" >= %s | ||
ORDER BY "id" | ||
LIMIT %s | ||
""", (starting_id, BATCH_SIZE)) | ||
|
||
records = cursor.fetchall() | ||
if not records: | ||
print("No more records to process.") | ||
break | ||
|
||
for record in records: | ||
record_id, triple, sentence = record | ||
answer, logprops_json = process_triple(triple, sentence) | ||
try: | ||
cursor.execute(""" | ||
UPDATE public."tblBiomedicalFactcheck" | ||
SET "answer" = %s, "logprops" = %s | ||
WHERE "id" = %s | ||
""", (answer, logprops_json, record_id)) | ||
except Exception as e: | ||
print(f"Error updating record ID {record_id}: {e}") | ||
conn.rollback() # Rollback in case of an error | ||
|
||
conn.commit() | ||
starting_id += BATCH_SIZE | ||
save_updated_id(starting_id) | ||
print(f"Processed up to ID: {starting_id}") | ||
|
||
# Main processing function | ||
def main(): | ||
starting_id = read_starting_id() | ||
conn = create_db_connection() | ||
cursor = conn.cursor() | ||
|
||
try: | ||
set_statement_timeout(cursor) | ||
process_batches(cursor, starting_id) | ||
finally: | ||
cursor.close() | ||
conn.close() | ||
print("Data processing completed successfully.") | ||
|
||
if __name__ == "__main__": | ||
main() |