Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

retrieval script #1

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e7476df
initial commit
Vincent-Ustach Jan 16, 2025
a4eb270
store false for inference_bool
Vincent-Ustach Jan 16, 2025
3749b33
require task_desc_infile
Vincent-Ustach Jan 16, 2025
a719456
task_desc_infile pathlib path
Vincent-Ustach Jan 16, 2025
a149c13
inference_bool default True
Vincent-Ustach Jan 16, 2025
0dc40c3
inference_bool typo
Vincent-Ustach Jan 16, 2025
6d30f24
import DataArgs
Vincent-Ustach Jan 16, 2025
9df5b98
set model device data_args to None
Vincent-Ustach Jan 16, 2025
dda1bca
read disease desc
Vincent-Ustach Jan 17, 2025
43a3218
Merge branch 'refs/heads/main' into retrieval_script
Vincent-Ustach Jan 17, 2025
b18d3d6
reduce precision of model before loading to device
Vincent-Ustach Jan 17, 2025
c0e4881
copy retrieval script
Vincent-Ustach Jan 17, 2025
a589308
rename files
Vincent-Ustach Jan 17, 2025
824b3c2
if no model load no create_input_retrieval
Vincent-Ustach Jan 17, 2025
1205024
remove unused imports
Vincent-Ustach Jan 17, 2025
54943dd
if no model no create_input_retrieval
Vincent-Ustach Jan 17, 2025
3a92807
remove unused imports
Vincent-Ustach Jan 17, 2025
694b729
black formatting
Vincent-Ustach Jan 17, 2025
48f8fb9
refactor with startup and do methods for later api
Vincent-Ustach Jan 17, 2025
4844320
docstrings
Vincent-Ustach Jan 17, 2025
6a7b4f1
bug in calling startup_retrieval
Vincent-Ustach Jan 17, 2025
4f8368e
fastapi app
Vincent-Ustach Jan 17, 2025
4e5bb42
add more required env vars
Vincent-Ustach Jan 17, 2025
643cc6f
move utils to inference/retrieval_utils.py
Vincent-Ustach Jan 17, 2025
f4299b1
fix imports for retrieval_utils
Vincent-Ustach Jan 17, 2025
397d631
typo
Vincent-Ustach Jan 17, 2025
01c887f
update comment
Vincent-Ustach Jan 17, 2025
1b3b262
instruction_source_dataset passed as argument
Vincent-Ustach Jan 18, 2025
f4fa800
update docstring for app
Vincent-Ustach Jan 18, 2025
7c6619d
bug w repeated model.to_device() commands
Vincent-Ustach Jan 19, 2025
58ff2a6
return top k
Vincent-Ustach Jan 19, 2025
f30352f
by default return all records
Vincent-Ustach Jan 19, 2025
7a9edd7
fillna
Vincent-Ustach Jan 19, 2025
491d950
remove disease from input description. remove unused imports in app.
Vincent-Ustach Jan 22, 2025
6067989
remove disease from input description. remove unused imports in app.
Vincent-Ustach Jan 22, 2025
3c60c9a
delete drug script
Vincent-Ustach Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions procyon/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import pandas as pd
from loguru import logger

# Import the key functions from the existing codebase
from procyon.inference.retrieval_utils import startup_retrieval, do_retrieval

app = FastAPI()

# Global variables to store model and device
model = None
device = None
data_args = None


class RetrievalRequest(BaseModel):
task_desc: str = Field(description="The task description.")
disease_desc: str = Field(description="The disease description.")
instruction_source_dataset: str = Field(
description="Dataset source for instructions - either 'disgenet' or 'omim'"
)
k: Optional[int] = Field(default=None, description="Number of top results to return. If None, returns all results", ge=1)


@app.on_event("startup")
async def startup_event():
"""Initialize the model and required components on startup"""
global model, device, data_args

if not os.getenv("HF_TOKEN"):
raise EnvironmentError("HF_TOKEN environment variable not set")
if not os.getenv("CHECKPOINT_PATH"):
raise EnvironmentError("CHECKPOINT_PATH environment variable not set")
if not os.getenv("HOME_DIR"):
raise EnvironmentError("HOME_DIR environment variable not set")
if not os.getenv("DATA_DIR"):
raise EnvironmentError("DATA_DIR environment variable not set")
if not os.getenv("LLAMA3_PATH"):
raise EnvironmentError("LLAMA3_PATH environment variable not set")

# Use the existing startup_retrieval function
model, device, data_args = startup_retrieval(inference_bool=True)
logger.info("Model loaded and ready")


@app.post("/retrieve")
async def retrieve_proteins(request: RetrievalRequest):
"""Endpoint to perform protein retrieval"""
global model, device, data_args

if not all([model, device, data_args]):
raise HTTPException(status_code=500, detail="Model not initialized")

try:
# Use the existing do_retrieval function
results_df = do_retrieval(
model=model,
data_args=data_args,
device=device,
task_desc=request.task_desc,
disease_desc=request.disease_desc,
instruction_source_dataset=request.instruction_source_dataset,
)

