Skip to content

Commit

Permalink
Merge conflict resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Jan 2, 2025
2 parents d576a7e + d3b9ce7 commit d3da45a
Show file tree
Hide file tree
Showing 21 changed files with 674 additions and 305 deletions.
10 changes: 9 additions & 1 deletion datasail/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from datasail.sail import sail
import os

#N_THREADS = "1"
#os.environ["OPENBLAS_NUM_THREADS"] = N_THREADS
#os.environ["OPENBLAS_MAX_THREADS"] = N_THREADS
#os.environ["GOTO_NUM_THREADS"] = N_THREADS
#os.environ["OMP_NUM_THREADS"] = N_THREADS
os.environ["GRB_LICENSE_FILE"] = "/home/rjo21/gurobi_mickey.lic"

from datasail.sail import sail

if __name__ == '__main__':
sail()
3 changes: 2 additions & 1 deletion datasail/cluster/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def cluster(dataset: DataSet, **kwargs) -> DataSet:
if len(dataset.cluster_names) > dataset.num_clusters:
dataset = force_clustering(dataset, kwargs[KW_LINKAGE])

store_to_cache(dataset, **kwargs)
# store_to_cache(dataset, **kwargs)

return dataset

Expand Down Expand Up @@ -212,6 +212,7 @@ def additional_clustering(
)
# cluster the clusters into new, fewer, and bigger clusters
labels = ca.fit_predict(cluster_matrix)
LOGGER.info("Clustering finished")
return labels2clusters(labels, dataset, cluster_matrix, linkage)


Expand Down
121 changes: 95 additions & 26 deletions datasail/cluster/foldseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import shutil
from pathlib import Path
from typing import Optional
import pickle

import numpy as np
from pyarrow import compute, csv
from collections import defaultdict
from tqdm import tqdm

