diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 8c95096..7f3d3dd 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -9,8 +9,9 @@ from tqdm import tqdm from multiprocessing import Pool import sourmash +import glob from typing import List, Set, Tuple -from .utils import load_signature_with_ksize +from .utils import load_signature_with_ksize, decompress_all_sig_files # Configure Loguru logger from loguru import logger @@ -23,7 +24,7 @@ sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" ) -SIG_SUFFIX = ".sig.gz" +SIG_SUFFIX = ".sig" def get_organisms_with_nonzero_overlap( @@ -61,6 +62,12 @@ def get_organisms_with_nonzero_overlap( logger.info("Unzipping the sample signature zip file") with zipfile.ZipFile(sample_file, "r") as sample_zip_file: sample_zip_file.extractall(path_to_sample_temp_dir) + all_gz_files = glob.glob(f"{path_to_sample_temp_dir}/signatures/*.sig.gz") + + # decompress all signature files + logger.info(f"Decompressing {len(all_gz_files)} .sig.gz files using {num_threads} threads.") + decompress_all_sig_files(all_gz_files, num_threads) + sample_sig_file = pd.DataFrame( [ @@ -141,7 +148,7 @@ def __find_exclusive_hashes( ) -> Set[int]: # load genome signature sig = load_signature_with_ksize( - os.path.join(path_to_temp_dir, "signatures", md5sum + ".sig.gz"), ksize + os.path.join(path_to_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize ) return {hash for hash in sig.minhash.hashes if hash in single_occurrence_hashes} @@ -155,7 +162,7 @@ def __find_exclusive_hashes( multiple_occurrence_hashes: Set[int] = set() for md5sum in tqdm(organism_md5sum_list, desc="Processing organism signatures"): sig = load_signature_with_ksize( - os.path.join(path_to_genome_temp_dir, "signatures", md5sum + ".sig.gz"), + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize, ) for hash in sig.minhash.hashes: diff --git a/src/yacht/make_training_data_from_sketches.py b/src/yacht/make_training_data_from_sketches.py index 3699470..8971911 100644 --- a/src/yacht/make_training_data_from_sketches.py +++ b/src/yacht/make_training_data_from_sketches.py @@ -7,6 +7,7 @@ from loguru import logger import json import shutil +import glob from . import utils # Configure Loguru logger @@ -107,6 +108,11 @@ def main(args): logger.info("Unzipping the sourmash signature file to the temporary directory") with zipfile.ZipFile(ref_file, "r") as sourmash_db: sourmash_db.extractall(path_to_temp_dir) + all_gz_files = glob.glob(f"{path_to_temp_dir}/signatures/*.sig.gz") + + # decompress all signature files + logger.info(f"Decompressing {len(all_gz_files)} .sig.gz files using {num_threads} threads.") + utils.decompress_all_sig_files(all_gz_files, num_threads) # Extract signature information logger.info("Extracting signature information") diff --git a/src/yacht/run_yacht_train_core b/src/yacht/run_yacht_train_core deleted file mode 100755 index dd9715b..0000000 Binary files a/src/yacht/run_yacht_train_core and /dev/null differ diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 4b4d9d2..d402622 100644 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -8,6 +8,7 @@ from loguru import logger from typing import Optional, List, Set, Dict, Tuple import shutil +import gzip from glob import glob # Configure Loguru logger @@ -477,75 +478,32 @@ def check_download_args(args, db_type): logger.error("We now haven't supported for virus database.") sys.exit(1) -def _temp_get_genome_name(sig_file_path, ksize): - res = get_info_from_single_sig(sig_file_path, ksize) - if res: - return res[0] - else: - return None - -def temp_generate_inputs( - selected_genomes_file_path: str, - sig_info_dict: Dict[str, Tuple[str, float, int, int]], - ksize: int, - num_threads: int = 16, -) -> Tuple[pd.DataFrame, pd.DataFrame]: +def _decompress_and_remove(file_path: str) -> None: """ - Temporary Helper function that generates the required input for `yacht run`. - :param selected_genomes_file_path: Path to a file containing all the genome file path. - :param num_threads: Number of threads to use for multiprocessing when reading the comparison files. Default is 16. - :param sig_info_dict: - A dictionary mapping each genome signature name to a tuple containing metadata: - (md5sum, minhash mean abundance, minhash hashes length, minhash scaled). - - md5sum: Checksum for data integrity. - - minhash mean abundance: The mean abundance for the genome's minhash. - - minhash hashes length: The length of minhash hashes. - - minhash scaled: The scaling factor for the minhash. - :return - manifest_df: a dataframe containing the processed reference signature information + Decompresses a GZIP-compressed file and removes the original compressed file. + :param file_path: The path to the .sig.gz file that needs to be decompressed and deleted. + :return: None """ - # get info from the signature files of selected genomes - selected_sig_files = pd.read_csv(selected_genomes_file_path, sep="\t", header=None) - selected_sig_files = selected_sig_files[0].to_list() - - # get the genome name from the signature files using multiprocessing - with Pool(num_threads) as p: - result_list = p.starmap(_temp_get_genome_name, [(sig_file_path, ksize) for sig_file_path in selected_sig_files]) - selected_genome_names_set = set([x for x in result_list if x]) + try: + output_filename = os.path.splitext(file_path)[0] + with gzip.open(file_path, 'rb') as f_in: + with open(output_filename, 'wb') as f_out: + f_out.write(f_in.read()) - # remove the close related organisms from the reference genome list - manifest_df = [] - for sig_name, ( - md5sum, - minhash_mean_abundance, - minhash_hashes_len, - minhash_scaled, - ) in tqdm(sig_info_dict.items(), desc="Removing close related organisms from the reference genome list"): - if sig_name in selected_genome_names_set: - manifest_df.append( - ( - sig_name, - md5sum, - minhash_hashes_len, - get_num_kmers( - minhash_mean_abundance, - minhash_hashes_len, - minhash_scaled, - False, - ), - minhash_scaled, - ) - ) - manifest_df = pd.DataFrame( - manifest_df, - columns=[ - "organism_name", - "md5sum", - "num_unique_kmers_in_genome_sketch", - "num_total_kmers_in_genome_sketch", - "genome_scale_factor", - ], - ) + os.remove(file_path) - return manifest_df \ No newline at end of file + except Exception as e: + logger.info(f"Failed to process {file_path}: {e}") + +def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: + """ + Decompresses all .sig.gz files in the list using multiple threads. + :param sig_files: List of .sig.gz files that need to be decompressed. + :param num_threads: Number of threads to use for decompression. + :return: None + """ + with Pool(num_threads) as p: + p.map(_decompress_and_remove, sig_files) + + logger.info("All .sig.gz files have been decompressed.") \ No newline at end of file