results_df = results_df.fillna('')

# Return all results if k is None, otherwise return top k
if request.k is None:
return {"results": results_df.to_dict(orient="records")}
return {"results": results_df.head(request.k).to_dict(orient="records")}

except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error during retrieval: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
"""
This API endpoint will allow users to perform protein retrieval for a given disease description using the
pre-trained ProCyon model Procyon-Full.
This API script can be run directly using the command `python main.py`
this script will start the FastAPI server on port 8000
The API will be available at http://localhost:8000
An example request can be made using curl:
curl -X POST "http://localhost:8000/retrieve" \
-H "Content-Type: application/json" \
-d '{"task_desc": "Find proteins related to this disease",
"disease_desc": "Major depressive disorder",
"instruction_source_dataset": "disgenet",
"k": 1000}'
"""
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
191 changes: 191 additions & 0 deletions procyon/inference/retrieval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
from pathlib import Path

from huggingface_hub import login as hf_login
from loguru import logger
import pandas as pd
from typing import Dict, Optional, Tuple, Union
import torch

from procyon.data.inference_utils import (
create_input_retrieval,
get_proteins_from_embedding,
)
from procyon.evaluate.framework.utils import move_inputs_to_device
from procyon.model.model_unified import UnifiedProCyon
from procyon.training.train_utils import DataArgs

CKPT_NAME = os.path.expanduser(os.getenv("CHECKPOINT_PATH"))


def startup_retrieval(
inference_bool: bool = True,
) -> Tuple[
Union[UnifiedProCyon, None], Union[torch.device, None], Union[DataArgs, None]
]:
"""
This function performs startup functions to initiate protein retrieval:
Logs into the huggingface hub and loads the pre-trained ProCyon model.
Args:
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference;
then the model will not be loaded.
Returns:
model (UnifiedProCyon): The pre-trained ProCyon model
device (torch.device): The compute device (GPU or CPU) on which the model is loaded
data_args (DataArgs): The data arguments defined by the pre-trained model
"""

logger.info("Logging into huggingface hub")
hf_login(token=os.getenv("HF_TOKEN"))
logger.info("Done logging into huggingface hub")

if inference_bool:
logger.info("Inference is enabled.")

# load the pre-trained ProCyon model
model, device, data_args = load_model_onto_device()
else:
logger.info("Inference is disabled.")
# loading the model takes much time and memory, so we skip it if we don't need it
model = None
device = None
data_args = None

return model, device, data_args


def load_model_onto_device() -> Tuple[UnifiedProCyon, torch.device, DataArgs]:
"""
Load the pre-trained ProCyon model and move it to the compute device.
Returns:
model (UnifiedProCyon): The pre-trained ProCyon model
device (torch.device): The compute device (GPU or CPU) on which the model is loaded
data_args (DataArgs): The data arguments defined by the pre-trained model
"""
# Load the pre-trained ProCyon model
logger.info("Loading pretrained model")
# Replace with the path where you downloaded a pre-trained ProCyon model (e.g. ProCyon-Full)
data_args = torch.load(os.path.join(CKPT_NAME, "data_args.pt"))
model, _ = UnifiedProCyon.from_pretrained(checkpoint_dir=CKPT_NAME)
logger.info("Done loading pretrained model")

logger.info("Quantizing the model to a smaller precision")
model.bfloat16() # Quantize the model to a smaller precision
logger.info("Done quantizing the model to a smaller precision")

logger.info("Setting the model to evaluation mode")
model.eval()
logger.info("Done setting the model to evaluation mode")

logger.info("Applying pretrained model to device")
logger.info(f"Total memory allocated by PyTorch: {torch.cuda.memory_allocated()}")
# identify available devices on the machine
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Total memory allocated by PyTorch: {torch.cuda.memory_allocated()}")

logger.info("Done loading model and applying it to compute device")

return model, device, data_args


def do_retrieval(
model: UnifiedProCyon,
data_args: DataArgs,
device: torch.device,
instruction_source_dataset: str,
inference_bool: bool = True,
task_desc_infile: Path = None,
disease_desc_infile: Path = None,
task_desc: str = None,
disease_desc: str = None,
) -> Optional[pd.DataFrame]:
"""
This function performs protein retrieval for a given disease using the pre-trained ProCyon model.
Args:
model (UnifiedProCyon): The pre-trained ProCyon model
data_args (DataArgs): The data arguments defined by the pre-trained model
device (torch.device): The compute device (GPU or CPU) on which the model is loaded
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference
task_desc_infile (Path): The path to the file containing the task description.
disease_desc_infile (Path): The path to the file containing the disease description.
task_desc (str): The task description.
disease_desc (str): The disease description.
instruction_source_dataset (str): Dataset source for instructions - either "disgenet" or "omim"
Returns:
df_dep (pd.DataFrame): The DataFrame containing the top protein retrieval results
"""
if instruction_source_dataset not in ["disgenet", "omim"]:
raise ValueError(
'instruction_source_dataset must be either "disgenet" or "omim"'
)

