diff --git a/datasail/__main__.py b/datasail/__main__.py index 94cdad8..4f19932 100644 --- a/datasail/__main__.py +++ b/datasail/__main__.py @@ -1,5 +1,13 @@ -from datasail.sail import sail +import os + +#N_THREADS = "1" +#os.environ["OPENBLAS_NUM_THREADS"] = N_THREADS +#os.environ["OPENBLAS_MAX_THREADS"] = N_THREADS +#os.environ["GOTO_NUM_THREADS"] = N_THREADS +#os.environ["OMP_NUM_THREADS"] = N_THREADS +os.environ["GRB_LICENSE_FILE"] = "/home/rjo21/gurobi_mickey.lic" +from datasail.sail import sail if __name__ == '__main__': sail() diff --git a/datasail/cluster/clustering.py b/datasail/cluster/clustering.py index ec5043a..21735aa 100644 --- a/datasail/cluster/clustering.py +++ b/datasail/cluster/clustering.py @@ -73,7 +73,7 @@ def cluster(dataset: DataSet, **kwargs) -> DataSet: if len(dataset.cluster_names) > dataset.num_clusters: dataset = force_clustering(dataset, kwargs[KW_LINKAGE]) - store_to_cache(dataset, **kwargs) + # store_to_cache(dataset, **kwargs) return dataset @@ -212,6 +212,7 @@ def additional_clustering( ) # cluster the clusters into new, fewer, and bigger clusters labels = ca.fit_predict(cluster_matrix) + LOGGER.info("Clustering finished") return labels2clusters(labels, dataset, cluster_matrix, linkage) diff --git a/datasail/cluster/foldseek.py b/datasail/cluster/foldseek.py index 0517a6a..a6e4ebd 100644 --- a/datasail/cluster/foldseek.py +++ b/datasail/cluster/foldseek.py @@ -2,8 +2,12 @@ import shutil from pathlib import Path from typing import Optional +import pickle import numpy as np +from pyarrow import compute, csv +from collections import defaultdict +from tqdm import tqdm from datasail.parsers import MultiYAMLParser from datasail.reader.utils import DataSet @@ -23,25 +27,26 @@ def run_foldseek(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = N raise ValueError("Foldseek is not installed.") user_args = MultiYAMLParser(FOLDSEEK).get_user_arguments(dataset.args, []) - results_folder = Path("fs_results") + results_folder = Path("/scratch/SCRATCH_SAS/roman/DataSAIL/fs_results") - tmp = Path("tmp") + tmp = Path("/scratch/SCRATCH_SAS/roman/DataSAIL/fs_tmp") tmp.mkdir(parents=True, exist_ok=True) - for name, filepath in dataset.data.items(): - shutil.copy(filepath, tmp) + ##for name in dataset.names: + ## shutil.copy(dataset.data[name], tmp) cmd = f"mkdir {results_folder} && " \ f"cd {results_folder} && " \ f"foldseek " \ f"easy-search " \ - f"../tmp " \ - f"../tmp " \ + f"{str(tmp.resolve())} " \ + f"{str(tmp.resolve())} " \ f"aln.m8 " \ f"tmp " \ - f"--format-output 'query,target,fident' " \ + f"--format-output 'query,target,fident,qlen,lddt' " \ f"-e inf " \ f"--threads {threads} " \ f"{user_args}" # && " \ + # f"--exhaustive-search 1 " \ # f"rm -rf ../tmp" if log_dir is None: @@ -54,27 +59,91 @@ def run_foldseek(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = N LOGGER.info("Start FoldSeek clustering") LOGGER.info(cmd) - os.system(cmd) - - if not (results_folder / "aln.m8").exists(): - raise ValueError("Something went wrong with foldseek. The output file does not exist.") - - namap = dict((n, i) for i, n in enumerate(dataset.names)) - cluster_sim = np.zeros((len(dataset.names), len(dataset.names))) - with open(f"{results_folder}/aln.m8", "r") as data: - for line in data.readlines(): - q1, q2, sim = line.strip().split("\t")[:3] - if "_" in q1 and "." in q1 and q1.rindex("_") > q1.index("."): - q1 = "_".join(q1.split("_")[:-1]) - if "_" in q2 and "." in q2 and q2.rindex("_") > q2.index("."): - q2 = "_".join(q2.split("_")[:-1]) - q1 = q1.replace(".pdb", "") - q2 = q2.replace(".pdb", "") - cluster_sim[namap[q1], namap[q2]] = sim - cluster_sim[namap[q2], namap[q1]] = sim + ##os.system(cmd) + + ##if not (results_folder / "aln.m8").exists(): + ## raise ValueError("Something went wrong with foldseek. The output file does not exist.") + + ##ds = read_with_pyarrow(f"{results_folder}/aln.m8") + #with open("/scratch/SCRATCH_SAS/roman/DataSAIL/pyarrow.pkl", "rb") as data: + # ds = pickle.load(data) + + #try: + #except Exception as e: + # print("pickling failed due to:", e) + # namap = dict((n, i) for i, n in enumerate(dataset.names)) + ##cluster_sim = np.zeros((len(dataset.names), len(dataset.names))) + #with open(f"{results_folder}/aln.m8", "r") as data: + # for line in data.readlines(): + # q1, q2, sim = line.strip().split("\t")[:3] + # if "_" in q1 and "." in q1 and q1.rindex("_") > q1.index("."): + # q1 = "_".join(q1.split("_")[:-1]) + # if "_" in q2 and "." in q2 and q2.rindex("_") > q2.index("."): + # q2 = "_".join(q2.split("_")[:-1]) + # q1 = q1.replace(".pdb", "") + # q2 = q2.replace(".pdb", "") + # cluster_sim[namap[q1], namap[q2]] = sim + # cluster_sim[namap[q2], namap[q1]] = sim + # print("Additional names:", set(dataset.names).difference(set(ds.keys()))) + # print("Additional hits:", set(ds.keys()).difference(set(dataset.names))) + # exit(0) + ##for i, name1 in enumerate(dataset.names): + ## cluster_sim[i, i] = 1 + ## for j, name2 in enumerate(dataset.names[i + 1:]): + ## if name2 in ds[name1]: + ## cluster_sim[i, j] = ds[name1][name2][2] / ds[name1][name2][3] + ## if name1 in ds[name2]: + ## cluster_sim[j, i] = ds[name2][name1][2] / ds[name2][name1][3] + ##cluster_sim = (cluster_sim + cluster_sim.T) / 2 + + # with open("/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/prot_sim_full_v12.pkl", "wb") as out: + # pickle.dump(ds, out) + with open("/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/eval/full_v0/prots.pkl", "rb") as f: + ds = pickle.load(f) shutil.rmtree(results_folder, ignore_errors=True) + shutil.rmtree(tmp, ignore_errors=True) dataset.cluster_names = dataset.names dataset.cluster_map = dict((n, n) for n in dataset.names) - dataset.cluster_similarity = cluster_sim + dataset.cluster_similarity = ds.cluster_similarity ## cluster_sim + + +def extract(tmp): + if len(tmp) == 1: + return tmp[0], "?" + else: + return "_".join(tmp[:-1]), tmp[-1] + + +def inner_list(): + return ["", "", 0, 0] + + +def outer_dict(): + return defaultdict(inner_list) + + +def read_with_pyarrow(file_path): + table = csv.read_csv( + file_path, + read_options=csv.ReadOptions(use_threads=True, column_names=["qid_chainid", "tid_chainid", "fident", "qlen", "lddt"]), + parse_options=csv.ParseOptions(delimiter="\t"), + ) + + indices = compute.sort_indices(table, [("lddt", "descending"), ("fident", "descending")]) + ds = defaultdict(outer_dict) + for idx in tqdm(indices): + q_id, q_chain = extract(table["qid_chainid"][idx.as_py()].as_py().split("_")) + t_id, t_chain = extract(table["tid_chainid"][idx.as_py()].as_py().split("_")) + record = ds[q_id][t_id] + if q_chain in record[0] or t_chain in record[1]: + continue + fident = table["fident"][idx.as_py()].as_py() + q_len = table["qlen"][idx.as_py()].as_py() + record[0] += q_chain + record[1] += t_chain + record[2] += fident * q_len + record[3] += q_len + return ds + diff --git a/datasail/cluster/mash.py b/datasail/cluster/mash.py index e11f05c..a7dc19e 100644 --- a/datasail/cluster/mash.py +++ b/datasail/cluster/mash.py @@ -58,6 +58,7 @@ def run_mash(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = None) dataset.cluster_names = dataset.names shutil.rmtree(results_folder, ignore_errors=True) + shutil.rmtree(tmp, ignore_errors=True) def read_mash_tsv(filename: Path, num_entities: int) -> np.ndarray: diff --git a/datasail/reader/read_molecules.py b/datasail/reader/read_molecules.py index 25711a4..80e4a9d 100644 --- a/datasail/reader/read_molecules.py +++ b/datasail/reader/read_molecules.py @@ -95,7 +95,7 @@ def remove_molecule_duplicates(dataset: DataSet) -> DataSet: dataset: The dataset to remove duplicates from Returns: - Update arguments as teh location of the data might change and an ID-Map file might be added. + Update arguments as the location of the data might change and an ID-Map file might be added. """ if isinstance(dataset.data[dataset.names[0]], (list, tuple, np.ndarray)): # TODO: proper check for duplicate embeddings diff --git a/datasail/reader/utils.py b/datasail/reader/utils.py index af3aa8e..edaebe5 100644 --- a/datasail/reader/utils.py +++ b/datasail/reader/utils.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, fields from pathlib import Path from typing import Generator, Tuple, List, Optional, Dict, Union, Any, Callable, Iterable, Set +from collections.abc import Iterable import h5py import numpy as np @@ -124,6 +125,9 @@ def strat2oh(self, name: Optional[str] = None, classes: Optional[Union[str, Set[ classes = [classes] if self.classes is not None: # print(name, self.class_oh[[self.classes[class_] for class_ in classes]].sum(axis=0)) + # print("classes", classes) + # print("self.cl", self.classes) + # print("self.oh", self.class_oh) return self.class_oh[[self.classes[class_] for class_ in classes]].sum(axis=0) return None @@ -287,7 +291,8 @@ def read_data( elif isinstance(weights, Generator): dataset.weights = dict(weights) elif inter is not None: - dataset.weights = dict(count_inter(inter, index)) + dataset.weights = {k: 0 for k in dataset.data.keys()} + dataset.weights.update(dict(count_inter(inter, index))) else: dataset.weights = {k: 1 for k in dataset.data.keys()} diff --git a/datasail/routine.py b/datasail/routine.py index f826308..a96236d 100644 --- a/datasail/routine.py +++ b/datasail/routine.py @@ -1,4 +1,5 @@ import time +import pickle from typing import Dict, Tuple, Optional from datasail.argparse_patch import remove_patch @@ -6,7 +7,7 @@ from datasail.reader.read import read_data from datasail.reader.utils import DataSet from datasail.report import report -from datasail.settings import LOGGER, KW_TECHNIQUES, KW_EPSILON, KW_RUNS, KW_SPLITS, KW_NAMES, \ +from datasail.settings import LOGGER, KW_INTER, KW_TECHNIQUES, KW_EPSILON, KW_RUNS, KW_SPLITS, KW_NAMES, \ KW_MAX_SEC, KW_MAX_SOL, KW_SOLVER, KW_LOGDIR, NOT_ASSIGNED, KW_OUTDIR, MODE_E, MODE_F, DIM_2, SRC_CL, KW_DELTA, \ KW_E_CLUSTERS, KW_F_CLUSTERS, KW_CC, CDHIT, INSTALLED, FOLDSEEK, TMALIGN, CDHIT_EST, DIAMOND, MMSEQS, MASH from datasail.solver.solve import run_solver @@ -40,19 +41,30 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]: LOGGER.info("Read data") # read e-entities and f-entities - e_dataset, f_dataset, inter = read_data(**kwargs) + e_dataset, f_dataset_tmp, inter = read_data(**kwargs) # if required, cluster the input otherwise define the cluster-maps to be None clusters = list(filter(lambda x: x[0].startswith(SRC_CL), kwargs[KW_TECHNIQUES])) cluster_e = len(clusters) != 0 and any(c[-1] in {DIM_2, MODE_E} for c in clusters) cluster_f = len(clusters) != 0 and any(c[-1] in {DIM_2, MODE_F} for c in clusters) - if cluster_e: - LOGGER.info("Cluster first set of entities.") - e_dataset = cluster(e_dataset, **kwargs) - if cluster_f: - LOGGER.info("Cluster second set of entities.") - f_dataset = cluster(f_dataset, **kwargs) + #if cluster_e: + # LOGGER.info("Cluster first set of entities.") + # e_dataset = cluster(e_dataset, **kwargs) + #if cluster_f: + # LOGGER.info("Cluster second set of entities.") + # f_dataset = cluster(f_dataset, **kwargs) + + split = str(kwargs[KW_INTER]).split("/")[-2] + #with open(f"/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/{split}.pkl", "wb") as f: + # pickle.dump((e_dataset, f_dataset), f) + with open(f"/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/{split}.pkl", "rb") as f: + e_dataset, f_dataset = pickle.load(f) + f_dataset.id_map = f_dataset_tmp.id_map + + #print("E_ID_Map is None:", e_dataset.id_map is None) + #print("F_ID_Map is None:", f_dataset.id_map is None) + #print("Nones in inter :", sum([x is None for x in inter])) if inter is not None: if e_dataset.type is not None and f_dataset.type is not None: @@ -88,6 +100,12 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]: LOGGER.info("Store results") + #print("E name:", e_name_split_map.keys()) + #print("F name:", f_name_split_map.keys()) + #print("E cluster:", e_cluster_split_map.keys()) + #print("F cluster:", f_cluster_split_map.keys()) + #print("Inter:", inter_split_map.keys()) + # infer interaction assignment from entity assignment if necessary and possible output_inter_split_map = dict() if new_inter is not None: @@ -95,6 +113,9 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]: output_inter_split_map[technique] = [] for run in range(kwargs[KW_RUNS]): output_inter_split_map[technique].append(dict()) + #print(e_name_split_map.keys()) + #print(f_name_split_map.keys()) + #print(techique) for e, f in inter: if technique.endswith(DIM_2) or technique == "R": output_inter_split_map[technique][-1][(e, f)] = inter_split_map[technique][run].get( diff --git a/datasail/solver/cluster_1d.py b/datasail/solver/cluster_1d.py index bc79b96..941158d 100644 --- a/datasail/solver/cluster_1d.py +++ b/datasail/solver/cluster_1d.py @@ -64,6 +64,7 @@ def solve_c1( loss = cvxpy.sum([t for tmp_list in tmp for t in tmp_list]) problem = solve(loss, constraints, max_sec, solver, log_file) + print(problem) return None if problem is None else { e: names[s] for s in range(len(splits)) for i, e in enumerate(clusters) if x[s, i].value > 0.1 diff --git a/datasail/solver/cluster_2d.py b/datasail/solver/cluster_2d.py index cd71fd7..8ad4e65 100644 --- a/datasail/solver/cluster_2d.py +++ b/datasail/solver/cluster_2d.py @@ -2,9 +2,10 @@ from pathlib import Path import cvxpy import numpy as np +from scipy.optimize import fsolve from datasail.solver.utils import solve, interaction_contraints, collect_results_2d, leakage_loss, compute_limits, \ - stratification_constraints + stratification_constraints, collect_results_2d2 def solve_c2( @@ -57,39 +58,87 @@ def solve_c2( A list of interactions and their assignment to a split and two mappings from entities to splits, one for each dataset """ - min_lim = compute_limits(epsilon, int(np.sum(inter)), [s / 2 for s in splits]) + if splits[0] == 0.771: # time_v0 + splits = [0.5650, 0.2175, 0.2175] + elif splits[0] == 0.889: # pl50_v0 + splits = [2/3, 1/6, 1/6] + elif splits[0] == 0.776: # ecod_v0 + splits = [0.5685, 0.2254, 0.2061] + elif splits[0] == 0.898: # pl50_v1 + splits = [0.6773, 0.1645, 0.1582] + elif splits[0] == 0.8: # full splits + splits = [0.5858, 0.2071, 0.2071] + else: + splits = convert(splits) + min_lim = compute_limits(epsilon, int(np.sum(inter)), splits) # [s / 2 for s in splits]) + #np.set_printoptions(threshold=np.inf) + #print(np.sum(inter)) + #print(e_weights) + #print(np.sum(e_weights)) + #print(f_weights) + #print(np.sum(f_weights)) + #print(min_lim) + #np.set_printoptions(threshold=np.inf) + #print(inter) x_e = cvxpy.Variable((len(splits), len(e_clusters)), boolean=True) x_f = cvxpy.Variable((len(splits), len(f_clusters)), boolean=True) - x_i = {(e, f): cvxpy.Variable(len(splits), boolean=True) for e in range(len(e_clusters)) for f in - range(len(f_clusters)) if inter[e, f] != 0} + #x_i = {(e, f): cvxpy.Variable(len(splits), boolean=True) for e in range(len(e_clusters)) for f in + # range(len(f_clusters)) if inter[e, f] != 0} # check if the cluster relations are uniform e_intra_weights = e_similarities if e_similarities is not None else 1 - e_distances f_intra_weights = f_similarities if f_similarities is not None else 1 - f_distances - e_uniform = e_intra_weights is None or np.allclose(e_intra_weights, np.ones_like(e_intra_weights)) or \ - np.allclose(e_intra_weights, np.zeros_like(e_intra_weights)) - f_uniform = f_intra_weights is None or np.allclose(f_intra_weights, np.ones_like(f_intra_weights)) or \ - np.allclose(f_intra_weights, np.zeros_like(f_intra_weights)) + #e_uniform = e_intra_weights is None or np.allclose(e_intra_weights, np.ones_like(e_intra_weights)) or \ + # np.allclose(e_intra_weights, np.zeros_like(e_intra_weights)) + #f_uniform = f_intra_weights is None or np.allclose(f_intra_weights, np.ones_like(f_intra_weights)) or \ + # np.allclose(f_intra_weights, np.zeros_like(f_intra_weights)) - def index(x, y): - return (x, y) if (x, y) in x_i else None + #def index(x, y): + # return (x, y) if (x, y) in x_i else None constraints = [ - cvxpy.sum(x_e, axis=0) == np.ones((len(e_clusters))), - cvxpy.sum(x_f, axis=0) == np.ones((len(f_clusters))), + cvxpy.sum(x_e, axis=0) == np.ones((len(e_clusters)), dtype=int), + cvxpy.sum(x_f, axis=0) == np.ones((len(f_clusters)), dtype=int), ] + for s, split in enumerate(splits): + constraints.append(min_lim[s] <= cvxpy.sum(cvxpy.multiply(x_e[s], e_weights))) + constraints.append(min_lim[s] <= cvxpy.sum(cvxpy.multiply(x_f[s], f_weights))) - if e_s_matrix is not None: - constraints.append(stratification_constraints(e_s_matrix, [s / 2 for s in splits], delta / 2, x_e)) - if f_s_matrix is not None: - constraints.append(stratification_constraints(f_s_matrix, [s / 2 for s in splits], delta / 2, x_f)) + #print(len(constraints)) + #if e_s_matrix is not None: + # constraints.append(stratification_constraints(e_s_matrix, [s / 2 for s in splits], delta / 2, x_e)) + #if f_s_matrix is not None: + # constraints.append(stratification_constraints(f_s_matrix, [s / 2 for s in splits], delta / 2, x_f)) + #print(len(constraints)) + #exit(0) - interaction_contraints(e_clusters, f_clusters, x_i, constraints, splits, x_e, x_f, min_lim, lambda key: inter[key], - index) - - e_loss = leakage_loss(e_uniform, e_intra_weights, x_e, e_clusters, e_weights, len(splits)) - f_loss = leakage_loss(f_uniform, f_intra_weights, x_f, f_clusters, f_weights, len(splits)) + #interaction_contraints(e_clusters, f_clusters, x_i, constraints, splits, x_e, x_f, min_lim, lambda key: inter[key], + # index) + e_tmp = [[e_weights[e1] * e_weights[e2] * e_intra_weights[e1, e2] * cvxpy.max( + cvxpy.vstack([x_e[s, e1] - x_e[s, e2] for s in range(len(splits))]) + ) for e2 in range(e1 + 1, len(e_clusters))] for e1 in range(len(e_clusters))] + f_tmp = [[f_weights[f1] * f_weights[f2] * f_intra_weights[f1, f2] * cvxpy.max( + cvxpy.vstack([x_f[s, f1] - x_f[s, f2] for s in range(len(splits))]) + ) for f2 in range(f1 + 1, len(f_clusters))] for f1 in range(len(f_clusters))] + e_loss = cvxpy.sum([e for e_tmp_list in e_tmp for e in e_tmp_list]) # leakage_loss(e_uniform, e_intra_weights, x_e, e_clusters, e_weights, len(splits)) + f_loss = cvxpy.sum([f for f_tmp_list in f_tmp for f in f_tmp_list]) # leakage_loss(f_uniform, f_intra_weights, x_f, f_clusters, f_weights, len(splits)) problem = solve(e_loss + f_loss, constraints, max_sec, solver, log_file) - return collect_results_2d(problem, names, splits, e_clusters, f_clusters, x_e, x_f, x_i, index) + #return collect_results_2d(problem, names, splits, e_clusters, f_clusters, x_e, x_f, x_i, index) + return collect_results_2d2(problem, names, splits, e_clusters, f_clusters, x_e, x_f, inter) + + +def func(x, targets): + denom = sum([a ** 2 for a in x]) + return [(x[i] ** 2 / denom - targets[i]) for i in range(len(x))] + + +def convert(targets): + targets = [t / sum(targets) for t in targets] + sol = fsolve( + lambda x: func(x, targets), + [1 / len(targets) for _ in targets] + ) + return [s / sum(sol) for s in sol] + diff --git a/datasail/solver/utils.py b/datasail/solver/utils.py index 1204f31..a49a2f6 100644 --- a/datasail/solver/utils.py +++ b/datasail/solver/utils.py @@ -132,6 +132,8 @@ def solve( f"The problem has {sum([functools.reduce(operator.mul, v.shape, 1) for v in problem.variables()])} variables " f"and {sum([functools.reduce(operator.mul, c.shape, 1) for c in problem.constraints])} constraints.") + print(max_sec) + if solver == SOLVER_CBC: kwargs = { "maximumSeconds": max_sec, @@ -191,16 +193,17 @@ def solve( LOGGER.info(f"{solver} status: {problem.status}") LOGGER.info(f"Solution's score: {problem.value}") - if "optimal" not in problem.status: + if problem.status in {"infeasible", "unbound"}: LOGGER.warning( f'{solver} cannot solve the problem. Please consider relaxing split restrictions, ' 'e.g., less splits, or a higher tolerance level for exceeding cluster limits.' ) return None return problem - except KeyError: - LOGGER.warning(f"Solving failed for {''}. Please use try another solver or update your python version.") - return None + except Exception as e: + raise e + # LOGGER.warning(f"Solving failed for {''}. Please use try another solver or update your python version.") + # return None def sample_categorical( @@ -372,6 +375,51 @@ def collect_results_2d( return output +def collect_results_2d2( + problem: cvxpy.Problem, + names: List[str], + splits: List[float], + e_entities: List[str], + f_entities: List[str], + x_e: Variable, + x_f: Variable, + inter: np.ndarray, +) -> Optional[Tuple[Dict[Tuple[str, str], str], Dict[object, str], Dict[object, str]]]: + """ + Report the found solution for two-dimensional splits. + + Args: + problem: Problem object after solving. + names: List of names of the splits. + splits: List of the relative sizes of the splits. + e_entities: List of names of entities in the e-dataset. + f_entities: List of names of entities in the f-dataset. + x_e: Optimization variables for the e-dataset. + x_f: Optimization variables for the f-dataset. + + Returns: + A list of interactions and their assignment to a split and two mappings from entities to splits, one for each + """ + if problem is None: + return None + + # report the found solution + output = ( + {}, + {e: names[s] for s in range(len(splits)) for i, e in enumerate(e_entities) if x_e[s, i].value > 0.1}, + {f: names[s] for s in range(len(splits)) for j, f in enumerate(f_entities) if x_f[s, j].value > 0.1}, + ) + for i, e in enumerate(e_entities): + for j, f in enumerate(f_entities): + if inter[i, j] == 0: + continue + if output[1][e] == output[2][f]: + output[0][e, f] = output[1][e] + else: + output[0][e, f] = NOT_ASSIGNED + return output + + def leakage_loss( uniform: bool, intra_weights, @@ -404,3 +452,37 @@ def leakage_loss( ] for e1 in range(len(clusters))] loss = cvxpy.sum([t for tmp_list in tmp for t in tmp_list]) return loss + + +#def leakage_loss( +# uniform: bool, +# intra_weights, +# x, +# clusters, +# weights, +# num_splits: int, +#) -> Union[int, cvxpy.Expression]: +# """ +# Compute the leakage loss for the cluster-based double-cold splitting. +# +# Args: +# uniform: Boolean flag if the cluster metric is uniform +# intra_weights: Weights of the intra-cluster edges +# x: Variables of the optimization problem +# clusters: List of cluster names +# weights: Weights of the clusters +# num_splits: Number of splits +# +# Returns: +# Loss describing the leakage between clusters +# """ +# if uniform: +# return 0 +# else: +# tmp = [[ +# weights[e1] * weights[e2] * intra_weights[e1, e2] * cvxpy.max( +# cvxpy.vstack([x[s, e1] - x[s, e2] for s in range(num_splits)]) +# ) for e2 in range(e1 + 1, len(clusters)) +# ] for e1 in range(len(clusters))] +# loss = cvxpy.sum([t for tmp_list in tmp for t in tmp_list]) +# return loss diff --git a/experiments/DTI/visualize.py b/experiments/DTI/visualize.py index 4d19486..58417b7 100644 --- a/experiments/DTI/visualize.py +++ b/experiments/DTI/visualize.py @@ -7,6 +7,7 @@ import pandas as pd import umap import matplotlib +import cairosvg from matplotlib.colors import LinearSegmentedColormap from matplotlib import pyplot as plt, gridspec, cm, colors as mpl_colors from matplotlib.lines import Line2D @@ -81,36 +82,29 @@ def comp_il(base_path: Path): else: root = base / "deepchem" - if tech in ["R", "I2", "C2"]: - dss = [(lig_dataset, "_lig"), (tar_dataset, "_tar")] - elif tech in ["I1e", "C1e", "lohi", "Butina", "Fingerprint", "MaxMin", "Scaffold", "Weight"]: - dss = [(lig_dataset, "_lig")] - elif tech in ["I1f", "C1f", "graphpart"]: - dss = [(tar_dataset, "_tar")] - else: - print(f"Unknown technique: {tech}") - continue - - for ds, n in dss: + for ds, n in [(lig_dataset, "_lig"), (tar_dataset, "_tar")]: + print(tech, n) name = tech + n - if name not in output: - output[name] = [] - elif tech not in TECHNIQUES["datasail"]: + if name in output: continue + output[name] = [] for run in range(5): print(name, run, end="\t") base = root / tech / f"split_{run}" train_ids = pd.read_csv(base / "train.csv")["ids"] test_ids = pd.read_csv(base / "test.csv")["ids"] - assi = np.array( - [1 if x in train_ids.values else -1 if x in test_ids.values else 0 for x in ds.cluster_names]) - il, total = david.eval( - assi.reshape(-1, 1), - ds.cluster_similarity, - [ds.cluster_weights[c] for c in ds.cluster_names], - ) - print(il) - output[name].append((il, total)) + assi = np.array([1 if x in train_ids.values else -1 if x in test_ids.values else 0 for x in ds.cluster_names]).reshape(-1, 1) + mask = assi @ assi.T + mask = -mask + mask[mask == -1] = 0 + output[name].append(np.sum(mask * ds.cluster_similarity)) + #il, total = david.eval( + # assi.reshape(-1, 1), + # ds.cluster_similarity, + # [ds.cluster_weights[c] for c in ds.cluster_names], + #) + print(output[name][-1]) + # output[name].append(il) with open(leak_path, "wb") as f: pickle.dump(output, f) @@ -358,12 +352,8 @@ def viz_sl_models( l = pickle.load(f) def leakage(tech): - if tech in ["R", "I2", "C2"]: - return [[(l[f"{tech}_lig"][i][j] + l[f"{tech}_tar"][i][j] / 2) for j in range(len(l[f"{tech}_lig"][i]))] - for i in range(5)] - if tech in ["I1f", "C1f", "graphpart"]: - return l[f"{tech}_tar"] - return l[f"{tech}_lig"] + return [(l[f"{tech}_lig"][i] + l[f"{tech}_tar"][i]) / (l["lig_sim"] + l["tar_sim"]) for i in range(5)] + # return [[(l[f"{tech}_lig"][i][j] + l[f"{tech}_tar"][i][j]) for j in range(len(l[f"{tech}_lig"][i]))] for i in range(5)] for s, (tool, _, t) in enumerate(techniques): for model in models: @@ -378,7 +368,7 @@ def leakage(tech): ax = fig.add_subplot(gs) df = pd.DataFrame(np.array(values).T, columns=[x[1] for x in techniques], index=models) c_map = {"I1f": "I1e", "C1f": "C1e", "R": "0d"} - df.loc["Splits"] = [np.average([x for x, _ in leakage(tech)]) for _, _, tech in techniques] + df.loc["Splits"] = [np.average(leakage(tech)) for _, _, tech in techniques] il = plot_bars_2y(df.T, ax, color=[COLORS[c_map.get(x[2], x[2]).lower()] for x in techniques]) ax.set_ylabel("RMSE (↓)") ax.set_xlabel("ML Models") @@ -386,18 +376,20 @@ def leakage(tech): il.legend(loc=legend, ncol=ncol, framealpha=1) ax.set_title("Performance comparison") ax.set_xlabel("ML Models") + set_subplot_label(ax, fig, label) elif ptype == "htm": + il = None gs_main = gs.subgridspec(1, 2, width_ratios=[15, 1], wspace=0.4) ax = fig.add_subplot(gs_main[1]) values = np.array(values, dtype=float) - leak = np.array([np.average([x for x, _ in leakage(tech)]) for _, _, tech in techniques]).reshape(-1, 1) + leak = np.array([np.average(leakage(tech)) for _, _, tech in techniques]).reshape(-1, 1) cmap = LinearSegmentedColormap.from_list("Custom", [COLORS["train"], COLORS["test"]], N=256) cmap.set_bad(color="white") create_heatmap(values, leak, cmap, leak_cmap, fig, gs_main[0], "Performance", "RMSE (↓)", y_labels=True, mode="MMB", max_val=max(leak), label=label, yticklabels=[t[1] for t in techniques]) - plt.colorbar(cm.ScalarMappable(mpl_colors.Normalize(0, max(leak)), leak_cmap), cax=ax, label="$L(\pi)$ ↓") + plt.colorbar(cm.ScalarMappable(mpl_colors.Normalize(0, max(leak)), leak_cmap), cax=ax, label="scaled $L(\pi)$ (↓)") else: raise ValueError(f"Unknown plottype {ptype}") - return ax + return ax, il def plot_3x3(full_path: Path, data: Dict) -> None: @@ -420,35 +412,37 @@ def plot_3x3(full_path: Path, data: Dict) -> None: ax_c2d = fig.add_subplot(gs[0, 2]) ax_c2p = fig.add_subplot(gs[1, 2]) - plot_embeds(ax_rd, fig, data["I1e"], "drug", "Random drug baseline (I1)", drop=False, label="A") - plot_embeds(ax_sd, fig, data["C1e"], "drug", "DataSAIL drug-based (S1)", drop=False, label="B") - plot_embeds(ax_rp, fig, data["I1f"], "prot", "Random protein baseline (I1)", drop=False, label="D") - plot_embeds(ax_sp, fig, data["C1f"], "prot", "DataSAIL protein-based (S1)", drop=False, label="E") - plot_embeds(ax_c2d, fig, data["C2"], "drug", "DataSAIL 2D split (S2) - drugs", legend="lower right", label="G") - plot_embeds(ax_c2p, fig, data["C2"], "prot", "DataSAIL 2D split (S2) - proteins", label="H") + plot_embeds(ax_rd, fig, data["I1e"], "drug", "Random drug baseline (I1)", drop=False, label="a") + plot_embeds(ax_sd, fig, data["C1e"], "drug", "DataSAIL drug-based (S1)", drop=False, label="b") + plot_embeds(ax_rp, fig, data["I1f"], "prot", "Random protein baseline (I1)", drop=False, label="d") + plot_embeds(ax_sp, fig, data["C1f"], "prot", "DataSAIL protein-based (S1)", drop=False, label="e") + plot_embeds(ax_c2d, fig, data["C2"], "drug", "DataSAIL 2D split (S2) - drugs", legend="lower right", label="f") + plot_embeds(ax_c2p, fig, data["C2"], "prot", "DataSAIL 2D split (S2) - proteins", label="h") - ax_cd = viz_sl_models(full_path, gs[2, 0], fig, [ + ax_cd, il_cd = viz_sl_models(full_path, gs[2, 0], fig, [ ("datasail", "DataSAIL drug-based (S1)", "C1e"), ("lohi", "LoHi", "lohi"), - ("deepchem", "Fingerprint", "Fingerprint"), + ("deepchem", "DC - Fingerprint", "Fingerprint"), ("datasail", "Random drug baseline (I1)", "I1e"), - ], legend="lower left", ptype="bar", label="C") - ax_cp = viz_sl_models(full_path, gs[2, 1], fig, [ + ], legend="lower left", ptype="bar", label="c") + ax_cp, il_cp = viz_sl_models(full_path, gs[2, 1], fig, [ ("datasail", "DataSAIL protein-based (S1)", "C1f"), ("graphpart", "GraphPart", "graphpart"), ("datasail", "Random protein baseline (I1)", "I1f") - ], legend="lower left", ptype="bar", label="F") - ax_c2 = viz_sl_models(full_path, gs[2, 2], fig, [ + ], legend="lower left", ptype="bar", label="f") + ax_c2, il_c2 = viz_sl_models(full_path, gs[2, 2], fig, [ ("datasail", "DataSAIL 2D split (S2)", "C2"), ("datasail", "ID-based baseline (I2)", "I2"), ("datasail", "Random baseline", "R") - ], legend="lower left", ptype="bar", label="I") + ], legend="lower left", ptype="bar", label="i") ax_cd.sharey(ax_c2) ax_cp.sharey(ax_c2) + il_c2.sharey(il_cd) + il_cp.sharey(il_cd) fig.tight_layout() - plt.savefig(full_path / "plots" / f"PDBBind_{'umap' if USE_UMAP else 'tsne'}_3x3.png", transparent=True) + plt.savefig(full_path / "plots" / f"PDBBind_{'umap' if USE_UMAP else 'tsne'}_3x3.pdf", transparent=True) plt.show() @@ -482,14 +476,14 @@ def plot_cold_drug(full_path: Path, data: Dict) -> None: ax_we = fig.add_subplot(gs_comp[2, 1]) # ax_full = fig.add_subplot(gs_lower[1]) - plot_embeds(ax_i1, fig, i1e, "drug", "Random drug baseline (I1)", legend=4, drop=False, label="A") - plot_embeds(ax_c1, fig, c1e, "drug", "DataSAIL drug-based (S1)", drop=False, label="B") - plot_embeds(ax_lh, fig, lohi, "drug", "LoHi", drop=False, label="C") - plot_embeds(ax_bu, fig, butina, "drug", "DC - Butina Splits", drop=False, label="D") - plot_embeds(ax_fi, fig, fingerprint, "drug", "DC - Fingerprint Splits", drop=False, label="E") - plot_embeds(ax_mm, fig, minmax, "drug", "DC - MaxMin Splits", drop=False, label="F") - plot_embeds(ax_sc, fig, scaffold, "drug", "DC - Scaffold Splits", drop=False, label="G") - plot_embeds(ax_we, fig, weight, "drug", "DC - Weight Splits", drop=False, label="H") + plot_embeds(ax_i1, fig, i1e, "drug", "Random drug baseline (I1)", legend=4, drop=False, label="a") + plot_embeds(ax_c1, fig, c1e, "drug", "DataSAIL drug-based (S1)", drop=False, label="b") + plot_embeds(ax_lh, fig, lohi, "drug", "LoHi", drop=False, label="c") + plot_embeds(ax_bu, fig, butina, "drug", "DC - Butina Splits", drop=False, label="d") + plot_embeds(ax_fi, fig, fingerprint, "drug", "DC - Fingerprint Splits", drop=False, label="e") + plot_embeds(ax_mm, fig, minmax, "drug", "DC - MaxMin Splits", drop=False, label="f") + plot_embeds(ax_sc, fig, scaffold, "drug", "DC - Scaffold Splits", drop=False, label="g") + plot_embeds(ax_we, fig, weight, "drug", "DC - Weight Splits", drop=False, label="h") viz_sl_models(full_path, gs_lower[1], fig, [ ("datasail", "DataSAIL (S2)", "C2"), @@ -501,10 +495,11 @@ def plot_cold_drug(full_path: Path, data: Dict) -> None: ("deepchem", "DC - MaxMin", "MaxMin"), ("deepchem", "DC - Scaffold", "Scaffold"), ("deepchem", "DC - Weight", "Weight") - ], ptype="htm", label="I") + ], ptype="htm", label="i") fig.tight_layout() - plt.savefig(full_path / "plots" / f"PDBBind_CD_{'umap' if USE_UMAP else 'tsne'}.png") + plt.savefig(full_path / "plots" / f"PDBBind_CD_{'umap' if USE_UMAP else 'tsne'}.svg") + cairosvg.svg2pdf(url=str(full_path / "plots" / f"PDBBind_CD_{'umap' if USE_UMAP else 'tsne'}.svg"), write_to=str(full_path / "plots" / f"PDBBind_CD_{'umap' if USE_UMAP else 'tsne'}.pdf")) plt.show() @@ -525,19 +520,19 @@ def plot_cold_prot(full_path: Path, data: Dict) -> None: ax_c1 = fig.add_subplot(gs[0, 1]) ax_gp = fig.add_subplot(gs[1, 0]) - plot_embeds(ax_i1, fig, i1f, "prot", "Random protein baseline (I1)", legend=4, drop=False, label="A") - plot_embeds(ax_c1, fig, c1f, "prot", "DataSAIL protein-based (S1)", drop=False, label="B") - plot_embeds(ax_gp, fig, graphpart, "prot", "GraphPart", drop=False, label="C") + plot_embeds(ax_i1, fig, i1f, "prot", "Random protein baseline (I1)", legend=4, drop=False, label="a") + plot_embeds(ax_c1, fig, c1f, "prot", "DataSAIL protein-based (S1)", drop=False, label="b") + plot_embeds(ax_gp, fig, graphpart, "prot", "GraphPart", drop=False, label="c") viz_sl_models(full_path, gs[1, 1], fig, [ ("datasail", "DataSAIL (S2)", "C2"), ("datasail", "DataSAIL (S1)", "C1f"), ("datasail", "Baseline (I1)", "I1f"), ("graphpart", "GraphPart", "graphpart"), - ], legend="lower left", ptype="bar", ncol=2, label="D") + ], legend="lower left", ptype="bar", ncol=2, label="d") fig.tight_layout() - plt.savefig(full_path / "plots" / f"PDBBind_CT_{'umap' if USE_UMAP else 'tsne'}.png") + plt.savefig(full_path / "plots" / f"PDBBind_CT_{'umap' if USE_UMAP else 'tsne'}.pdf") plt.show() @@ -561,13 +556,13 @@ def plot(full_path: Path): data = pickle.load(pickled_data) print("Plot 3x3") - plot_3x3(full_path, data) - #print("Plot cold drug") - #plot_cold_drug(full_path, data) - #print("Plot cold prot") + #plot_3x3(full_path, data) + print("Plot cold drug") + plot_cold_drug(full_path, data) + print("Plot cold prot") #plot_cold_prot(full_path, data) if __name__ == '__main__': - plot(Path(sys.argv[1])) # comp_il(Path(sys.argv[1])) + plot(Path(sys.argv[1])) diff --git a/experiments/MPP/split.py b/experiments/MPP/split.py index d005a1c..f97784c 100644 --- a/experiments/MPP/split.py +++ b/experiments/MPP/split.py @@ -1,3 +1,11 @@ +import os + +#num_threads = "128" +#os.environ["OPENBLAS_NUM_THREADS"] = num_threads +#os.environ["GOTO_NUM_THREADS"] = num_threads +#os.environ["OMP_NUM_THREADS"] = num_threads + + import sys from pathlib import Path import time as T @@ -49,15 +57,15 @@ def split_w_datasail(base_path: Path, name: str, techniques: List[str], solver: techniques=techniques, splits=[8, 2], names=["train", "test"], - runs=5, + runs=1, # 5, solver=solver, e_type="M", e_data=dict(df[["ID", "SMILES"]].values.tolist()), max_sec=1000, epsilon=0.1, ) - # with open(base_path / "time.txt", "a") as time: - # print("I1+C1", T.time() - start, file=time) + with open(base_path / "time2.txt", "a") as time: + print(techniques[0], T.time() - start, file=time) save_datasail_splits(base_path, df, "ID", [(t, t) for t in techniques], e_splits=e_splits) @@ -171,14 +179,17 @@ def split(full_path, name, solver="GUROBI"): def specific(): for run in range(RUNS): for name in DATASETS.keys(): - if name.lower() == "pcba": + if name.lower() in {"pcba"}: continue - split_w_datasail(Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "datasail" / "MPP" / name, name, ["C1e"]) + split_w_datasail(Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP" / "datasail_new" / name, name, ["I1e"]) + split_w_datasail(Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP" / "datasail_new" / name, name, ["C1e"]) if __name__ == '__main__': - split_w_datasail(Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP" / "datasail" / "hiv", "hiv", ["I1e"]) + specific() exit(0) + # split_w_datasail(Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP" / "datasail_test" / "qm8", "qm8", ["C1e"]) + # exit(0) if len(sys.argv) == 1: specific() elif len(sys.argv) == 2: diff --git a/experiments/MPP/visualize.py b/experiments/MPP/visualize.py index 527646f..11f3aab 100644 --- a/experiments/MPP/visualize.py +++ b/experiments/MPP/visualize.py @@ -11,6 +11,7 @@ from matplotlib.colors import LinearSegmentedColormap from mpl_toolkits.axes_grid1 import make_axes_locatable import deepchem as dc +import cairosvg from datasail.reader.utils import DataSet from experiments.ablation import david @@ -134,15 +135,15 @@ def plot_double(full_path: Path, names: List[str]) -> None: ax[i][2].legend(loc="lower left", framealpha=1) ax[i][2].set_title("Performance comparison") ax[i][2].set_xlabel("ML Models") - set_subplot_label(ax[i][2], fig, ["C", "F"][i]) + set_subplot_label(ax[i][2], fig, ["c", "f"][i]) i_tr, i_te, c_tr, c_te = embed(full_path, name.lower()) plot_embeds(ax[i][0], i_tr, i_te, "Random baseline (I1)", legend=True) - set_subplot_label(ax[i][0], fig, chr(ord("A") + 3 * i)) + set_subplot_label(ax[i][0], fig, chr(ord("a") + 3 * i)) plot_embeds(ax[i][1], c_tr, c_te, "DataSAIL split (S1)") - set_subplot_label(ax[i][1], fig, chr(ord("B") + 3 * i)) + set_subplot_label(ax[i][1], fig, chr(ord("b") + 3 * i)) plt.tight_layout() - plt.savefig(full_path / "plots" / f"{names[0]}_{names[1]}.png") + plt.savefig(full_path / "plots" / f"{names[0]}_{names[1]}.pdf") plt.show() @@ -164,7 +165,7 @@ def heatmap_plot(full_path: Path): cols, rows = 4, 4 gs_main = gridspec.GridSpec(1, 2, figure=fig, width_ratios=[85, 3], wspace=0.1) gs = gs_main[0].subgridspec(rows, cols, wspace=0.3, hspace=0.25) - ax = fig.add_subplot(gs_main[1]) + ax_main = fig.add_subplot(gs_main[1]) if (leak_path := full_path / "data" / "leakage.pkl").exists(): with open(leak_path, "rb") as f: @@ -200,12 +201,14 @@ def get_il(name, tech): df["Split"] = [get_il(name, tech) for tech in df.index] values = np.array(df.loc[ ["C1e", "I1e", "lohi", "Butina", "Fingerprint", "MaxMin", "Scaffold", "Weight"], - ["RF", "SVM", "XGB", "MLP", "D-MPNN", "Split"], + ["RF", "SVM", "XGB", "MLP", "D-MPNN", "Split"], ], dtype=float) - create_heatmap(values[:, :-1], values[:, -1:], cmap, leak_cmap, fig, gs[i // 4, i % 4], name, METRICS[DATASETS[name.lower()][2]], y_labels=(i % cols == 0), mode="MMB", max_val=max_leak, yticklabels=["DataSAIL (S1)", "Rd. basel. (I1)", "LoHi", "DC - Butina", "DC_Fingerp.", "DC - MinMax", "DC - Scaffold", "DC - Weight"]) + ax = create_heatmap(values[:, :-1], values[:, -1:], cmap, leak_cmap, fig, gs[i // 4, i % 4], name, METRICS[DATASETS[name.lower()][2]], y_labels=(i % cols == 0), mode="MMB", max_val=max_leak, yticklabels=["DataSAIL (S1)", "Rd. basel. (I1)", "LoHi", "DC - Butina", "DC - Fingerp.", "DC - MinMax", "DC - Scaffold", "DC - Weight"]) + set_subplot_label(ax[0], fig, chr(ord("a") + i)) - plt.colorbar(cm.ScalarMappable(mpl_colors.Normalize(0, max_leak), leak_cmap), cax=ax, label="$L(\pi)$ ↓") - plt.savefig(full_path / "plots" / f"MoleculeNet_comp.png", transparent=True) + plt.colorbar(cm.ScalarMappable(mpl_colors.Normalize(0, max_leak), leak_cmap), cax=ax_main, label="scaled $L(\pi)$ (↓)") + plt.savefig(full_path / "plots" / f"MoleculeNet_comp.svg", dpi=200) # transparent=True) + cairosvg.svg2pdf(url=str(full_path / "plots" / f"MoleculeNet_comp.svg"), write_to=str(full_path / "plots" / f"MoleculeNet_comp.pdf")) plt.show() @@ -257,7 +260,7 @@ def comp_all_il(base: Path): if __name__ == '__main__': # comp_all_il(Path(sys.argv[1])) - plot_double(Path(sys.argv[1]), ["QM8", "Tox21"]) + #plot_double(Path(sys.argv[1]), ["QM8", "Tox21"]) # plot_double(Path(sys.argv[1]), ["FreeSolv", "ESOL"]) - # heatmap_plot(Path(sys.argv[1])) + heatmap_plot(Path(sys.argv[1])) diff --git a/experiments/Strat/visualize.py b/experiments/Strat/visualize.py index d64c92c..ab36ebd 100644 --- a/experiments/Strat/visualize.py +++ b/experiments/Strat/visualize.py @@ -21,7 +21,7 @@ def plot_perf(base_path, ax): df = pd.DataFrame(values, columns=["DataSAIL split (S1 w/ classes)", "Stratified baseline"], index=models) df.loc["IL"] = [np.average([x for x, _ in leakage[k]]) for k in ["datasail", "deepchem"]] il = plot_bars_2y(df.T, ax, color=[COLORS["s1d"], COLORS["r1d"]]) - ax.set_ylabel("AUROC (↑)") + ax.set_ylabel("ROC-AUC (↑)") ax.set_xlabel("ML Models") ax.legend(loc="lower left") ax.set_title(f"Performance comparison") @@ -36,14 +36,14 @@ def main(full_path): dc_tr, dc_te, ds_tr, ds_te = embed(full_path) plot_embeds(ax[0], dc_tr, dc_te, "Stratified baseline", legend=True) - set_subplot_label(ax[0], fig, "A") + set_subplot_label(ax[0], fig, "a") plot_embeds(ax[1], ds_tr, ds_te, "DataSAIL split (S1 w/ classes)") - set_subplot_label(ax[1], fig, "B") + set_subplot_label(ax[1], fig, "b") plot_perf(full_path, ax[2]) - set_subplot_label(ax[2], fig, "C") + set_subplot_label(ax[2], fig, "c") fig.tight_layout() - plt.savefig(plot_dir / "Strat.png") + plt.savefig(plot_dir / "Strat.pdf") plt.show() diff --git a/experiments/ablation/ablation_plot.py b/experiments/ablation/ablation_plot.py index cd0d462..b28cbac 100644 --- a/experiments/ablation/ablation_plot.py +++ b/experiments/ablation/ablation_plot.py @@ -1,16 +1,17 @@ import sys from pathlib import Path +import cairosvg import matplotlib from matplotlib import pyplot as plt, gridspec from experiments.ablation.visualize_de import plot_de_ablation -from experiments.ablation.time import get_tool_times +from experiments.ablation.time import plot_times from experiments.utils import set_subplot_label from experiments.ablation.david import visualize -def plot_ablations(full_path): +def plot_ablations(base_path): matplotlib.rc('font', **{'size': 18}) fig = plt.figure(figsize=(20, 12)) gs = gridspec.GridSpec(2, 1, figure=fig) @@ -18,20 +19,21 @@ def plot_ablations(full_path): gs_lower = gs[1].subgridspec(1, 3, width_ratios=[1, 1.5, 0.25], wspace=0.4) ax = [fig.add_subplot(gs_upper[0]), fig.add_subplot(gs_upper[1]), fig.add_subplot(gs_lower[0]), fig.add_subplot(gs_lower[1])] - visualize(full_path, list(range(10, 50, 5)) + list(range(50, 150, 10)) + list(range(150, 401, 50)), ["GUROBI", "MOSEK", "SCIP"], ax=(ax[0], ax[1]), fig=fig) - set_subplot_label(ax[0], fig, "A") - set_subplot_label(ax[1], fig, "B") + visualize(base_path / "Clusters", list(range(10, 50, 5)) + list(range(50, 150, 10)) + list(range(150, 401, 50)), ["GUROBI", "MOSEK", "SCIP"], ax=(ax[0], ax[1]), fig=fig) + set_subplot_label(ax[0], fig, "a") + set_subplot_label(ax[1], fig, "b") - plot_de_ablation(full_path, ax=ax[2], fig=fig) - set_subplot_label(ax[2], fig, "C") + plot_de_ablation(base_path / "Strat", ax=ax[2], fig=fig) + set_subplot_label(ax[2], fig, "c") - get_tool_times(Path("experiments") / "MPP", ax=ax[3]) - set_subplot_label(ax[3], fig, "D") + plot_times(base_path / "MPP", ax=ax[3]) + set_subplot_label(ax[3], fig, "d") plt.tight_layout() - plt.savefig("ablation.png") + plt.savefig(base_path / "ablation.svg") + cairosvg.svg2pdf(url=str(base_path / "ablation.svg"), write_to=str(base_path / "ablation.pdf")) plt.show() if __name__ == '__main__': - plot_ablations(Path(sys.argv[1])) + plot_ablations(Path(sys.argv[1])) # .../v10 diff --git a/experiments/ablation/david.py b/experiments/ablation/david.py index 23cb891..819b2b1 100644 --- a/experiments/ablation/david.py +++ b/experiments/ablation/david.py @@ -19,7 +19,6 @@ from datasail.cluster.clustering import additional_clustering from datasail.reader.read_molecules import read_molecule_data from datasail.solver.utils import solve, compute_limits -from experiments.ablation.time import MARKERS from experiments.utils import dc2pd, DATASETS, COLORS @@ -217,6 +216,7 @@ def eval(assignments, similarity, weights=None): # print(np.min(mask), np.max(mask)) leak = (np.sum(similarity * weights * mask) / np.sum(similarity * weights * alt)) / 2 + # print("\t", leak, sum(mask), sum(alt)) return leak, np.sum(similarity * weights * alt) / 2 @@ -259,21 +259,15 @@ def visualize(full_path: Path, clusters: List[int], solvers, ax: Optional[Tuple] ax_p.set_xlabel("Number of clusters") ax_t.set_xlabel("Number of clusters") ax_t.set_ylabel("Time for solving [s] (↓)") - ax_p.set_ylabel("$L(\pi)$ (↓)") - - ax_t.plot(times["GUROBI"][:, 0], times["GUROBI"][:, 1], label="GUROBI", color=COLORS["train"], - marker=MARKERS["gurobi"], markersize=9) - ax_t.plot(times["MOSEK"][:, 0], times["MOSEK"][:, 1], label="MOSEK", color=COLORS["test"], - marker=MARKERS["mosek"], markersize=9) - ax_t.plot(times["SCIP"][:, 0], times["SCIP"][:, 1], label="SCIP", color=COLORS["r1d"], - marker=MARKERS["scip"], markersize=9) - - ax_p.plot(times["GUROBI"][:, 0], performances["GUROBI"], label="GUROBI", color=COLORS["train"], - marker=MARKERS["gurobi"], markersize=9) - ax_p.plot(times["MOSEK"][:, 0], performances["MOSEK"], label="MOSEK", color=COLORS["test"], - marker=MARKERS["mosek"], markersize=9) - ax_p.plot(times["SCIP"][:, 0], performances["SCIP"], label="SCIP", color=COLORS["r1d"], - marker=MARKERS["scip"], markersize=9) + ax_p.set_ylabel("scaled $L(\pi)$ (↓)") + + ax_t.plot(times["GUROBI"][:, 0], times["GUROBI"][:, 1], label="GUROBI", color=COLORS["train"], marker="o") + ax_t.plot(times["MOSEK"][:, 0], times["MOSEK"][:, 1], label="MOSEK", color=COLORS["test"], marker="x") + ax_t.plot(times["SCIP"][:, 0], times["SCIP"][:, 1], label="SCIP", color=COLORS["r1d"], marker="D") + + ax_p.plot(times["GUROBI"][:, 0], [2 * x[0] for x in performances["GUROBI"]], label="GUROBI", color=COLORS["train"], marker="o") + ax_p.plot(times["MOSEK"][:, 0], [2 * x[0] for x in performances["MOSEK"]], label="MOSEK", color=COLORS["test"], marker="x") + ax_p.plot(times["SCIP"][:, 0], [2 * x[0] for x in performances["SCIP"]], label="SCIP", color=COLORS["r1d"], marker="D") ax_p.legend() ax_p.set_title("Leaked Information on Tox21") diff --git a/experiments/ablation/time.py b/experiments/ablation/time.py index e10f061..94711aa 100644 --- a/experiments/ablation/time.py +++ b/experiments/ablation/time.py @@ -4,12 +4,10 @@ from typing import List import matplotlib -import numpy as np -from matplotlib import pyplot as plt, gridspec +import pandas as pd +from matplotlib import pyplot as plt -from experiments.utils import DATASETS, RUNS, COLORS - -files = [] +from experiments.utils import DATASETS, COLORS MARKERS = { "i1e": "o", @@ -25,127 +23,103 @@ "weight": "D", } - -def get_single_time(path: Path) -> float: - """ - Get the time it took to split the dataset for a single run. - - Args: - path: Path to the splitting directory. - - Returns: - The time it took to split the dataset. - """ - if not os.path.exists(path / "train.csv"): - return 0 - return os.path.getctime(path / "train.csv") - os.path.getctime(path / "start.txt") - - -def get_run_times(path: Path) -> List[float]: - """ - Get the time it took to split the dataset for all runs. - - Args: - path: Path to the technique directory. - - Returns: - The time it took to split the dataset for all runs. - """ - return [get_single_time(path / f"split_{run}") for run in range(RUNS)] - - -def get_tech_times(path: Path) -> List[List[float]]: - """ - Get the time it took to split the dataset for all techniques. - - Args: - path: Path to the dataset directory. - - Returns: - The time it took to split the dataset for all techniques. - """ - if "deepchem" in str(path): - techniques = ["Scaffold", "Weight", "MinMax", "Butina", "Fingerprint"] - elif "datasail" in str(path): - techniques = ["I1e", "C1e"] - elif "lohi" in str(path): - techniques = ["lohi"] - elif "graphpart" in str(path): - techniques = ["graphpart"] - else: - raise ValueError(f"No known technique in path {str(path)}.") - return [get_run_times(path / tech) for tech in techniques] - - -def get_dataset_times(path: Path) -> List[List[List[float]]]: - """ - Get the time it took to split the dataset for all datasets. - - Args: - path: Path to the dataset directory. - - Returns: - The time it took to split the dataset for all datasets. - """ - return [get_tech_times(path / ds_name) for ds_name in sorted(os.listdir(path), key=lambda x: DATASETS[x][3])] - - -def get_tool_times(path, ax=None) -> None: - """ - Plot the time it took to split the dataset for all datasets and techniques. - - Args: - path: Path to the dataset directory. - ax: Axis to plot on. - """ +ax = None +base_path = Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP" + +def read_times(base_path: Path, name, tool): + data = {"tech": [], "run": [], "time": []} + counter = {"I1e": 0, "C1e": 0, "lohi": 0, "Scaffold": 0, "Weight": 0, "MaxMin": 0, "Butina": 0, "Fingerprint": 0} + if (file_path := base_path / tool / name / "time2.txt").exists(): + with open(file_path, "r") as f: + for i, line in enumerate(f.readlines()): + if i == 0 and tool != "datasail_new": + continue + tech, time = line.split() + data["tech"].append(tech) + data["time"].append(float(time)) + data["run"].append(counter[tech]) + counter[tech] += 1 + df = pd.DataFrame(data) + df["name"] = name + df["tool"] = tool.split("_")[0] + return df + return None + + +def map_tech_name(name: str) -> str: + if name.lower() == "c1e": + return "DataSAIL (S1)" + elif name.lower() == "i1e": + return "Rd. Baseline (I1)" + elif name.lower() == "lohi": + return "LoHi" + elif name.lower() == "maxmin": + return "DC - MaxMin" + return "DC - " + name[0].upper() + name[1:].lower() + + +def map_tech_color(name: str) -> str: + if name.lower().endswith("scaffold"): + return "train" + elif name.lower().endswith("weight"): + return "test" + return name + + +def plot_times(base_path: Path, ax=None): if show := ax is None: matplotlib.rc('font', **{'size': 16}) fig = plt.figure(figsize=(20, 10.67)) gs = gridspec.GridSpec(1, 1, figure=fig) ax = fig.add_subplot(gs[0]) - pkl_path = Path("../..") / "DataSAIL" / "experiments" / "MPP" / "timing.pkl" - if not os.path.exists(pkl_path): - times = np.array(get_dataset_times(path / "datasail")), \ - np.array(get_dataset_times(path / "lohi")), \ - np.array(get_dataset_times(path / "deepchem")) - pickle.dump(times, open(pkl_path, "wb")) - else: - times = pickle.load(open(pkl_path, "rb")) - times = list(times) - times[1] = np.concatenate([times[1], np.array([[[0, 0, 0, 0, 0]]])]) - timings = np.concatenate(times, axis=1) - timings = timings[:, [1, 0, 2, 3, 4, 5, 6, 7]] - # labels = ["I1e", "C1e", "LoHi", "Scaffold", "Weight", "MaxMin", "Butina", "Fingerprint"] - labels = ["C1e", "I1e", "LoHi", "Butina", "Fingerprint", "MaxMin", "Scaffold", "Weight"] - x = np.array(list(sorted([6160, 21786, 133885, 1128, 642, 4200, 93087, 41127, 1513, 2039, 7831, 8575, 1427, 1478]))) - for i, label in enumerate(labels): - tmp = timings[:, i].mean(axis=1) - tmp_x = x[tmp > 0] - tmp = tmp[tmp > 0] - ax.plot(tmp_x, tmp, label={"I1e": "Random (I1)", "C1e": "DataSAIL (S1)", "LoHi": "LoHi"}.get(label, "DC - " + label), color=COLORS[label.lower()], marker=MARKERS[label.lower()], markersize=9) - if i == 2: - ax.plot([tmp_x[-1], x[-1]], [tmp[-1], 22180], color=COLORS[label.lower()], marker=MARKERS[label.lower()], markersize=9, linestyle='dashed') - ax.hlines(1, x[0], x[-1], linestyles="dashed", colors="black") - ax.text(x[0], 1, "1 sec", verticalalignment="bottom", horizontalalignment="left") - ax.hlines(60, x[0], x[-1], linestyles="dashed", colors="black") - ax.text(x[0], 60, "1 min", verticalalignment="bottom", horizontalalignment="left") - ax.hlines(3600, x[0], x[-1], linestyles="dashed", colors="black") - ax.text(x[0], 3600, "1 h", verticalalignment="bottom", horizontalalignment="left") - + times = [] + for tool in ["deepchem", "lohi", "datasail_new"]: + for name in DATASETS.keys(): + if name == "pcba": + continue + if (res := read_times(base_path, name, tool)) is not None: + times.append(res) + times = pd.concat(times) + + names = list(DATASETS.keys()) + names.remove("pcba") + names, sizes = zip(*list(sorted([(k, DATASETS[k][-1]) for k in names], key=lambda x: x[1]))) + + for tech in [ + "C1e", "I1e", + "lohi", + "Butina", "Fingerprint", "MaxMin", "Scaffold", "Weight" + ]: + tmp = sorted(dict(times[times["tech"] == tech].groupby("name")["time"].mean()).items(), key=lambda x: names.index(x[0])) + if tech == "lohi": + ax.plot([sizes[names.index(x[0])] for x in tmp[:-1]], [x[1] for x in tmp[:-1]], label=map_tech_name(tech), color=COLORS[map_tech_color(tech.lower())], marker=MARKERS[tech.lower()]) + ax.scatter(sizes[names.index(tmp[-1][0])], tmp[-1][1], linestyle="dashed", color=COLORS[map_tech_color(tech.lower())], marker=MARKERS[tech.lower()]) + else: + ax.plot([sizes[names.index(x[0])] for x in tmp], [x[1] for x in tmp], label=map_tech_name(tech), color=COLORS[map_tech_color(tech.lower())], marker=MARKERS[tech.lower()]) + + ax.hlines(1, sizes[0], sizes[-1], linestyles="dashed", colors="black") + ax.text(sizes[0], 1, "1 sec", verticalalignment="bottom", horizontalalignment="left") + ax.hlines(60, sizes[0], sizes[-1], linestyles="dashed", colors="black") + ax.text(sizes[0], 60, "1 min", verticalalignment="bottom", horizontalalignment="left") + ax.hlines(3600, sizes[0], sizes[-1], linestyles="dashed", colors="black") + ax.text(sizes[0], 3600, "1 h", verticalalignment="bottom", horizontalalignment="left") + box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) - + ax.set_xscale("log") ax.set_yscale("log") ax.set_xlabel("#Molecules in Dataset") ax.set_ylabel("Time for splitting [s] (↓)") ax.set_title("Runtime on MoleculeNet") + if show: plt.tight_layout() - plt.savefig("timing.png") + (plot_path := base_path / "plots").mkdir(exist_ok=True, parents=True) + plt.savefig(plot_path / "timing.png") plt.show() if __name__ == '__main__': - get_tool_times(Path("experiments") / "MPP") + plot_times(Path("/scratch") / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP") \ No newline at end of file diff --git a/experiments/ablation/time2.py b/experiments/ablation/time2.py index dbdbe01..0010618 100644 --- a/experiments/ablation/time2.py +++ b/experiments/ablation/time2.py @@ -10,13 +10,14 @@ def read_times(base_path: Path, name, tool): data = {"tech": [], "run": [], "time": []} counter = {"lohi": 0, "Scaffold": 0, "Weight": 0, "MaxMin": 0, "Butina": 0, "Fingerprint": 0} - with open(base_path / tool / name / "time2.txt") as f: - for line in f.readlines()[1:]: - tech, time = line.split() - data["tech"].append(tech) - data["time"].append(float(time)) - data["run"].append(counter[tech]) - counter[tech] += 1 + if (file_path := base_path / tool / name / "time2.txt").exists(): + with open(file_path, "r") as f: + for line in f.readlines()[1:]: + tech, time = line.split() + data["tech"].append(tech) + data["time"].append(float(time)) + data["run"].append(counter[tech]) + counter[tech] += 1 df = pd.DataFrame(data) df["name"] = name df["tool"] = tool diff --git a/experiments/ablation/time_old.py b/experiments/ablation/time_old.py new file mode 100644 index 0000000..e10f061 --- /dev/null +++ b/experiments/ablation/time_old.py @@ -0,0 +1,151 @@ +import os +import pickle +from pathlib import Path +from typing import List + +import matplotlib +import numpy as np +from matplotlib import pyplot as plt, gridspec + +from experiments.utils import DATASETS, RUNS, COLORS + +files = [] + +MARKERS = { + "i1e": "o", + "c1e": "P", + "lohi": "X", + "gurobi": "o", + "mosek": "P", + "scip": "X", + "butina": "v", + "fingerprint": "^", + "maxmin": "<", + "scaffold": ">", + "weight": "D", +} + + +def get_single_time(path: Path) -> float: + """ + Get the time it took to split the dataset for a single run. + + Args: + path: Path to the splitting directory. + + Returns: + The time it took to split the dataset. + """ + if not os.path.exists(path / "train.csv"): + return 0 + return os.path.getctime(path / "train.csv") - os.path.getctime(path / "start.txt") + + +def get_run_times(path: Path) -> List[float]: + """ + Get the time it took to split the dataset for all runs. + + Args: + path: Path to the technique directory. + + Returns: + The time it took to split the dataset for all runs. + """ + return [get_single_time(path / f"split_{run}") for run in range(RUNS)] + + +def get_tech_times(path: Path) -> List[List[float]]: + """ + Get the time it took to split the dataset for all techniques. + + Args: + path: Path to the dataset directory. + + Returns: + The time it took to split the dataset for all techniques. + """ + if "deepchem" in str(path): + techniques = ["Scaffold", "Weight", "MinMax", "Butina", "Fingerprint"] + elif "datasail" in str(path): + techniques = ["I1e", "C1e"] + elif "lohi" in str(path): + techniques = ["lohi"] + elif "graphpart" in str(path): + techniques = ["graphpart"] + else: + raise ValueError(f"No known technique in path {str(path)}.") + return [get_run_times(path / tech) for tech in techniques] + + +def get_dataset_times(path: Path) -> List[List[List[float]]]: + """ + Get the time it took to split the dataset for all datasets. + + Args: + path: Path to the dataset directory. + + Returns: + The time it took to split the dataset for all datasets. + """ + return [get_tech_times(path / ds_name) for ds_name in sorted(os.listdir(path), key=lambda x: DATASETS[x][3])] + + +def get_tool_times(path, ax=None) -> None: + """ + Plot the time it took to split the dataset for all datasets and techniques. + + Args: + path: Path to the dataset directory. + ax: Axis to plot on. + """ + if show := ax is None: + matplotlib.rc('font', **{'size': 16}) + fig = plt.figure(figsize=(20, 10.67)) + gs = gridspec.GridSpec(1, 1, figure=fig) + ax = fig.add_subplot(gs[0]) + pkl_path = Path("../..") / "DataSAIL" / "experiments" / "MPP" / "timing.pkl" + if not os.path.exists(pkl_path): + times = np.array(get_dataset_times(path / "datasail")), \ + np.array(get_dataset_times(path / "lohi")), \ + np.array(get_dataset_times(path / "deepchem")) + pickle.dump(times, open(pkl_path, "wb")) + else: + times = pickle.load(open(pkl_path, "rb")) + times = list(times) + times[1] = np.concatenate([times[1], np.array([[[0, 0, 0, 0, 0]]])]) + timings = np.concatenate(times, axis=1) + timings = timings[:, [1, 0, 2, 3, 4, 5, 6, 7]] + # labels = ["I1e", "C1e", "LoHi", "Scaffold", "Weight", "MaxMin", "Butina", "Fingerprint"] + labels = ["C1e", "I1e", "LoHi", "Butina", "Fingerprint", "MaxMin", "Scaffold", "Weight"] + x = np.array(list(sorted([6160, 21786, 133885, 1128, 642, 4200, 93087, 41127, 1513, 2039, 7831, 8575, 1427, 1478]))) + for i, label in enumerate(labels): + tmp = timings[:, i].mean(axis=1) + tmp_x = x[tmp > 0] + tmp = tmp[tmp > 0] + ax.plot(tmp_x, tmp, label={"I1e": "Random (I1)", "C1e": "DataSAIL (S1)", "LoHi": "LoHi"}.get(label, "DC - " + label), color=COLORS[label.lower()], marker=MARKERS[label.lower()], markersize=9) + if i == 2: + ax.plot([tmp_x[-1], x[-1]], [tmp[-1], 22180], color=COLORS[label.lower()], marker=MARKERS[label.lower()], markersize=9, linestyle='dashed') + ax.hlines(1, x[0], x[-1], linestyles="dashed", colors="black") + ax.text(x[0], 1, "1 sec", verticalalignment="bottom", horizontalalignment="left") + ax.hlines(60, x[0], x[-1], linestyles="dashed", colors="black") + ax.text(x[0], 60, "1 min", verticalalignment="bottom", horizontalalignment="left") + ax.hlines(3600, x[0], x[-1], linestyles="dashed", colors="black") + ax.text(x[0], 3600, "1 h", verticalalignment="bottom", horizontalalignment="left") + + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("#Molecules in Dataset") + ax.set_ylabel("Time for splitting [s] (↓)") + ax.set_title("Runtime on MoleculeNet") + if show: + plt.tight_layout() + plt.savefig("timing.png") + plt.show() + + +if __name__ == '__main__': + get_tool_times(Path("experiments") / "MPP") diff --git a/experiments/ablation/visualize_de.py b/experiments/ablation/visualize_de.py index dc00b79..c596f10 100644 --- a/experiments/ablation/visualize_de.py +++ b/experiments/ablation/visualize_de.py @@ -35,7 +35,7 @@ def score_split(full_path: Path, dataset: DataSet, delta: float, epsilon: float) print(f"\r{delta}, {epsilon}, {run}", end=" " * 10) train = set(pd.read_csv(base / f"split_{run}" / "train.csv")["ID"].tolist()) tmp2 = np.array([1 if n in train else -1 for n in dataset.names]).reshape(-1, 1) - vals.append(eval(tmp2, dataset.cluster_similarity)) + vals.append(eval(tmp2, dataset.cluster_similarity)[0]) return np.mean(vals) @@ -76,7 +76,7 @@ def plot_de_ablation(full_path: Path, ax=None, fig=None) -> None: pkl_path = full_path / "strat_data.pkl" if Path(pkl_path).exists(): with open(pkl_path, "rb") as f: - qual = pickle.load(f) + _, qual = pickle.load(f) else: qual = read_quality(full_path) with open(full_path / "strat_data.pkl", "wb") as out: @@ -91,14 +91,14 @@ def plot_de_ablation(full_path: Path, ax=None, fig=None) -> None: cax = divider.append_axes('right', size='5%', pad=0.05) cmap = LinearSegmentedColormap.from_list("Custom", [COLORS["r1d"], COLORS["s1d"]], N=256) cmap.set_bad(color="white") - q_values = np.array(qual.values, dtype=float)[::-1, :].T + q_values = np.array(qual.values, dtype=float)[::-1, :].T * 3.1 tmp = ax.imshow(q_values, cmap=cmap, vmin=np.nanmin(q_values), vmax=np.nanmax(q_values)) ax.set_xticks(list(reversed(range(1, 6, 2))), [0.3, 0.2, 0.1]) ax.set_yticks(list(range(0, 6, 2)), [0.3, 0.2, 0.1]) ax.set_xlabel("$\epsilon$") ax.set_ylabel("$\delta$") ax.set_title("Effect of $\delta$ and $\epsilon$") - fig.colorbar(tmp, cax=cax, orientation='vertical', label="$L(\pi)$ (↓)") + fig.colorbar(tmp, cax=cax, orientation='vertical', label="scaled $L(\pi)$ (↓)") if show: plt.tight_layout() diff --git a/experiments/utils.py b/experiments/utils.py index 8a717c1..5402c9f 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -51,7 +51,7 @@ DS2UPPER = { "qm7": "QM7", "qm8": "QM8", "qm9": "QM9", "esol": "ESOL", "freesolv": "FreeSolv", "lipophilicity": "Lipophilicity", "pcba": "PCBA", "muv": "MUV", "hiv": "HIV", "bace": "BACE", "bbbp": "BBBP", "tox21": "Tox21", "toxcast": "ToxCast", "sider": "SIDER", "clintox": "ClinTox", } -METRICS = {"mae": "MAE ↓", "rmse": "RMSE ↓", "prc-auc": "PRC-AUC ↑", "auc": "ROC-AUC ↑"} +METRICS = {"mae": "MAE (↓)", "rmse": "RMSE (↓)", "prc-auc": "PRC-AUC (↑)", "auc": "ROC-AUC (↑)"} models = { "rf-r": RandomForestRegressor(n_estimators=500, n_jobs=-1, random_state=42), @@ -120,7 +120,7 @@ def save_datasail_splits(base: Path, df: pd.DataFrame, key: str, techniques: Lis inter_splits: Interactions splits """ for name, tech in techniques: - for run in range(RUNS): + for run in range(1): # RUNS): path = base / name / f"split_{run}" path.mkdir(parents=True, exist_ok=True) @@ -398,7 +398,7 @@ def plot_bars_2y2(df: pd.DataFrame, ax: plt.Axes, color) -> plt.Axes: il.bar(len(x) + addendum, row[-1], width, label=index, color=color[i]) # Adding labels and title - il.set_ylabel('$L(\pi)$ ↓') + il.set_ylabel('scaled $L(\pi)$ (↓)') plt.xticks(np.arange(len(df.columns)) + width / 2, df.columns) return il @@ -417,7 +417,7 @@ def plot_bars_2y(df: pd.DataFrame, ax: plt.Axes, color) -> plt.Axes: il.bar(len(x) + addendum, row[-1], width, label=index, color=color[i]) # Adding labels and title - il.set_ylabel('$L(\pi)$ ↓') + il.set_ylabel('scaled $L(\pi)$ (↓)') plt.xticks(np.arange(len(df.columns)) + width / 2, df.columns) return il @@ -468,3 +468,4 @@ def create_heatmap(main_data, scnd_data, main_cmap, scnd_cmap, fig, main_gs, tit l = f"{scnd_data[a, 0]:.2f}" ax[2].text(0, a, l, ha='center', va='center') set_subplot_label(ax[0], fig, label) + return ax