-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathselfpad_eval.py
250 lines (203 loc) · 7.51 KB
/
selfpad_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
Author: Talip Ucar
email: ucabtuc@gmail.com or talip.ucar@astrazeneca.com
Description: A sample script to score antibody sequences for humanness.
"""
import copy
import logging
import os
from typing import Any, Dict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
from src.selfpad_humanness import PADFintune
from utils_common.utils import set_dirs
from utils_common.arguments import get_arguments, get_config
from utils_finetune.load_data_eval import PADLoader
from torcheval.metrics.functional import binary_auprc
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from tqdm import tqdm
# Configure the logging level
logging.basicConfig(level=logging.INFO)
def eval(data_loader, config: Dict[str, Any]) -> None:
"""
Fine-tunes a Large Protein Model (LPM) with a new classification head and evaluates it on the test set.
Parameters
----------
data_loader : IterableDataset
PyTorch data loader.
config : Dict[str, Any]
Dictionary containing configuration options and arguments.
"""
# Set the random seed to make the training deterministic
pl.seed_everything(seed=config["seed"])
# Initialize the model
model = PADFintune(config)
# Fit the model to the data
model.load_models()
transformer = model.transformer
transformer.to(config["device"])
transformer.eval()
transformer.config["add_noise"] = False
val_loss = []
preds_l = []
preds_raw_l = []
labels_l = []
embs_l = []
species_l = []
raw_seqs_l = []
embs_ext_l = []
x1_l = []
test_loader = data_loader.test_loader
# Attach progress bar to data_loader
test_tqdm = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
with torch.no_grad():
# pass data through model
for batch in test_tqdm:
x1, x1aa, y1, raw_seqs = batch[1]
labels = y1.reshape(
-1,
)
preds, h, h_ext = transformer(x1, x1aa, agg_dim=model.agg_dim)
# Generate labels and predictions
preds_raw = preds.cpu()
preds_idx = preds[:, 1].cpu()
# Save the results
preds_l.append(preds_idx)
preds_raw_l.append(preds_raw)
labels_l.append(labels)
embs_l.append(h)
embs_ext_l.append(h_ext)
x1_l.append(x1)
raw_seqs_l.extend(raw_seqs)
# loss = model.ce_loss(preds, labels.to(model.device))
# batch_loss = loss.item()
# val_loss.append(batch_loss)
# del loss
preds_l = torch.cat(preds_l)
preds_raw_l = torch.cat(preds_raw_l)
labels_l = torch.cat(labels_l)
embeddings = torch.cat(embs_l)
embs_ext = torch.cat(embs_ext_l)
x1_cat = torch.cat(x1_l)
preds_l_single = copy.deepcopy(preds_l)
preds_raw_l_single = copy.deepcopy(preds_l)
labels_l_single = copy.deepcopy(labels_l)
preds_l_single[preds_l_single >= config["threshold"]] = 1
preds_l_single[preds_l_single < config["threshold"]] = 0
# Compute metrics
f1 = 100 * model.f1_score(preds_l_single, labels_l_single)
acc = 100 * model.acc_score(preds_l_single, labels_l_single)
recall = 100 * model.recall_score(preds_l_single, labels_l_single)
auc = 100 * model.roc_auc(preds_raw_l_single, labels_l_single)
prec = 100 * model.precision_score(preds_l_single, labels_l_single)
pr_auc = 100 * binary_auprc(preds_raw_l_single, labels_l_single)
summary_dict_single = {
"F1": f1.numpy(),
"Recall": recall.numpy(),
"Precision": prec.numpy(),
"ROC AUC": auc.numpy(),
"Accuracy": acc.numpy(),
"PR AUC": pr_auc.numpy(),
}
summary_df_single = pd.DataFrame(
dict([(k, pd.Series(v)) for k, v in summary_dict_single.items()])
)
print("=================================")
print("Single chain performance")
print(summary_df_single)
# Results for paired sequence
cutoff = config["cutoff"]
preds_l1 = preds_l[:cutoff]
preds_l2 = preds_l[cutoff:]
thresh = config["threshold"]
preds_l1[preds_l1 >= thresh] = 1
preds_l1[preds_l1 < thresh] = 0
preds_l2[preds_l2 >= thresh] = 1
preds_l2[preds_l2 < thresh] = 0
preds_l = (preds_l1 + preds_l2) / 2
preds_raw_l = (preds_raw_l[:cutoff] + preds_raw_l[cutoff:]) / 2
preds_raw_l = preds_raw_l[:, 1]
preds_l[preds_l >= 0.5] = 1
labels_l = labels_l[:cutoff]
# Compute metrics
f1 = 100 * model.f1_score(preds_l, labels_l)
acc = 100 * model.acc_score(preds_l, labels_l)
recall = 100 * model.recall_score(preds_l, labels_l)
auc = 100 * model.roc_auc(preds_raw_l, labels_l)
prec = 100 * model.precision_score(preds_l, labels_l)
pr_auc = 100 * binary_auprc(preds_raw_l, labels_l)
labels_l = (
labels_l.reshape(
-1,
)
.cpu()
.numpy()
.tolist()
)
# Save the scores as csv file
raw_seqs_l_h = raw_seqs_l[:cutoff]
raw_seqs_l_l = raw_seqs_l[:cutoff]
scores_df = pd.DataFrame({
"Heavy": raw_seqs_l_h,
"Light": raw_seqs_l_l,
"Scores": preds_raw_l,
"Prediction": preds_l,
"TrueLabel": labels_l,
})
filename = f"{config['dataset']}_scores.csv"
file_path = os.path.join(config["inference_dir"], filename)
scores_df.to_csv(file_path)
print("=================================")
print(f"Raw scores are saved as csv file at: {file_path}")
return (
f1.numpy(),
recall.numpy(),
prec.numpy(),
auc.numpy(),
acc.numpy(),
pr_auc.numpy(),
)
def main():
# Get parser / command line arguments for pre-trained model
args = get_arguments(humanness=False)
# Get configuration file
config = get_config(args)
# Get parser / command line arguments for fine-tuning the model
args = get_arguments(humanness=True)
# Get configuration file
config_finetune = get_config(args)
# Update the config with options for fine-tuning
config.update(**config_finetune)
config["add_noise"] = False
config["experiment"] = "humanness"
config["num_workers"] = 0
# Ser directories (or create if they don't exist)
set_dirs(config)
# Get data loader.
data_loader = PADLoader(config, drop_last=True, is_training=True)
f1, recall, prec, auc, acc, pr_auc = eval(data_loader, config=config)
summary_dict = {
"F1": [f1],
"Recall": [recall],
"Precision": [prec],
"Accuracy": [acc],
"ROC AUC": [auc],
"PR AUC": [pr_auc],
}
# Save the results
summary_df_paired = pd.DataFrame(
dict([(k, pd.Series(v)) for k, v in summary_dict.items()])
)
filename = f"{config['dataset']}_summary.csv"
file_path = os.path.join(config["inference_dir"], filename)
summary_df_paired.to_csv(file_path)
print("Paired sequence performance")
print(summary_df_paired)
print("=================================")
print(f"Summary of evaluation metrics are saved as csv file at: {file_path}")
if __name__ == "__main__":
main()