# TODO get rid of this IO if we always do protein retrieval!
# Load the pre-calculated protein target embeddings
logger.info("Load protein target embeddings")
all_protein_embeddings, all_protein_ids = torch.load(
os.path.join(CKPT_NAME, "protein_target_embeddings.pkl")
)
all_protein_embeddings = all_protein_embeddings.float()
logger.info(
f"shape of precalculated embeddings matrix: {all_protein_embeddings.shape}"
)

#
logger.info("entering task description and prompt")
if task_desc_infile is not None:
if task_desc is not None:
raise ValueError(
"Only one of task_desc_infile and task_desc can be provided."
)
# read the task description from a file
with open(task_desc_infile, "r") as f:
task_desc = f.read()
elif task_desc is None:
raise ValueError("Either task_desc_infile or task_desc must be provided.")

if disease_desc_infile is not None:
if disease_desc is not None:
raise ValueError(
"Only one of disease_desc_infile and disease_desc can be provided."
)
# read the disease description from a file
with open(disease_desc_infile, "r") as f:
disease_desc = f.read()
elif disease_desc is None:
raise ValueError("Either disease_desc_infile or disease_desc must be provided.")

task_desc = task_desc.replace("\n", " ")
disease_desc = disease_desc.replace("\n", " ")

logger.info("Done entering task description and prompt")

if inference_bool:
logger.info("Now performing protein retrieval")

# Create input for retrieval
input_simple = create_input_retrieval(
input_description=disease_desc,
data_args=data_args,
task_definition=task_desc,
instruction_source_dataset=instruction_source_dataset,
instruction_source_relation="all",
aaseq_type="protein",
icl_example_number=1, # 0, 1, 2
)

input_simple = move_inputs_to_device(input_simple, device=device)
with torch.no_grad():
model_out = model(
inputs=input_simple,
retrieval=True,
aaseq_type="protein",
)
# The script can run up to here without a GPU, but the following line requires a GPU
df_dep = get_proteins_from_embedding(
all_protein_embeddings, model_out, top_k=None
)

logger.info("Done performaing protein retrieval for example 1")

return df_dep
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"jupyter==1.0.0",
"kiwisolver==1.4.5",
"llvmlite==0.42.0",
"loguru==0.7.3",
"matplotlib==3.8.3",
"multiprocess==0.70.16",
"networkx==3.2.1",
Expand Down Expand Up @@ -50,6 +51,11 @@ dependencies = [
"virtualenv==20.28.0",
"wandb==0.16.3",
"bert-score>=0.3.13",
"argparse>=1.4.0",
"huggingface-hub==0.23.4",
"loguru==0.7.3",
"fastapi>=0.109.0",
"uvicorn>=0.27.0",
]

[project.optional-dependencies]
Expand Down
80 changes: 80 additions & 0 deletions scripts/protein_retrieval_disease_pheno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from typing import Dict, Optional, Tuple, Union
from pathlib import Path

import argparse
from loguru import logger
import pandas as pd

from procyon.inference.retrieval_utils import startup_retrieval, do_retrieval

CKPT_NAME = os.path.expanduser(os.getenv("CHECKPOINT_PATH"))


def single_retrieval(
task_desc_infile: Path,
disease_desc_infile: Path,
instruction_source_dataset: str,
inference_bool: bool = True,
) -> Union[pd.DataFrame, None]:
"""
This function uses the pre-trained ProCyon model to perform one protein retrieval run
for a given disease using DisGeNET data.
Args:
task_desc_infile (Path): The path to the file containing the task description.
disease_desc_infile (Path): The path to the file containing the disease description.
instruction_source_dataset (str): Dataset source for instructions - either "disgenet" or "omim"
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference
Returns:
Optional[pd.DataFrame]: DataFrame with results if inference_bool is True, None otherwise
"""
model, device, data_args = startup_retrieval(inference_bool)

results_df = do_retrieval(
model=model,
data_args=data_args,
device=device,
inference_bool=inference_bool,
task_desc_infile=task_desc_infile,
disease_desc_infile=disease_desc_infile,
instruction_source_dataset=instruction_source_dataset,
)
if results_df is not None:
logger.info(f"top results: {results_df.head(10).to_dict(orient='records')}")

logger.info("DONE WITH ALL WORK")
return results_df


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_desc_infile",
type=str,
help="Description of the task.",
)
parser.add_argument(
"--disease_desc_infile",
type=str,
help="Description of the task.",
)
parser.add_argument(
"--inference_bool",
action="store_false",
help="OPTIONAL; choose this if you do not intend to do inference or load the model",
default=True,
)
parser.add_argument(
"--instruction_source_dataset",
type=str,
choices=["disgenet", "omim"],
help="Dataset source for instructions - either 'disgenet' or 'omim'",
)
args = parser.parse_args()

single_retrieval(
args.task_desc_infile,
args.disease_desc_infile,
args.inference_bool,
args.instruction_source_dataset,
)