-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinfer_profile.py
executable file
·662 lines (512 loc) · 20.7 KB
/
infer_profile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
#!/usr/bin/env python
import argparse
from Bio import SearchIO
from Bio.SeqRecord import SeqRecord
from Bio.SubsMat.MatrixInfo import blosum62
import copy
from functools import partial
import json
import math
from multiprocessing import Pool
import numpy as np
import os
import re
import shutil
import string
import subprocess
import sys
from tqdm import tqdm
bar_format="{percentage:3.0f}%|{bar:20}{r_bar}"
import warnings
# Defaults
root_dir = os.path.dirname(os.path.realpath(__file__))
files_dir = os.path.join(root_dir, "files")
# models_dir = os.path.join(root_dir, "models")
# Append JASPAR-profile-inference to path
sys.path.append(root_dir)
# Import globals
from __init__ import CisBP2Pfam, Jglobals, ReadSRModel, ScoreAlignmentResult
#-------------#
# Functions #
#-------------#
def parse_args():
"""
This function parses arguments provided via the command line and returns
an {argparse} object.
"""
# Initialize
parser = argparse.ArgumentParser()
# Mandatory args
parser.add_argument("sequences",
help="input sequence(s) in FASTA format")
# Optional args
parser.add_argument("--dummy-dir", default="/tmp/", metavar="DIR",
help="dummy directory (default = /tmp/)")
parser.add_argument("--files-dir", default=files_dir, metavar="DIR",
help="files directory from get_files.py (default = ./files/)")
# # parser.add_argument("--models-dir", default=models_dir)
parser.add_argument("--output-file", metavar="FILE",
help="output file (default = STDOUT)")
parser.add_argument("--threads", default=1, metavar="INT", type=int,
help="number of threads to use (default = 1)")
parser.add_argument("-w", "--warnings", action="store_true",
help="issue warnings (default = False)")
# Inference args
inference_group = parser.add_argument_group("inference arguments")
inference_group.add_argument("-l", "--latest", action="store_true",
help="return the latest version of each profile")
inference_group.add_argument("--rost", default=5, metavar="INT",
help="\"n\" parameter for the Rost's curve (default = 5)")
inference_group.add_argument("--taxon", nargs="*", default=Jglobals.taxons,
metavar="STR", help="return profiles from given taxon (default = all)")
args = parser.parse_args()
return(args)
def main():
# Parse arguments
args = parse_args()
# Warnings
if not args.warnings:
warnings.filterwarnings("ignore")
# Infer profiles
infer_profiles(args.sequences, args.dummy_dir, args.files_dir,
args.output_file, args.threads, args.latest, args.rost, args.taxon)
def infer_profiles(fasta_file, dummy_dir="/tmp/", files_dir=files_dir,
output_file=None, threads=1, latest=False, n=5, taxons=Jglobals.taxons):
# Initialize
base_name = os.path.basename(__file__)
pid = os.getpid()
# Load data
cisbp = __load_CisBP_models(files_dir)
# jaspar = __load_JASPAR_files_n_models(files_dir, models_dir, taxons)
jaspar = __load_JASPAR_files_n_models(files_dir, taxons)
# Create dummy dir
dummy_dir = os.path.join(dummy_dir, "%s.%s" % (base_name, pid))
dummy_file = os.path.join(dummy_dir, "inferred_profiles.tsv")
if not os.path.exists(dummy_dir):
os.makedirs(dummy_dir)
# Get sequences as SeqRecords
# Note: https://biopython.org/wiki/SeqRecord
seq_records = []
for seq_record in Jglobals.parse_fasta_file(fasta_file):
seq_records.append(seq_record)
# Write
# columns = ["Query", "TF Name", "TF Matrix", "E-value", "Query Start-End",
# "TF Start-End", "DBD %ID", "Cis-BP", "JASPAR"]
columns = ["Query", "TF Name", "TF Matrix", "E-value", "Query Start-End",
"TF Start-End", "DBD %ID"]
Jglobals.write(dummy_file, "\t".join(columns))
# Infer SeqRecord profiles
kwargs = {"total": len(seq_records), "bar_format": bar_format}
pool = Pool(min([threads, len(seq_records)]))
p = partial(infer_SeqRecord_profiles, cisbp=cisbp, dummy_dir=dummy_dir,
files_dir=files_dir, jaspar=jaspar, latest=latest, n=n, taxons=taxons)
for inferences in tqdm(pool.imap(p, seq_records), **kwargs):
for inference in inferences:
Jglobals.write(dummy_file, "\t".join(map(str, inference)))
pool.close()
pool.join()
# Write
if output_file:
shutil.copy(dummy_file, output_file)
else:
with open(dummy_file) as f:
# For each line...
for line in f:
Jglobals.write(None, line.strip("\n"))
# Remove dummy dir
shutil.rmtree(dummy_dir)
def __load_CisBP_models(files_dir="./files/"):
# Initialize
cisbp = {}
for json_file in os.listdir(os.path.join(files_dir, "cisbp")):
model = ReadSRModel(os.path.join(files_dir, "cisbp", json_file))
cisbp.setdefault(CisBP2Pfam[model["Family_Name"]], model)
return(cisbp)
def __load_JASPAR_files_n_models(files_dir="./files/", taxons=["fungi",
"insects", "nematodes", "plants", "vertebrates"]):
# Initialize
jaspar = {}
pfams = {}
profiles = {}
uniprots = {}
for taxon in taxons:
with open(os.path.join(files_dir, "%s.pfam.json" % taxon)) as f:
for key, values in json.load(f).items():
pfams.setdefault(key, values)
with open(os.path.join(files_dir, "%s.profiles.json" % taxon)) as f:
for key, values in json.load(f).items():
profiles.setdefault(key, values)
with open(os.path.join(files_dir, "%s.uniprot.json" % taxon)) as f:
for key, values in json.load(f).items():
uniprots.setdefault(key, values)
for uniprot in pfams:
# Add Pfam domains
jaspar.setdefault(uniprot, {})
jaspar[uniprot].setdefault("pfam", pfams[uniprot])
# Add profiles
jaspar[uniprot].setdefault("profiles", [])
for profile in uniprots[uniprot][0]:
jaspar[uniprot]["profiles"].append([profile, profiles[profile]])
return(jaspar)
def infer_SeqRecord_profiles(seq_record, cisbp, jaspar, dummy_dir="/tmp/",
files_dir="./files/", latest=False, n=5, taxons=["fungi", "insects",
"nematodes", "plants", "vertebrates"]):
# Initialize
pfam_alignments = []
inference_results = []
# Get SeqRecord Pfam DBDs
_, alignments = __get_SeqRecord_Pfam_alignments(seq_record,
files_dir, dummy_dir)
if len(alignments) == 0:
return(inference_results)
pfam_alignments.append({})
for alignment in alignments:
pfam_alignments[0].setdefault(alignment[0], [])
pfam_alignments[0][alignment[0]].append(alignment[1])
# BLAST+ search
blast_results = blast(seq_record, files_dir, taxons, n)
# Get Pfam DBDs of BLAST+ (filtered) results
pfam_alignments.append(__get_blast_results_Pfam_alignments(blast_results,
jaspar))
for r in blast_results:
inferences = [False, False]
for DBD in pfam_alignments[0]:
if DBD not in pfam_alignments[1][r[1]]:
continue
# Clean sequences
seq1 = pfam_alignments[0][DBD]
seq1 = list([__remove_insertions(s) for s in seq1])
seq2 = pfam_alignments[1][r[1]][DBD]
seq2 = list([__remove_insertions(s) for s in seq2])
# Similarity regression alignment
sr_alignment = {}
ByPosPctID = __get_X(copy.copy(seq1), copy.copy(seq2), "identity")
sr_alignment.setdefault("ByPos.PctID", ByPosPctID)
ByPosAvgB62 = __get_X(copy.copy(seq1), copy.copy(seq2), "blosum62")
sr_alignment.setdefault("ByPos.AvgB62", ByPosAvgB62)
PctID_L = sum(ByPosPctID) / len(ByPosPctID)
sr_alignment.setdefault("PctID_L", PctID_L)
# Inference: Cis-BP
if DBD in cisbp:
model = cisbp[DBD]
else:
model = cisbp[None]
_, Classification = ScoreAlignmentResult(sr_alignment, model)
inferences[0] = Classification == "HSim"
# i.e. inferred result
if not inferences[0] and not inferences[1]:
continue
for matrix, gene_name in jaspar[r[1]]["profiles"]:
# inference_results.append([r[0], gene_name, matrix, r[4], r[2],
# r[3], round(sr_alignment["PctID_L"], 3), inferences[0],
# inferences[1]])
inference_results.append([r[0], gene_name, matrix, r[4], r[2],
r[3], round(sr_alignment["PctID_L"], 3)])
break
# Sort
inference_results.sort(key=lambda x: (x[3], x[1], float(x[2][2:])))
# Remove profiles from older versions
if latest:
for i in sorted(frozenset(range(len(inference_results))), reverse=True):
if inference_results[i][2][:6] == inference_results[i - 1][2][:6]:
inference_results.pop(i)
return(inference_results)
def __get_SeqRecord_Pfam_alignments(seq_record, files_dir="./files/",
dummy_dir="/tmp/"):
# Initialize
pfam_alignments = []
hmm_db = os.path.join(files_dir, "pfam", "All.hmm")
# Make seq file
seq_file = os.path.join(dummy_dir, ".%s.seq.fasta" % os.getpid())
__make_seq_file(seq_record, seq_file)
# For each DBD...
for pfam_id_std, start, end, evalue in hmmscan(seq_file, hmm_db, dummy_dir,
non_overlapping_domains=True):
# Initialize
hmm_file = os.path.join(files_dir, "pfam", "%s.hmm" % pfam_id_std)
# Make seq file
sub_seq_record = SeqRecord(seq_record.seq[start:end], id=seq_record.id,
name=seq_record.name, description=seq_record.description)
__make_seq_file(sub_seq_record, seq_file)
# Add DBDs
alignment = hmmalign(seq_file, hmm_file)
pfam_alignments.append((pfam_id_std, alignment, start+1, end, evalue))
return(seq_record.id, pfam_alignments)
def __make_seq_file(seq_record, file_name=".seq.fa"):
# Remove seq file if exists...
if os.path.exists(file_name):
os.remove(file_name)
# Write
Jglobals.write(file_name, seq_record.format("fasta"))
def hmmscan(seq_file, hmm_file, dummy_dir="/tmp/",
non_overlapping_domains=False):
# Initialize
out_file = os.path.join(dummy_dir, ".%s.out.txt" % os.getpid())
# Scan
cmd = "hmmscan --domtblout %s %s %s" % (out_file, hmm_file, seq_file)
process = subprocess.run([cmd], shell=True, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
# Read domains
domains = __read_domains(out_file)
# Remove output file
if os.path.exists(out_file):
os.remove(out_file)
# Filter overlapping domains
if non_overlapping_domains:
domains = __get_non_overlapping_domains(domains)
# Yield domains one by one
for pfam_ac, start, end, evalue in sorted(domains, key=lambda x: x[1]):
yield(pfam_ac, start, end, evalue)
def __read_domains(file_name):
"""
From PMID:22942020;
A hit has equal probability of being in the same clan as a different clan
when the E-value is 0.01 (log 10 = -2). When the E-value is 10-5, the pro-
bability that a sequence belongs to the same clan is >95%.
From CIS-BP paper;
We scanned all protein sequences for putative DNA-binding domains (DBDs)
using the 81 Pfam (Finn et al., 2010) models listed in (Weirauch and
Hughes, 2011) and the HMMER tool (Eddy, 2009), with the recommended de-
tection thresholds of Per-sequence Eval < 0.01 and Per-domain conditional
Eval < 0.01.
"""
# Initialize
domains = []
cutoff_mod = 1e-5
cutoff_dom = 0.01
# For each result...
for res in SearchIO.parse(file_name, "hmmscan3-domtab"):
# For each model...
for mod in res.iterhits():
# Skip poor models
if mod.evalue > cutoff_mod:
continue
# For each domain...
for dom in mod.hsps:
# Skip poor domains
if dom.evalue_cond > cutoff_dom:
continue
# Append domain
domains.append((mod.id, dom.query_start, dom.query_end,
dom.evalue_cond))
return(domains)
def __get_non_overlapping_domains(domains):
"""
Do domains 1 & 2 overlap?
---------1111111---------
-------22222------------- True
----------22222---------- True
-------------22222------- True
-----22222--------------- False
---------------22222----- False
"""
# Initialize
nov_domains = []
# Sort domains by e-value
for domain in sorted(domains, key=lambda x: x[-1]):
# Initialize
domains_overlap = False
# For each non-overlapping domain...
for nov_domain in nov_domains:
if domain[1] < nov_domain[2] and domain[2] > nov_domain[1]:
domains_overlap = True
break
# Add non-overlapping domain
if not domains_overlap:
nov_domains.append(domain)
return(nov_domains)
def hmmalign(seq_file, hmm_file):
# Align
cmd = "hmmalign --outformat PSIBLAST %s %s" % (hmm_file, seq_file)
process = subprocess.check_output([cmd], shell=True, universal_newlines=True)
return(__read_PSIBLAST_format(process))
def __read_PSIBLAST_format(psiblast_alignment):
# Initialize
alignment = ""
# For each chunk...
for chunk in psiblast_alignment.split("\n"):
# If alignment substring...
m = re.search("\s+(\S+)$", chunk)
if m:
alignment += m.group(1)
return(alignment)
def blast(seq_record, files_dir="./files/", taxons=["fungi", "insects",
"nematodes", "plants", "vertebrates"], n=5):
# Initialize
blast_results = set()
outfmt = "sseqid pident length qstart qend sstart send evalue bitscore ppos qlen slen"
# For each taxon...
for taxon in taxons:
# Taxon db
taxon_db = os.path.join(files_dir, "%s.fa" % taxon)
# Run BLAST+
cmd = "blastp -db %s -outfmt \"6 %s\"" % (taxon_db, outfmt)
process = subprocess.Popen([cmd], shell=True, stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
fasta_sequence = ">%s\n%s" % (seq_record.id, seq_record.seq)
process.stdin.write(fasta_sequence.encode())
(blast_records, blast_errors) = process.communicate()
# For each BLAST+ record...
for blast_record in blast_records.decode("utf-8").split("\n"):
# Custom BLAST+ record:
# (1) identifier of target sequence;
# (2) percentage of identical matches;
# (3) alignment length;
# (4-5, 6-7) start and end-position in query and in target;
# (8) E-value;
# (9) bit score;
# (10) percentage of positive-scoring matches; and
# (4-7, 11, 12) joint coverage (i.e. square root of the coverage
# on the query and the target).
blast_record = blast_record.split("\t")
# Skip if not a BLAST+ record
if len(blast_record) != 12: continue
# Get BLAST+ record
target_id = blast_record[0]
percent_identities = float(blast_record[1])
alignment_length = int(blast_record[2])
query_start_end = "%s-%s" % (blast_record[3], blast_record[4])
target_start_end = "%s-%s" % (blast_record[5], blast_record[6])
e_value = float(blast_record[7])
score = float(blast_record[8])
percent_similarity = float(blast_record[9])
query_aligned_residues = int(blast_record[4]) - \
int(blast_record[3]) + 1
query_length = float(blast_record[10])
target_aligned_residues = int(blast_record[6]) - \
int(blast_record[5]) + 1
target_length = float(blast_record[11])
query_coverage = query_aligned_residues * 100 / query_length
target_coverage = target_aligned_residues * 100 / target_length
joint_coverage = math.sqrt(query_coverage * target_coverage)
# Add BLAST+ record to search results
blast_results.add((seq_record.id, target_id, query_start_end,
target_start_end, e_value, score, percent_identities,
alignment_length, percent_similarity, joint_coverage))
# Return filtered results sorted by score
blast_results = sorted(blast_results, key=lambda x: x[-1], reverse=True)
return(__filter_blast_results_by_Rost(blast_results))
def __filter_blast_results_by_Rost(blast_results, n=5):
# Initialize
blast_homologs = []
# For each result...
for result in blast_results:
# Initialize
pid = result[6]
L = result[7]
# If homologs...
if __is_alignment_over_Rost_seq_id_curve(pid, L, n):
# Add homolog
blast_homologs.append(result)
return(blast_homologs)
def __is_alignment_over_Rost_seq_id_curve(pid, L, n=5):
"""
This function returns whether an alignment is over the Rost's pairwise
sequence identity curve or not.
"""
return(pid >= __get_Rost_cutoff_percent_identity(L, n))
def __get_Rost_cutoff_percent_identity(L, n=5):
"""
This function returns the Rost's cut-off percentage of identical residues
for an alignment of length "L".
"""
return(n + (480 * pow(L, -0.32 * (1 + pow(math.e, float(-L) / 1000)))))
# def __is_alignment_over_Rost_seq_sim_curve(psim, L, n=12):
# """
# This function returns whether an alignment is over the Rost's pairwise
# sequence similarity curve or not.
# """
# return(psim >= _get_Rost_cutoff_percent_similarity(L, n))
# def __get_Rost_cutoff_percent_similarity(L, n=12):
# """
# This function returns the Rost's cut-off percentage of "similar" residues
# for an alignment of length "L".
# """
# return(n + (420 * pow(L, -0.335 * (1 + pow(math.e, float(-L) / 2000)))))
def __get_blast_results_Pfam_alignments(blast_results, jaspar):
# Initialize
pfam_alignments = {}
# For each BLAST result...
for blast_result in blast_results:
pfam_alignments.setdefault(blast_result[1], {})
for alignment in jaspar[blast_result[1]]["pfam"]:
pfam_alignments[blast_result[1]].setdefault(alignment[0], [])
pfam_alignments[blast_result[1]][alignment[0]].append(alignment[1])
return(pfam_alignments)
def __get_CisBP_models(DBDs, cisbp):
"""
Return Cis-BP highly-similar ("hsim") and dissimilar ("dis") thresholds
based on DBD percentage of sequence identity ("pid") and similarity
regression ("sr") models. For similarity regression models, it further
returns the feature scaling "mean" and standard deviation ("sd"), and the
features "intercept" and "weights".
"""
# Initialize
thresholds = {}
# For each DBD...
for DBD in DBDs:
# Initialize
thresholds.setdefault(DBD, {})
# Get thresholds
if DBD in cisbp:
m = cisbp[DBD]
else:
m = cisbp[None]
for what_threshold in ["dis", "hsim"]:
t = m.get_model_threshold("pid", what_threshold)
if t is None:
t = cisbp[None].get_model_threshold("pid", what_threshold)
thresholds[DBD].setdefault("pid", {})
thresholds[DBD]["pid"].setdefault(what_threshold, t)
thresholds[DBD].setdefault("sr", m.get_model("sr"))
return(thresholds)
def __remove_insertions(s):
"""
Remove insertions (i.e. lower case letters)
"""
return(s.translate(str.maketrans("", "", string.ascii_lowercase)))
def __get_X(seq1, seq2, similarity="identity"):
"""
Compare DBDs.
"""
# Initialize
scores = []
# Reassign seq with more DBDs to seq1
seq1, seq2 = __reassign(seq1, seq2)
for i in range(len(seq1) - len(seq2) + 1):
# Initialize
arr = [0] * len(seq1[i])
for j in range(len(seq2)):
for k in range(len(seq1[i+j])):
arr[k] += __score(seq1[i+j][k], seq2[j][k], similarity)
# Append scores and rescale by the number of DBDs in seq2
scores.append(np.array(arr))
scores[-1] = scores[-1] / len(seq2)
# Sort
scores.sort(key=lambda x: sum(x), reverse=True)
return(scores[0])
def __reassign(seq1, seq2):
if len(seq1) < len(seq2):
return(seq2, seq1)
return(seq1, seq2)
def __score(aa1, aa2, similarity="identity"):
if similarity == "identity":
if aa1 == aa2:
return(1)
else:
return(0)
elif similarity == "blosum62":
if aa1 == "-" and aa2 == "-":
return(1)
elif aa1 == "-" or aa2 == "-":
return(-4)
else:
if (aa1, aa2) in blosum62:
return(blosum62[(aa1, aa2)])
else:
return(blosum62[(aa2, aa1)])
#-------------#
# Main #
#-------------#
if __name__ == "__main__":
main()