Skip to content

Commit

Permalink
More bugs with stratification fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed May 21, 2024
1 parent 4c0a7c9 commit 903fa27
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 48 deletions.
10 changes: 4 additions & 6 deletions datasail/cluster/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,16 @@ def finish_clustering(dataset: DataSet) -> None:
"""
# compute the weights and the stratification for the clusters
dataset.cluster_weights = {}
if dataset.stratification is not None:
dataset.cluster_stratification = {}
dataset.cluster_stratification = {}

for key, value in dataset.cluster_map.items():
if value not in dataset.cluster_weights:
dataset.cluster_weights[value] = 0
dataset.cluster_weights[value] += dataset.weights[key]

if dataset.stratification is not None:
if value not in dataset.cluster_stratification:
dataset.cluster_stratification[value] = np.zeros(len(dataset.classes))
dataset.cluster_stratification[value] += dataset.strat2oh(name=key)
if value not in dataset.cluster_stratification:
dataset.cluster_stratification[value] = np.zeros(len(dataset.classes))
dataset.cluster_stratification[value] += dataset.strat2oh(name=key)


def additional_clustering(
Expand Down
2 changes: 0 additions & 2 deletions datasail/cluster/diamond.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,5 @@ def run_diamond(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = No
shutil.rmtree(result_folder, ignore_errors=True)

dataset.cluster_names = table.index.tolist()
print(dataset.cluster_names)
dataset.cluster_map = {n: n for n in dataset.cluster_names}
dataset.cluster_similarity = table.to_numpy()
dataset.cluster_weights = {n: 1 for n in dataset.cluster_names}
2 changes: 1 addition & 1 deletion datasail/cluster/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None:
else:
raise ValueError(f"Unknown method {method}")
fps = [dataset.data[name] for name in dataset.names]
run(dataset, fps, method)

run(dataset, fps, method)
dataset.cluster_names = copy.deepcopy(dataset.names)
dataset.cluster_map = {n: n for n in dataset.names}

Expand Down
57 changes: 22 additions & 35 deletions datasail/reader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def read_data(
Returns:
A dataset storing all information on that datatype
"""
# parse the protein weights
# parse the weights
if isinstance(weights, Path) and weights.is_file():
if weights.suffix[1:].lower() == "csv":
dataset.weights = dict((n, float(w)) for n, w in read_csv(weights, ","))
Expand All @@ -286,13 +286,31 @@ def read_data(
elif inter is not None:
dataset.weights = dict(count_inter(inter, index))
else:
dataset.weights = dict((p, 1) for p in list(dataset.data.keys()))
dataset.weights = {k: 1 for k in dataset.data.keys()}

dataset.classes, dataset.stratification = read_stratification(strats)
# parse the stratification
if isinstance(strats, Path) and strats.is_file():
if strats.suffix[1:].lower() == "csv":
dataset.stratification = dict(read_csv(strats, ","))
elif strats.suffix[1:].lower() == "tsv":
dataset.stratification = dict(read_csv(strats, "\t"))
else:
raise ValueError()
elif isinstance(strats, dict):
dataset.stratification = strats
elif isinstance(strats, Callable):
dataset.stratification = strats()
elif isinstance(strats, Generator):
dataset.stratification = dict(strats)
else:
dataset.stratification = {k: 0 for k in dataset.data.keys()}

# .classes maps the individual classes to their index in one-hot encoding, important for non-numeric classes
dataset.classes = {s: i for i, s in enumerate(set(dataset.stratification.values()))}
dataset.class_oh = np.eye(len(dataset.classes))
dataset.num_clusters = num_clusters

# parse the protein similarity measure
# parse the similarity or distance measure
if sim is None and dist is None:
dataset.similarity, dataset.distance = get_default(dataset.type, dataset.format)
dataset.names = list(dataset.data.keys())
Expand All @@ -312,37 +330,6 @@ def read_data(
return dataset


def read_stratification(strats: DATA_INPUT) -> Tuple[Dict[Any, int], Optional[Dict[str, np.ndarray]]]:
"""
Read in the stratification for the data.
Args:
strats: Stratification input
Returns:
Set of all classes and a dictionary mapping the entity names to their class
"""
# parse the stratification
if isinstance(strats, Path) and strats.is_file():
if strats.suffix[1:].lower() == "csv":
stratification = dict(read_csv(strats, ","))
elif strats.suffix[1:].lower() == "tsv":
stratification = dict(read_csv(strats, "\t"))
else:
raise ValueError()
elif isinstance(strats, dict):
stratification = strats
elif isinstance(strats, Callable):
stratification = strats()
elif isinstance(strats, Generator):
stratification = dict(strats)
else:
return {0: 0}, None

classes = {s: i for i, s in enumerate(set(stratification.values()))}
return classes, stratification


def read_folder(folder_path: Path, file_extension: Optional[str] = None) -> Generator[Tuple[str, str], None, None]:
"""
Read in all PDB file from a folder and ignore non-PDB files.
Expand Down
2 changes: 1 addition & 1 deletion datasail/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.3"
__version__ = "1.0.4"
2 changes: 1 addition & 1 deletion recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: "datasail"
version: '1.0.3'
version: '1.0.4'

source:
path: ..
Expand Down
2 changes: 1 addition & 1 deletion recipe_lite/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: "datasail-lite"
version: '1.0.3'
version: '1.0.4'

source:
path: ..
Expand Down
1 change: 0 additions & 1 deletion tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import os
from pathlib import Path

import pytest
Expand Down

0 comments on commit 903fa27

Please sign in to comment.