-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconditional_generation.py
100 lines (78 loc) · 3.86 KB
/
conditional_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import random
import numpy as np
import pandas as pd
import torch
from rdkit.Chem import AllChem
from rdkit import Chem
import utils
from dataset import data_prep
import encoding_utils as eutils
import analysis_utils as autils
import generation_utils as g_utils
import VAE
import warnings
warnings.filterwarnings("ignore")
config = utils.get_config(print_dict = False)
seed = config["seed"]
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def seed_worker(seed):
worker_seed = torch.manual_seed(seed)
np.random.seed(worker_seed)
random.seed(worker_seed)
G = torch.Generator()
G.manual_seed(seed)
# loading of supervised learning dataset
dataset = pd.read_csv(config["original_dataset"])
# loading of unsupervised learning dataset
undataset = pd.read_csv(config["augmented_dataset"])
train_dataloader, test_dataloader = data_prep(dataset, undataset)
model = VAE.load_VAE(pretrained = True)
model.eval()
index_to_smile = data_prep.index_to_smile
ordinalenc = data_prep.ordinalenc
############################### interpolation
nc_idx, c_idx = VAE.extract_high_prob(data_prep.supervised_input, model, threshold = 0.8)
g_utils.interpolation_generation(lerp_type = "slerp-in", steps = config["lerp_steps"], threshold = config["lerp_threshold"], model = model,
data_input = data_prep.supervised_input,
nc_idx = nc_idx, c_idx = c_idx,
index_to_smile = index_to_smile, ordinalenc = ordinalenc,
dataset = dataset,
batch = config["intp_batch"],
current_batch = config["current_batch"],
log_file = config["intp_file"], use_filter = True)
############################### bayesian optimization
from scipy.stats import multivariate_normal
import json
bo_domain = g_utils.obtain_domain(data_prep.supervised_input, model)
def molecular_optimization(batch, model, start_mol_idx, bo_domain, log_file):
p_weight = 0.0025
x_mol, _ = model.input_to_latent(data_prep.supervised_input[start_mol_idx])
x_mol = x_mol.detach().cpu().numpy()
mvn = multivariate_normal(mean = np.zeros(128), cov = np.identity(128))
prior_x = mvn.logpdf(x_mol + 1e-9).reshape(-1, 1)
log_y = torch.log(model.latent_to_prob(torch.Tensor(x_mol).cuda())).detach().cpu().numpy()- p_weight * prior_x.reshape(-1,1)
if not os.path.exists('./log'):
os.makedirs('./log')
with open(log_file, 'w') as bo_log:
bo_log.write("Gaussian Process Starting Point" + "\n" +
"Initial bb1 smile: " + str(dataset['bb1_smile'][start_mol_idx]) + "\n" +
"Initial bb2 smile: " + str(dataset['bb2_smile'][start_mol_idx]) + "\n" +
"Initial reaction: " + str(dataset['reaction'][start_mol_idx]) + "\n" )
with open(log_file, 'a') as bo_log:
bo_log.write("Initial X: " + "\n")
json.dump(x_mol.tolist(), bo_log)
bo_log.write("acquisition prior weight: " + str(p_weight) + "\n")
bo_log.write("=============================================================================" + "\n\n")
g_utils.multi_bo(batch = batch, start_x = x_mol, start_y = log_y,
model = model, bo_domain = bo_domain,
dataset = dataset,
index_to_smile = index_to_smile, ordinalenc = ordinalenc,
log_file = log_file, use_filter = True)
molecular_optimization(batch = config["bo_batch"], start_mol_idx = config["initial_idx"], log_file = config["bo_file"])