from datasail.parsers import MultiYAMLParser
from datasail.reader.utils import DataSet
Expand All @@ -23,25 +27,26 @@ def run_foldseek(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = N
raise ValueError("Foldseek is not installed.")
user_args = MultiYAMLParser(FOLDSEEK).get_user_arguments(dataset.args, [])

results_folder = Path("fs_results")
results_folder = Path("/scratch/SCRATCH_SAS/roman/DataSAIL/fs_results")

tmp = Path("tmp")
tmp = Path("/scratch/SCRATCH_SAS/roman/DataSAIL/fs_tmp")
tmp.mkdir(parents=True, exist_ok=True)
for name, filepath in dataset.data.items():
shutil.copy(filepath, tmp)
##for name in dataset.names:
## shutil.copy(dataset.data[name], tmp)

cmd = f"mkdir {results_folder} && " \
f"cd {results_folder} && " \
f"foldseek " \
f"easy-search " \
f"../tmp " \
f"../tmp " \
f"{str(tmp.resolve())} " \
f"{str(tmp.resolve())} " \
f"aln.m8 " \
f"tmp " \
f"--format-output 'query,target,fident' " \
f"--format-output 'query,target,fident,qlen,lddt' " \
f"-e inf " \
f"--threads {threads} " \
f"{user_args}" # && " \
# f"--exhaustive-search 1 " \
# f"rm -rf ../tmp"

if log_dir is None:
Expand All @@ -54,27 +59,91 @@ def run_foldseek(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = N

LOGGER.info("Start FoldSeek clustering")
LOGGER.info(cmd)
os.system(cmd)

if not (results_folder / "aln.m8").exists():
raise ValueError("Something went wrong with foldseek. The output file does not exist.")

namap = dict((n, i) for i, n in enumerate(dataset.names))
cluster_sim = np.zeros((len(dataset.names), len(dataset.names)))
with open(f"{results_folder}/aln.m8", "r") as data:
for line in data.readlines():
q1, q2, sim = line.strip().split("\t")[:3]
if "_" in q1 and "." in q1 and q1.rindex("_") > q1.index("."):
q1 = "_".join(q1.split("_")[:-1])
if "_" in q2 and "." in q2 and q2.rindex("_") > q2.index("."):
q2 = "_".join(q2.split("_")[:-1])
q1 = q1.replace(".pdb", "")
q2 = q2.replace(".pdb", "")
cluster_sim[namap[q1], namap[q2]] = sim
cluster_sim[namap[q2], namap[q1]] = sim
##os.system(cmd)

##if not (results_folder / "aln.m8").exists():
## raise ValueError("Something went wrong with foldseek. The output file does not exist.")

##ds = read_with_pyarrow(f"{results_folder}/aln.m8")
#with open("/scratch/SCRATCH_SAS/roman/DataSAIL/pyarrow.pkl", "rb") as data:
# ds = pickle.load(data)

#try:
#except Exception as e:
# print("pickling failed due to:", e)
# namap = dict((n, i) for i, n in enumerate(dataset.names))
##cluster_sim = np.zeros((len(dataset.names), len(dataset.names)))
#with open(f"{results_folder}/aln.m8", "r") as data:
# for line in data.readlines():
# q1, q2, sim = line.strip().split("\t")[:3]
# if "_" in q1 and "." in q1 and q1.rindex("_") > q1.index("."):
# q1 = "_".join(q1.split("_")[:-1])
# if "_" in q2 and "." in q2 and q2.rindex("_") > q2.index("."):
# q2 = "_".join(q2.split("_")[:-1])
# q1 = q1.replace(".pdb", "")
# q2 = q2.replace(".pdb", "")
# cluster_sim[namap[q1], namap[q2]] = sim
# cluster_sim[namap[q2], namap[q1]] = sim
# print("Additional names:", set(dataset.names).difference(set(ds.keys())))
# print("Additional hits:", set(ds.keys()).difference(set(dataset.names)))
# exit(0)
##for i, name1 in enumerate(dataset.names):
## cluster_sim[i, i] = 1
## for j, name2 in enumerate(dataset.names[i + 1:]):
## if name2 in ds[name1]:
## cluster_sim[i, j] = ds[name1][name2][2] / ds[name1][name2][3]
## if name1 in ds[name2]:
## cluster_sim[j, i] = ds[name2][name1][2] / ds[name2][name1][3]
##cluster_sim = (cluster_sim + cluster_sim.T) / 2

# with open("/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/prot_sim_full_v12.pkl", "wb") as out:
# pickle.dump(ds, out)
with open("/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/eval/full_v0/prots.pkl", "rb") as f:
ds = pickle.load(f)

shutil.rmtree(results_folder, ignore_errors=True)
shutil.rmtree(tmp, ignore_errors=True)

dataset.cluster_names = dataset.names
dataset.cluster_map = dict((n, n) for n in dataset.names)
dataset.cluster_similarity = cluster_sim
dataset.cluster_similarity = ds.cluster_similarity ## cluster_sim


def extract(tmp):
if len(tmp) == 1:
return tmp[0], "?"
else:
return "_".join(tmp[:-1]), tmp[-1]


def inner_list():
return ["", "", 0, 0]


def outer_dict():
return defaultdict(inner_list)


def read_with_pyarrow(file_path):
table = csv.read_csv(
file_path,
read_options=csv.ReadOptions(use_threads=True, column_names=["qid_chainid", "tid_chainid", "fident", "qlen", "lddt"]),
parse_options=csv.ParseOptions(delimiter="\t"),
)

indices = compute.sort_indices(table, [("lddt", "descending"), ("fident", "descending")])
ds = defaultdict(outer_dict)
for idx in tqdm(indices):
q_id, q_chain = extract(table["qid_chainid"][idx.as_py()].as_py().split("_"))
t_id, t_chain = extract(table["tid_chainid"][idx.as_py()].as_py().split("_"))
record = ds[q_id][t_id]
if q_chain in record[0] or t_chain in record[1]:
continue
fident = table["fident"][idx.as_py()].as_py()
q_len = table["qlen"][idx.as_py()].as_py()
record[0] += q_chain
record[1] += t_chain
record[2] += fident * q_len
record[3] += q_len
return ds

1 change: 1 addition & 0 deletions datasail/cluster/mash.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def run_mash(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = None)
dataset.cluster_names = dataset.names

shutil.rmtree(results_folder, ignore_errors=True)
shutil.rmtree(tmp, ignore_errors=True)


def read_mash_tsv(filename: Path, num_entities: int) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion datasail/reader/read_molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def remove_molecule_duplicates(dataset: DataSet) -> DataSet:
dataset: The dataset to remove duplicates from
Returns:
Update arguments as teh location of the data might change and an ID-Map file might be added.
Update arguments as the location of the data might change and an ID-Map file might be added.
"""
if isinstance(dataset.data[dataset.names[0]], (list, tuple, np.ndarray)):
# TODO: proper check for duplicate embeddings
Expand Down
7 changes: 6 additions & 1 deletion datasail/reader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Generator, Tuple, List, Optional, Dict, Union, Any, Callable, Iterable, Set
from collections.abc import Iterable

import h5py
import numpy as np
Expand Down Expand Up @@ -124,6 +125,9 @@ def strat2oh(self, name: Optional[str] = None, classes: Optional[Union[str, Set[
classes = [classes]
if self.classes is not None:
# print(name, self.class_oh[[self.classes[class_] for class_ in classes]].sum(axis=0))
# print("classes", classes)
# print("self.cl", self.classes)
# print("self.oh", self.class_oh)
return self.class_oh[[self.classes[class_] for class_ in classes]].sum(axis=0)
return None

Expand Down Expand Up @@ -287,7 +291,8 @@ def read_data(
elif isinstance(weights, Generator):
dataset.weights = dict(weights)
elif inter is not None:
dataset.weights = dict(count_inter(inter, index))
dataset.weights = {k: 0 for k in dataset.data.keys()}
dataset.weights.update(dict(count_inter(inter, index)))
else:
dataset.weights = {k: 1 for k in dataset.data.keys()}

Expand Down
37 changes: 29 additions & 8 deletions datasail/routine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import time
import pickle
from typing import Dict, Tuple, Optional

from datasail.argparse_patch import remove_patch
from datasail.cluster.clustering import cluster
from datasail.reader.read import read_data
from datasail.reader.utils import DataSet
from datasail.report import report
from datasail.settings import LOGGER, KW_TECHNIQUES, KW_EPSILON, KW_RUNS, KW_SPLITS, KW_NAMES, \
from datasail.settings import LOGGER, KW_INTER, KW_TECHNIQUES, KW_EPSILON, KW_RUNS, KW_SPLITS, KW_NAMES, \
KW_MAX_SEC, KW_MAX_SOL, KW_SOLVER, KW_LOGDIR, NOT_ASSIGNED, KW_OUTDIR, MODE_E, MODE_F, DIM_2, SRC_CL, KW_DELTA, \
KW_E_CLUSTERS, KW_F_CLUSTERS, KW_CC, CDHIT, INSTALLED, FOLDSEEK, TMALIGN, CDHIT_EST, DIAMOND, MMSEQS, MASH
from datasail.solver.solve import run_solver
Expand Down Expand Up @@ -40,19 +41,30 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]:
LOGGER.info("Read data")

# read e-entities and f-entities
e_dataset, f_dataset, inter = read_data(**kwargs)
e_dataset, f_dataset_tmp, inter = read_data(**kwargs)

# if required, cluster the input otherwise define the cluster-maps to be None
clusters = list(filter(lambda x: x[0].startswith(SRC_CL), kwargs[KW_TECHNIQUES]))
cluster_e = len(clusters) != 0 and any(c[-1] in {DIM_2, MODE_E} for c in clusters)
cluster_f = len(clusters) != 0 and any(c[-1] in {DIM_2, MODE_F} for c in clusters)

if cluster_e:
LOGGER.info("Cluster first set of entities.")
e_dataset = cluster(e_dataset, **kwargs)
if cluster_f:
LOGGER.info("Cluster second set of entities.")
f_dataset = cluster(f_dataset, **kwargs)
#if cluster_e:
# LOGGER.info("Cluster first set of entities.")
# e_dataset = cluster(e_dataset, **kwargs)
#if cluster_f:
# LOGGER.info("Cluster second set of entities.")
# f_dataset = cluster(f_dataset, **kwargs)

split = str(kwargs[KW_INTER]).split("/")[-2]
#with open(f"/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/{split}.pkl", "wb") as f:
# pickle.dump((e_dataset, f_dataset), f)
with open(f"/scratch/SCRATCH_SAS/roman/DataSAIL/PLINDER/{split}.pkl", "rb") as f:
e_dataset, f_dataset = pickle.load(f)
f_dataset.id_map = f_dataset_tmp.id_map

#print("E_ID_Map is None:", e_dataset.id_map is None)
#print("F_ID_Map is None:", f_dataset.id_map is None)
#print("Nones in inter :", sum([x is None for x in inter]))

if inter is not None:
if e_dataset.type is not None and f_dataset.type is not None:
Expand Down Expand Up @@ -88,13 +100,22 @@ def datasail_main(**kwargs) -> Optional[Tuple[Dict, Dict, Dict]]:

LOGGER.info("Store results")

#print("E name:", e_name_split_map.keys())
#print("F name:", f_name_split_map.keys())
#print("E cluster:", e_cluster_split_map.keys())
#print("F cluster:", f_cluster_split_map.keys())
#print("Inter:", inter_split_map.keys())

# infer interaction assignment from entity assignment if necessary and possible
output_inter_split_map = dict()
if new_inter is not None:
for technique in kwargs[KW_TECHNIQUES]:
output_inter_split_map[technique] = []
for run in range(kwargs[KW_RUNS]):
output_inter_split_map[technique].append(dict())
#print(e_name_split_map.keys())
#print(f_name_split_map.keys())
#print(techique)
for e, f in inter:
if technique.endswith(DIM_2) or technique == "R":
output_inter_split_map[technique][-1][(e, f)] = inter_split_map[technique][run].get(
Expand Down
1 change: 1 addition & 0 deletions datasail/solver/cluster_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def solve_c1(

loss = cvxpy.sum([t for tmp_list in tmp for t in tmp_list])
problem = solve(loss, constraints, max_sec, solver, log_file)
print(problem)

return None if problem is None else {
e: names[s] for s in range(len(splits)) for i, e in enumerate(clusters) if x[s, i].value > 0.1
Expand Down
Loading

0 comments on commit d3da45a

Please sign in to comment.