diff --git a/deepblast/alignment.py b/deepblast/alignment.py
index 25a222d..168cf5b 100644
--- a/deepblast/alignment.py
+++ b/deepblast/alignment.py
@@ -1,16 +1,21 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
from deepblast.language_model import BiLM, pretrained_language_models
from deepblast.nw_cuda import NeedlemanWunschDecoder as NWDecoderCUDA
-from deepblast.embedding import StackedRNN, EmbedLinear
+from deepblast.embedding import StackedRNN, EmbedLinear, MultiheadProduct
from deepblast.dataset.utils import unpack_sequences
-import torch.nn.functional as F
+import math
+
+
+def swish(x):
+ return x * F.sigmoid(x)
class NeedlemanWunschAligner(nn.Module):
def __init__(self, n_alpha, n_input, n_units, n_embed,
- n_layers=2, lm=None, device='gpu'):
+ n_layers=2, n_heads=16, lm=None, device='gpu', local=True):
""" NeedlemanWunsch Alignment model
Parameters
@@ -25,6 +30,8 @@ def __init__(self, n_alpha, n_input, n_units, n_embed,
Embedding dimension
n_layers : int
Number of RNN layers.
+ n_heads : int
+ Number of heads in multilinear layer.
lm : BiLM
Pretrained language model (optional)
padding_idx : int
@@ -32,6 +39,8 @@ def __init__(self, n_alpha, n_input, n_units, n_embed,
transform : function
Activation function (default relu)
sparse : False?
+ local : bool
+ Specifies if local alignment should be performed on the traceback
"""
super(NeedlemanWunschAligner, self).__init__()
if lm is None:
@@ -39,22 +48,30 @@ def __init__(self, n_alpha, n_input, n_units, n_embed,
self.lm = BiLM()
self.lm.load_state_dict(torch.load(path))
self.lm.eval()
+ transform = swish
if n_layers > 1:
self.match_embedding = StackedRNN(
- n_alpha, n_input, n_units, n_embed, n_layers, lm=lm)
+ n_alpha, n_input, n_units, n_embed, n_layers, lm=lm,
+ transform=swish, rnn_type='gru')
self.gap_embedding = StackedRNN(
- n_alpha, n_input, n_units, n_embed, n_layers, lm=lm)
+ n_alpha, n_input, n_units, n_embed, n_layers, lm=lm,
+ transform=swish, rnn_type='gru')
else:
self.match_embedding = EmbedLinear(
- n_alpha, n_input, n_embed, lm=lm)
+ n_alpha, n_input, n_embed, lm=lm,
+ transform=swish)
self.gap_embedding = EmbedLinear(
- n_alpha, n_input, n_embed, lm=lm)
+ n_alpha, n_input, n_embed, lm=lm,
+ transform=swish)
+ self.match_mixture = MultiheadProduct(n_embed, n_embed, n_heads)
+ self.gap_mixture = MultiheadProduct(n_embed, n_embed, n_heads)
# TODO: make cpu compatible version
# if device == 'cpu':
# self.nw = NWDecoderCPU(operator='softmax')
# else:
self.nw = NWDecoderCUDA(operator='softmax')
+ self.local = local
def forward(self, x, order):
""" Generate alignment matrix.
@@ -72,10 +89,13 @@ def forward(self, x, order):
with torch.enable_grad():
zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order)
-
# Obtain theta through an inner product across latent dimensions
- theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
- A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
+ theta = self.match_mixture(zx, zy)
+ #A = self.gap_mixture(gx, gy)
+ # zero out first row and first column for local alignments
+ G = self.gap_mixture(gx, gy)
+ A = torch.zeros(G.shape).to(G.device)
+ A[:, 1:, 1:] += G[:, 1:, 1:]
aln = self.nw.decode(theta, A)
return aln, theta, A
@@ -84,13 +104,26 @@ def traceback(self, x, order):
with torch.enable_grad():
zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order)
- match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
- gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
+ match = self.match_mixture(zx, zy)
+ # gap = self.gap_mixture(gx, gy)
+ A = self.gap_mixture(gx, gy)
+ gap = torch.zeros(A.shape).to(A.device)
+ gap[:, 1:, 1:] += A[:, 1:, 1:]
+
+ # zero out first row and first column for local alignments
+ # L = gx.shape[1]
+ # gap = torch.zeros((L, L))
+ # gap[1:, 1:] = self.gap_mixture(gx[:, 1:, :], gy[:, 1:, :])
+
B, _, _ = match.shape
+
for b in range(B):
- aln = self.nw.decode(
- match[b, :xlen[b], :ylen[b]].unsqueeze(0),
- gap[b, :xlen[b], :ylen[b]].unsqueeze(0)
- )
+ M = match[b, :xlen[b], :ylen[b]].unsqueeze(0)
+ G = gap[b, :xlen[b], :ylen[b]].unsqueeze(0)
+ # val = math.log(1 - (1/50)) # based on average insertion length
+ # if self.local:
+ # G[0, 0, :] = val
+ # G[0, :, 0] = val
+ aln = self.nw.decode(M, G)
decoded = self.nw.traceback(aln.squeeze())
yield decoded, aln
diff --git a/deepblast/constants.py b/deepblast/constants.py
index 633a8c5..bf7c572 100644
--- a/deepblast/constants.py
+++ b/deepblast/constants.py
@@ -1 +1,7 @@
x, m, y = 0, 1, 2 # state numberings
+
+match_mean = 2
+match_std = 10
+
+gap_mean = -4
+gap_std = 10
diff --git a/deepblast/dataset/dataset.py b/deepblast/dataset/dataset.py
index 45a0e4b..4c6f1e9 100644
--- a/deepblast/dataset/dataset.py
+++ b/deepblast/dataset/dataset.py
@@ -7,11 +7,21 @@
from deepblast.constants import m
from deepblast.dataset.utils import (
state_f, tmstate_f,
- clip_boundaries, states2matrix, states2edges,
- path_distance_matrix
+ remove_gaps, states2matrix, states2edges,
+ path_distance_matrix, gap_mask
)
+def reshape(x, N, M):
+ if x.shape != (N, M) and x.shape != (M, N):
+ raise ValueError(f'The shape of `x` {x.shape} '
+ f'does not agree with ({N}, {M})')
+ if tuple(x.shape) != (N, M):
+ return x.t()
+ else:
+ return x
+
+
class AlignmentDataset(Dataset):
def __init__(self, pairs, tokenizer=UniprotTokenizer()):
self.tokenizer = tokenizer
@@ -38,13 +48,19 @@ def __iter__(self):
yield self.__getitem__(i)
+class FastaDataset(AlignmentDataset):
+ """ Dataset for searching. """
+ def __init__(self, query_path, db_path, tokenizer=UniprotTokenizer()):
+ pass
+
+
class TMAlignDataset(AlignmentDataset):
""" Dataset for training and testing.
This is appropriate for the Malisam / Malidup datasets.
"""
def __init__(self, path, tokenizer=UniprotTokenizer(),
- tm_threshold=0.4, max_len=1024, pad_ends=False,
+ tm_threshold=0.5, max_len=1024, pad_ends=False,
clip_ends=True, construct_paths=False):
""" Read in pairs of proteins.
@@ -126,12 +142,11 @@ def __getitem__(self, i):
pos = self.pairs.iloc[i]['chain2']
states = self.pairs.iloc[i]['alignment']
states = list(map(tmstate_f, states))
- if self.clip_ends:
- gene, pos, states = clip_boundaries(gene, pos, states)
+ gene, pos, states = remove_gaps(gene, pos, states, self.clip_ends)
+ gene_mask, pos_mask = gap_mask(states)
if self.pad_ends:
states = [m] + states + [m]
-
states = torch.Tensor(states).long()
gene = self.tokenizer(str.encode(gene))
pos = self.tokenizer(str.encode(pos))
@@ -148,7 +163,11 @@ def __getitem__(self, i):
path_matrix = path_matrix.t()
if tuple(alignment_matrix.shape) != (len(gene), len(pos)):
alignment_matrix = alignment_matrix.t()
- return gene, pos, states, alignment_matrix, path_matrix
+
+ # gene_mask = torch.Tensor(gene_mask).long()
+ # pos_mask = torch.Tensor(pos_mask).long()
+ return (gene, pos, states, alignment_matrix, path_matrix,
+ gene_mask, pos_mask)
class MaliAlignmentDataset(AlignmentDataset):
diff --git a/deepblast/dataset/tests/test_dataset.py b/deepblast/dataset/tests/test_dataset.py
index 5270189..d314fa9 100644
--- a/deepblast/dataset/tests/test_dataset.py
+++ b/deepblast/dataset/tests/test_dataset.py
@@ -18,14 +18,13 @@ def test_getitem(self):
x = TMAlignDataset(self.data_path, tm_threshold=0,
pad_ends=False, clip_ends=False)
res = x[0]
- self.assertEqual(len(res), 5)
- gene, pos, states, alignment_matrix, _ = res
+ self.assertEqual(len(res), 7)
+ gene, pos, states, alignment_matrix, _, _, _ = res
# test the lengths
self.assertEqual(len(gene), 103)
self.assertEqual(len(pos), 21)
self.assertEqual(len(states), 103)
- # wtf is going on here??
- self.assertEqual(alignment_matrix.shape, (22, 103))
+ self.assertEqual(alignment_matrix.shape, (103, 21))
class TestMaliDataset(unittest.TestCase):
@@ -47,7 +46,7 @@ def test_getitem(self):
self.assertEqual(len(gene), 81)
self.assertEqual(len(pos), 81)
self.assertEqual(len(states), 100)
- self.assertEqual(alignment_matrix.shape, (81, 82))
+ self.assertEqual(alignment_matrix.shape, (81, 81))
if __name__ == '__main__':
diff --git a/deepblast/dataset/tests/test_utils.py b/deepblast/dataset/tests/test_utils.py
index 7c0db9b..b1139cf 100644
--- a/deepblast/dataset/tests/test_utils.py
+++ b/deepblast/dataset/tests/test_utils.py
@@ -1,8 +1,10 @@
import unittest
from deepblast.dataset.utils import (
tmstate_f, states2matrix, states2alignment,
- path_distance_matrix, clip_boundaries,
- pack_sequences, unpack_sequences)
+ path_distance_matrix, remove_gaps,
+ pack_sequences, unpack_sequences,
+ gap_mask, merge_mask,
+ remove_orphans)
from math import sqrt
import numpy as np
import numpy.testing as npt
@@ -30,6 +32,11 @@ def test_states2matrix_zinc(self):
s = np.array(list(map(tmstate_f, s)))
states2matrix(s, sparse=True)
+ def test_states2matrix_insert(self):
+ # Test how this is constructed if there are
+ # gaps in the beginning of the alignment
+ pass
+
def test_states2matrix_only_matches(self):
s = ":11::11:"
s = np.array(list(map(tmstate_f, s)))
@@ -223,7 +230,7 @@ def test_clip_ends_none(self):
s_ = [m, m, m, m]
x_ = 'GSSG'
y_ = 'GEIR'
- rx, ry, rs = clip_boundaries(x_, y_, s_)
+ rx, ry, rs = remove_gaps(x_, y_, s_)
self.assertEqual(x_, rx)
self.assertEqual(y_, ry)
self.assertEqual(s_, rs)
@@ -233,7 +240,7 @@ def test_clip_ends(self):
s = [x, m, m, m, y]
x = 'GSSG'
y = 'GEIR'
- rx, ry, rs = clip_boundaries(x, y, s)
+ rx, ry, rs = remove_gaps(x, y, s)
ex, ey, es = 'SSG', 'GEI', [m, m, m]
self.assertEqual(ex, rx)
self.assertEqual(ey, ry)
@@ -245,7 +252,7 @@ def test_clip_ends_2(self):
st = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1,
1, 2, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 1])
- rx, ry, rs = clip_boundaries(gen, oth, st)
+ rx, ry, rs = remove_gaps(gen, oth, st)
self.assertTrue(1)
def test_pack_sequences(self):
@@ -273,5 +280,116 @@ def test_unpack_sequences(self):
tt.assert_allclose(expY, resY)
+class TestPreprocess(unittest.TestCase):
+
+ def test_gap_mask(self):
+ s = ":11::22:"
+ res = gap_mask(s)
+ exp_x = np.array([3, 4])
+ exp_y = np.array([1, 2])
+
+ npt.assert_equal(res[0], exp_x)
+ npt.assert_equal(res[1], exp_y)
+
+ s = ":11:.:22:"
+ res = gap_mask(s)
+ exp_x = np.array([2, 4, 5])
+ exp_y = np.array([1, 2, 4])
+ npt.assert_equal(res[0], exp_x)
+ npt.assert_equal(res[1], exp_y)
+
+ def test_gap_mask2(self):
+ s = (
+ '222222222222222222.11112222222222222222222222222'
+ '222222222222222222222222222222222222222222222222'
+ '22222222...::::::..:2:22::2:::::::..11.111...::.'
+ '::::::::::.::::......:::::::::::222:.::::::::.11'
+ '.:::::::::.:22.::::::::::::2:::::::::::::::1::..'
+ '.::::::::::::::::::::::22:2:2::::::::::1::::::::'
+ '::::22222::::::::::1::::::.'
+ )
+ # N, M = 197, 283
+ gap_mask(s)
+
+ def test_gap_mask3(self):
+ seq = ('TSKINKELITTANDKKYTIATVVKVDGIAWFDRRDGVDQFKADTGNDVWVGPSQA'
+ 'DAAAQVQIVENLIAQGVDAIAIVPFSVEAVEPVLKKARERGIVVISHEASNIQNV'
+ 'DYDIEAFDNKAYGANLKELGKSGGKGKYVTTVGSLTSKSQEWIDGAVEYQKANFP'
+ 'ESEATGRLETYDDANTDYNKLKEATAYPDITGILGAPPTSAGAGRLIAEGGLKGK'
+ 'VFFAGTGLVSVAGEYIKNDDVQYIQFWDPAVAGYANLAVAALEKKNDQIKAGLNL'
+ 'GLPGYESLLAPDAAKPNLLYGAGWVGVTKEND')
+ st = ('222222222222222:::::::::::2::::::1:::::::::::1::::::::.:'
+ ':::::::::::::::::::::::::::::::::::::::::::::::::::1.:::'
+ ':::::::2:::::::::1:::::::1:::::::::::::::::1::::::::::::'
+ ':::::::::2:::::::::::::::::1:::::::::::::1::::::::::::::'
+ ':::::::::::1:::11:::::::::::::::::::2::.:::1:::::::::22:'
+ '::222222222222222222222::::::::::::11::11:11.11111111111'
+ '1111111')
+ xmask, ymask = gap_mask(st)
+ xidx = merge_mask(xmask, len(seq), len(seq))
+ yidx = merge_mask(ymask, len(seq), len(seq))
+
+
+ def test_gap_mask4(self):
+
+ st1 = ('2222222222222222222222222222222222222222222222222222222222222.'
+ '22222222222222222222222222222222222222222222222222222222222222'
+ '22222222222222222222222222222222222222222222222222222222222222'
+ '2::.:.22222222222222222222222222222..2..:.::::::::::::11...::.'
+ '..:::.111111111..:..::::::1:.11.111.::::::.:::::::::::::::::2:'
+ ':..11111111111111122222222222222222222...:::::::::::::::::::2:'
+ '::..:::::::::::::::::::::::::2.:.:::::::::::::::2:::::::::::::'
+ '1::::::::::1:......:::::::::::::::1111.1.11:11:11111111.111111'
+ '1111111111111111111111111111.:1:::::::.2222.22.:...:..::..::.:'
+ '::.::::::.11.....::::::::222.22222222222222222222222222 ')
+ st1 = list(map(tmstate_f, st1))
+ L = 205
+ xmask, ymask = gap_mask(st1)
+ xidx = merge_mask(xmask, L, L)
+ yidx = merge_mask(ymask, L, L)
+ self.assertGreater(len(xidx), 0)
+ self.assertGreater(len(yidx), 0)
+
+ seq = ('PKYQIIDAAVEVIAENGYHQSQVSKIAKQAGVADGTIYLYFKNKEDILISLFKEKGQFI'
+ 'EREEDIKEKATAKEKLALVISKHFSLLAGDHNLAIVTQLELRQSNLELRQKINEILKGY'
+ 'LNILDGILTEGIQSGEIKEGLDVRLARQIFGTIDETVTTWVNDQKYDLVALSNSVLELL'
+ 'VSGIHNK')
+ states = [2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1,
+ 2, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1,
+ 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2, 2, 1]
+ xmask, ymask = gap_mask(states)
+
+ self.assertLess(max(xmask), len(seq))
+
+ def test_replace_orphans_small(self):
+ s = ":11:11:"
+ e = ":111211:"
+ r = remove_orphans(s, threshold=3)
+ self.assertEqual(r, e)
+
+ def test_replace_orphans(self):
+ s = ":1111111111:11111111111111:"
+ e = ":11111111111211111111111111:"
+ r = remove_orphans(s, threshold=9)
+ self.assertEqual(r, e)
+
+ s = ":2222222222:22222222222222:"
+ e = ":22222222221222222222222222:"
+ r = remove_orphans(s, threshold=9)
+ self.assertEqual(r, e)
+
+ s = ":1111111111:22222222222222:"
+ r = remove_orphans(s, threshold=9)
+ self.assertEqual(r, s)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/deepblast/dataset/utils.py b/deepblast/dataset/utils.py
index 14a35f3..7cac1c5 100644
--- a/deepblast/dataset/utils.py
+++ b/deepblast/dataset/utils.py
@@ -4,6 +4,8 @@
from scipy.sparse import coo_matrix
from scipy.spatial import cKDTree
from deepblast.constants import x, m, y
+from itertools import islice
+from functools import reduce
def state_f(z):
@@ -25,17 +27,29 @@ def tmstate_f(z):
return m
-def clip_boundaries(X, Y, A):
+def revstate_f(z):
+ if z == x:
+ return '1'
+ if z == y:
+ return '2'
+ if z == m:
+ return ':'
+
+
+def remove_gaps(X, Y, A, clip_ends=True):
""" Remove xs and ys from ends. """
- if A[0] == m:
- first = 0
- else:
- first = A.index(m)
+ first = 0
+ last = len(A)
+ if clip_ends:
+ if A[0] == m:
+ first = 0
+ else:
+ first = A.index(m)
- if A[-1] == m:
- last = len(A)
- else:
- last = len(A) - A[::-1].index(m)
+ if A[-1] == m:
+ last = len(A)
+ else:
+ last = len(A) - A[::-1].index(m)
X, Y = states2alignment(np.array(A), X, Y)
X_ = X[first:last].replace('-', '')
Y_ = Y[first:last].replace('-', '')
@@ -98,8 +112,17 @@ def states2edges(states):
prev_s, next_s = states[:-1], states[1:]
transitions = list(zip(prev_s, next_s))
state_diffs = np.array(list(map(state_diff_f, transitions)))
- coords = np.cumsum(state_diffs, axis=0).tolist()
- coords = [(0, 0)] + list(map(tuple, coords))
+ coords = np.cumsum(state_diffs, axis=0)
+ if states[0] == m:
+ coords = [(0, 0)] + list(map(tuple, coords.tolist()))
+ elif states[0] == y:
+ coords[:, 0] = np.maximum(coords[:, 0] - 1, 0)
+ coords = [(0, 0)] + list(map(tuple, coords.tolist()))
+ elif states[0] == x:
+ coords[:, 1] = np.maximum(coords[:, 1] - 1, 0)
+ coords = [(0, 0)] + list(map(tuple, coords.tolist()))
+ else:
+ raise ValueError(f'Unrecognized state {states[0]}')
return coords
@@ -138,7 +161,6 @@ def states2alignment(states: np.array, X: str, Y: str):
f'The state string length {sy} does not match '
f'the length of sequence {len(X)}.\n'
f'SequenceX: {X}\nSequenceY: {Y}\nStates: {states}\n'
-
)
i, j = 0, 0
@@ -236,16 +258,30 @@ def collate_f(batch):
states = [x[2] for x in batch]
alignments = [x[3] for x in batch]
paths = [x[4] for x in batch]
- max_x = max(map(len, genes))
- max_y = max(map(len, others))
+ g_mask = [x[5] for x in batch]
+ p_mask = [x[6] for x in batch]
+
+ x_len = list(map(len, genes))
+ y_len = list(map(len, others))
+ max_x = max(x_len)
+ max_y = max(y_len)
+ max_l = max(max_x, max_y)
+ x_mask = []
+ y_mask = []
B = len(genes)
- dm = torch.zeros((B, max_x, max_y))
- p = torch.zeros((B, max_x, max_y))
+ dm = torch.zeros((B, max_l, max_l))
+ p = torch.zeros((B, max_l, max_l))
for b in range(B):
n, m = len(genes[b]), len(others[b])
dm[b, :n, :m] = alignments[b]
p[b, :n, :m] = paths[b]
- return genes, others, states, dm, p
+ gm = merge_mask(g_mask[b], n, max_l)
+ pm = merge_mask(p_mask[b], m, max_l)
+ assert len(gm) > 0, (len(g_mask[b]), max(g_mask[b]), n, max_l)
+ assert len(pm) > 0, (len(p_mask[b]), max(p_mask[b]), m, max_l)
+ x_mask.append(gm)
+ y_mask.append(pm)
+ return genes, others, states, dm, p, (x_mask, y_mask)
def path_distance_matrix(pi):
@@ -273,3 +309,93 @@ def path_distance_matrix(pi):
d, i = model.query(coords)
Pdist = np.array(coo_matrix((d, (coords[:, 0], coords[:, 1]))).todense())
return Pdist
+
+
+def merge_mask(idx, length, max_len):
+ pads = set(list(range(length, max_len)))
+ idx = set(idx.tolist()) | pads
+ allx = set(list(range(0, max_len)))
+ idx = torch.Tensor(list(allx - idx)).long()
+ return idx
+
+
+# Preprocessing functions
+def gap_mask(states):
+ """ Builds a mask for all gaps.
+
+ Reports rows and columns that should be completely masked.
+
+ Parameters
+ ----------
+ states : str
+ List of alignment states
+
+ Returns
+ -------
+ mask : np.array
+ Masked array.
+ """
+ i, j = 0, 0
+ rows, cols = [], []
+ for k in range(len(states)):
+ if states[k] == x:
+ cols.append(i)
+ i += 1
+ elif states[k] == y:
+ rows.append(j)
+ j += 1
+ elif states[k] == m:
+ i += 1
+ j += 1
+ else:
+ raise ValueError(f'{states[k]} is not recognized')
+ return np.array(cols), np.array(rows)
+
+
+def window(seq, n=2):
+ "Returns a sliding window (of width n) over data from the iterable"
+ " s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... "
+ it = iter(seq)
+ result = tuple(islice(it, n))
+ if len(result) == n:
+ yield result
+ for elem in it:
+ result = result[1:] + (elem,)
+ yield result
+
+
+def replace_orphan(w, s=5):
+ i = len(w) // 2
+ # identify orphans and replace with gaps
+ sw = ''.join(w)
+ if (w[i] == ':') and ((('1' * s) in sw[:i] and ('1' * s) in sw[i:]) or
+ (('2' * s) in sw[:i] and ('2' * s) in sw[i:])):
+ return ['1', '2']
+ else:
+ return [w[i]]
+
+
+def remove_orphans(states, threshold: int = 11):
+ """ Removes singletons and doubletons that are orphaned.
+ A match is considered orphaned if it exceeds the `threshold` gap.
+ Parameters
+ ----------
+ states : np.array
+ List of alignment states
+ threshold : int
+ Number of consecutive gaps surrounding a matched required for it
+ to be considered an orphan.
+ Returns
+ -------
+ new_states : np.array
+ States string with orphans removed.
+ Notes
+ -----
+ The threshold *must* be an odd number. This determines the window size.
+ """
+ wins = list(window(states, threshold))
+ rwins = list(map(lambda x: replace_orphan(x, threshold // 2), list(wins)))
+ new_states = list(reduce(lambda x, y: x + y, rwins))
+ new_states = list(states[:threshold // 2]) + new_states
+ new_states += list(states[-threshold // 2 + 1:])
+ return ''.join(new_states)
diff --git a/deepblast/embedding.py b/deepblast/embedding.py
index 510e572..909573f 100644
--- a/deepblast/embedding.py
+++ b/deepblast/embedding.py
@@ -1,7 +1,48 @@
+import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
+def init_weights(m):
+ # https://stackoverflow.com/a/49433937/1167475
+ if type(m) == nn.Linear:
+ nn.init.xavier_uniform(m.weight)
+ m.bias.data.fill_(0.01)
+
+
+class MultiLinear(nn.Module):
+ """ Multiple linear layers concatenated together"""
+ def __init__(self, n_input, n_output, n_heads=16):
+ super(MultiLinear, self).__init__()
+ self.multi_output = nn.ModuleList(
+ [
+ nn.Linear(n_input, n_output)
+ for i in range(n_heads)
+ ]
+ )
+ self.multi_output.apply(init_weights)
+
+ def forward(self, x):
+ outputs = torch.stack(
+ [head(x) for head in self.multi_output], dim=-1)
+ return outputs
+
+
+class MultiheadProduct(nn.Module):
+ def __init__(self, n_input, n_output, n_heads=16):
+ super(MultiheadProduct, self).__init__()
+ self.multilinear = MultiLinear(n_input, n_output, n_heads)
+ self.linear = nn.Linear(n_heads, 1)
+ nn.init.xavier_uniform(self.linear.weight)
+
+ def forward(self, x, y):
+ zx = self.multilinear(x)
+ zy = self.multilinear(y)
+ dists = torch.einsum('bidh,bjdh->bijh', zx, zy)
+ output = self.linear(dists)
+ return output.squeeze()
+
+
class LMEmbed(nn.Module):
def __init__(self, nin, nout, lm, padding_idx=-1, transform=nn.ReLU(),
sparse=False):
@@ -41,7 +82,7 @@ def forward(self, x):
class EmbedLinear(nn.Module):
def __init__(self, nin, nhidden, nout, padding_idx=-1,
- sparse=False, lm=None):
+ sparse=False, lm=None, transform=nn.ReLU()):
super(EmbedLinear, self).__init__()
if padding_idx == -1:
@@ -49,7 +90,8 @@ def __init__(self, nin, nhidden, nout, padding_idx=-1,
if lm is not None:
self.embed = LMEmbed(
- nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse)
+ nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse,
+ transform=transform)
self.proj = nn.Linear(self.embed.nout, nout)
self.lm = True
else:
@@ -58,6 +100,7 @@ def __init__(self, nin, nhidden, nout, padding_idx=-1,
self.proj = nn.Linear(nout, nout)
self.lm = False
+ init_weights(self.proj)
self.nout = nout
def forward(self, x):
@@ -85,7 +128,7 @@ def forward(self, x):
class StackedRNN(nn.Module):
def __init__(self, nin, nembed, nunits, nout, nlayers=2,
padding_idx=-1, dropout=0, rnn_type='lstm',
- sparse=False, lm=None):
+ sparse=False, lm=None, transform=nn.ReLU()):
super(StackedRNN, self).__init__()
if padding_idx == -1:
@@ -93,7 +136,8 @@ def __init__(self, nin, nembed, nunits, nout, nlayers=2,
if lm is not None:
self.embed = LMEmbed(
- nin, nembed, lm, padding_idx=padding_idx, sparse=sparse)
+ nin, nembed, lm, padding_idx=padding_idx, sparse=sparse,
+ transform=transform)
nembed = self.embed.nout
self.lm = True
else:
diff --git a/deepblast/losses.py b/deepblast/losses.py
index 7d5b509..4cb727a 100644
--- a/deepblast/losses.py
+++ b/deepblast/losses.py
@@ -1,4 +1,10 @@
import torch
+from torch.distributions import Normal
+from deepblast.constants import match_mean, match_std, gap_mean, gap_std
+
+
+def mask_tensor(A, x_mask, y_mask):
+ return A[x_mask][:, y_mask]
class AlignmentAccuracy:
@@ -6,8 +12,52 @@ def __call__(self, true_edges, pred_edges):
pass
+class L2MatrixCrossEntropy:
+ def __call__(self, Ytrue, Ypred, M, G, x_mask, y_mask):
+ """ Computes binary cross entropy on the matrix with regularizers.
+
+ The matrix cross entropy loss is given by
+
+ d(ypred, ytrue) = - (mean(ytrue x log(ypred))
+ + mean((1 - ytrue) x log(1 - ypred)))
+
+ Parameters
+ ----------
+ Ytrue : torch.Tensor
+ Ground truth alignment matrix of dimension N x M.
+ All entries are marked by 0 and 1.
+ Ypred : torch.Tensor
+ Predicted alignment matrix of dimension N x M.
+ M : torch.Tensor
+ Match score matrix
+ G : torch.Tensor
+ Gap score matrix
+ """
+ score = 0
+ eps = 3e-8 # unfortunately, this is the smallest eps we can have :(
+ Ypred = torch.clamp(Ypred, min=eps, max=1 - eps)
+ for b in range(len(x_mask)):
+ pos = torch.mean(
+ mask_tensor(Ytrue[b], x_mask[b], y_mask[b]) * torch.log(
+ mask_tensor(Ypred[b], x_mask[b], y_mask[b]))
+ )
+ neg = torch.mean(
+ (1 - mask_tensor(Ytrue[b], x_mask[b], y_mask[b])) * torch.log(
+ 1 - mask_tensor(Ypred[b], x_mask[b], y_mask[b]))
+ )
+ score += -(pos + neg)
+
+ match_prior = Normal(match_mean, match_std)
+ gap_prior = Normal(gap_mean, gap_std)
+ log_like = score / len(x_mask)
+ match_log = match_prior.log_prob(M).mean()
+ gap_log = gap_prior.log_prob(G).mean()
+ score = log_like - match_log - gap_log
+ return score
+
+
class MatrixCrossEntropy:
- def __call__(self, Ytrue, Ypred, x_len, y_len):
+ def __call__(self, Ytrue, Ypred, x_mask, y_mask):
""" Computes binary cross entropy on the matrix
The matrix cross entropy loss is given by
@@ -26,21 +76,29 @@ def __call__(self, Ytrue, Ypred, x_len, y_len):
score = 0
eps = 3e-8 # unfortunately, this is the smallest eps we can have :(
Ypred = torch.clamp(Ypred, min=eps, max=1 - eps)
- for b in range(len(x_len)):
+ for b in range(len(x_mask)):
pos = torch.mean(
- Ytrue[b, :x_len[b], :y_len[b]] * torch.log(
- Ypred[b, :x_len[b], :y_len[b]])
+ mask_tensor(Ytrue[b], x_mask[b], y_mask[b]) * torch.log(
+ mask_tensor(Ypred[b], x_mask[b], y_mask[b]))
)
neg = torch.mean(
- (1 - Ytrue[b, :x_len[b], :y_len[b]]) * torch.log(
- 1 - Ypred[b, :x_len[b], :y_len[b]])
+ (1 - mask_tensor(Ytrue[b], x_mask[b], y_mask[b])) * torch.log(
+ 1 - mask_tensor(Ypred[b], x_mask[b], y_mask[b]))
)
+ # pos = torch.mean(
+ # Ytrue[b, x_mask[b], y_mask[b]] * torch.log(
+ # Ypred[b, x_mask[b], y_mask[b]])
+ # )
+ # neg = torch.mean(
+ # (1 - Ytrue[b, x_mask[b], y_mask[b]]) * torch.log(
+ # 1 - Ypred[b, x_mask[b], y_mask[b]])
+ # f)
score += -(pos + neg)
- return score / len(x_len)
+ return score / len(x_mask)
class SoftPathLoss:
- def __call__(self, Pdist, Ypred, x_len, y_len):
+ def __call__(self, Pdist, Ypred, x_mask, y_mask):
""" Computes a soft path loss
The soft path loss is given by
@@ -60,15 +118,15 @@ def __call__(self, Pdist, Ypred, x_len, y_len):
Predicted alignment matrix of dimension N x M.
"""
score = 0
- for b in range(len(x_len)):
+ for b in range(len(x_mask)):
score += torch.norm(
- Pdist[b, :x_len[b], :y_len[b]] * Ypred[b, :x_len[b], :y_len[b]]
+ Pdist[b, x_mask[b], y_mask[b]] * Ypred[b, x_mask[b], y_mask[b]]
)
- return score / len(x_len)
+ return score / len(x_mask)
class SoftAlignmentLoss:
- def __call__(self, Ytrue, Ypred, x_len, y_len):
+ def __call__(self, Ytrue, Ypred, x_mask, y_mask):
""" Computes soft alignment loss as proposed in Mensch et al.
The soft alignment loss is given by
@@ -96,8 +154,8 @@ def __call__(self, Ytrue, Ypred, x_len, y_len):
since it is possible to leave out important parts of the alignment.
"""
score = 0
- for b in range(len(x_len)):
+ for b in range(len(x_mask)):
score += torch.norm(
- Ytrue[b, :x_len[b], :y_len[b]] - Ypred[b, :x_len[b], :y_len[b]]
+ Ytrue[b, x_mask[b], y_mask[b]] - Ypred[b, x_mask[b], y_mask[b]]
)
- return score / len(x_len)
+ return score / len(x_mask)
diff --git a/deepblast/score.py b/deepblast/score.py
index 2397ec1..81b920d 100644
--- a/deepblast/score.py
+++ b/deepblast/score.py
@@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
-from deepblast.dataset.utils import states2alignment
+from deepblast.dataset.utils import states2alignment, states2edges, tmstate_f
def roc_edges(true_edges, pred_edges):
@@ -16,6 +16,25 @@ def roc_edges(true_edges, pred_edges):
return tp, fp, fn, perc_id, ppv, fnr, fdr
+def alignment_score(true_states: str, pred_states: str):
+ """
+ Computes ROC statistics on alignment
+
+ Parameters
+ ----------
+ true_states : str
+ Ground truth state string
+ pred_states : str
+ Predicted state string
+ """
+ pred_states = list(map(tmstate_f, pred_states))
+ true_states = list(map(tmstate_f, true_states))
+ pred_edges = states2edges(pred_states)
+ true_edges = states2edges(true_states)
+ stats = roc_edges(true_edges, pred_edges)
+ return stats
+
+
def alignment_visualization(truth, pred, match, gap, xlen, ylen):
""" Visualize alignment matrix
diff --git a/deepblast/tests/test_alignment.py b/deepblast/tests/test_alignment.py
index f713ab7..a539768 100644
--- a/deepblast/tests/test_alignment.py
+++ b/deepblast/tests/test_alignment.py
@@ -18,20 +18,20 @@ def setUp(self):
nalpha, ninput, nunits, nembed = 22, 1024, 1024, 1024
self.aligner = NeedlemanWunschAligner(nalpha, ninput, nunits, nembed)
- @unittest.skip
+ @unittest.skipUnless(torch.cuda.is_available(), 'No GPU was detected')
def test_alignment(self):
self.embedding = self.embedding.cuda()
self.aligner = self.aligner.cuda()
x = torch.Tensor(
self.tokenizer(b'ARNDCQEGHILKMFPSTWYVXOUBZ')
- ).unsqueeze(0).long().cuda()
+ ).long().cuda()
y = torch.Tensor(
self.tokenizer(b'ARNDCQEGHILKARNDCQMFPSTWYVXOUBZ')
- ).unsqueeze(0).long().cuda()
- N, M = x.shape[1], y.shape[1]
- seq, order = pack_sequences([x], [y])
+ ).long().cuda()
+ M = max(x.shape[0], y.shape[0])
+ seq, order = pack_sequences([x, x], [y, y])
aln, theta, A = self.aligner(seq, order)
- self.assertEqual(aln.shape, (1, N, M))
+ self.assertEqual(aln.shape, (2, M, M))
@unittest.skipUnless(torch.cuda.is_available(), "No GPU detected")
def test_batch_alignment(self):
@@ -64,8 +64,10 @@ def test_collate_alignment(self):
A2 = torch.ones((len(x2), len(y2))).long()
P1 = torch.ones((len(x1), len(y1))).long()
P2 = torch.ones((len(x2), len(y2))).long()
- batch = [(x1, y1, s1, A1, P1), (x2, y2, s2, A2, P2)]
- gene_codes, other_codes, states, dm, p = collate_f(batch)
+ mask = [torch.Tensor([0, 1]), torch.Tensor([0, 1])]
+ batch = [(x1, y1, s1, A1, P1, mask[0], mask[1]),
+ (x2, y2, s2, A2, P2, mask[0], mask[1])]
+ gene_codes, other_codes, states, dm, p, mask = collate_f(batch)
self.embedding = self.embedding.cuda()
self.aligner = self.aligner.cuda()
seq, order = pack_sequences(gene_codes, other_codes)
diff --git a/deepblast/tests/test_embedding.py b/deepblast/tests/test_embedding.py
new file mode 100644
index 0000000..371db34
--- /dev/null
+++ b/deepblast/tests/test_embedding.py
@@ -0,0 +1,28 @@
+import torch
+from deepblast.embedding import MultiLinear, MultiheadProduct
+import unittest
+
+
+class TestEmbedding(unittest.TestCase):
+ def setUp(self):
+ b, L, d, h = 3, 100, 50, 8
+ self.x = torch.randn(b, L, d)
+ self.y = torch.randn(b, L, d)
+ self.b = b
+ self.L = L
+ self.d = d
+ self.h = h
+
+ def test_multilinear(self):
+ model = MultiLinear(self.d, self.d, self.h)
+ res = model(self.x)
+ self.assertEqual(tuple(res.shape), (self.b, self.L, self.d, self.h))
+
+ def test_multihead_product(self):
+ model = MultiheadProduct(self.d, self.d, self.h)
+ res = model(self.x, self.y)
+ self.assertEqual(tuple(res.shape), (self.b, self.L, self.L))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/deepblast/tests/test_score.py b/deepblast/tests/test_score.py
new file mode 100644
index 0000000..e1c8a25
--- /dev/null
+++ b/deepblast/tests/test_score.py
@@ -0,0 +1,90 @@
+from deepblast.score import roc_edges, alignment_text
+from deepblast.dataset.utils import states2edges, tmstate_f
+import pandas as pd
+import numpy as np
+import unittest
+
+
+class TestScore(unittest.TestCase):
+
+ def setUp(self):
+ pass
+
+ def test_alignment_text(self):
+ gene = 'YACSGGCGQNFRTMSEFNEHMIRLVH'
+ other = 'LICPKHTRDCGKVFKRNSSLRVHEKTH'
+ pred = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 2, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1])
+ truth = np.array([1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1])
+ stats = np.array([1, 1, 1, 1, 1, 1, 1])
+ alignment_text(gene, other, pred, truth, stats)
+
+ def test_roc_edges(self):
+ cols = ['tp', 'fp', 'fn', 'perc_id', 'ppv', 'fnr', 'fdr']
+
+ exp_alignment = (
+ 'FRCPRPAGCE--KLYSTSSHVNKHLLL',
+ 'YDCE---ICQSFKDFSPYMKLRKHRAT',
+ '::::111:::22:::::::::::::::'
+ )
+ res_alignment = (
+ 'FRCPRPAGCEKLYSTSSHVNKHLL',
+ 'YDCEICQSFKDFSPYMKLRKHRAT',
+ '::::::::::::::::::::::::'
+ )
+ # TODO: there are still parts of the alignment
+ # that are being clipped erronously
+ exp_edges = states2edges(
+ list(map(tmstate_f, exp_alignment[2])))
+ res_edges = states2edges(
+ list(map(tmstate_f, res_alignment[2])))
+
+ res = pd.Series(roc_edges(exp_edges, res_edges), index=cols)
+
+ self.assertGreater(res.perc_id, 0.1)
+
+ def test_roc_edges_2(self):
+ cols = ['tp', 'fp', 'fn', 'perc_id', 'ppv', 'fnr', 'fdr']
+ exp_alignment = (
+ 'SVHTLLDEKHETLDSEWEKLVRDAMTSGVSKKQFREFLDYQKWRKSQ',
+ ':1111111111111111111111111111::::::::::::::::::',
+ 'I----------------------------FTYGELQRMQEKERNKGQ'
+ )
+ res_alignment = (
+ 'SVHTLLDEKHETLDSEWEKLVRDAMTSGVSKKQFREFLDYQKWRKSQ',
+ '1:1111111:111111111111111111:11::::::::::::::::',
+ '-I-------F------------------T--YGELQRMQEKERNKGQ'
+ )
+
+ exp_edges = states2edges(
+ list(map(tmstate_f, exp_alignment[2])))
+ res_edges = states2edges(
+ list(map(tmstate_f, res_alignment[2])))
+ res = pd.Series(roc_edges(exp_edges, res_edges), index=cols)
+ self.assertGreater(res.tp, 20)
+ self.assertGreater(res.perc_id, 0.5)
+
+ def test_roc_edges_3(self):
+ cols = ['tp', 'fp', 'fn', 'perc_id', 'ppv', 'fnr', 'fdr']
+ exp_alignment = (
+ 'F--GD--D--------QN-PYTESVDILEDLVIEFITEMTHKAMSI',
+ 'ISHLVIMHEEGEVDGKAIPDLTAPVSAVQAAVSNLVRVGKETVQTT',
+ ':22::22:22222222::2:::::::::::::::::::::::::::'
+ )
+ res_alignment = (
+ '-FG---D------D--QN-PYTESVDILEDLVIEFITEMTHKAMSI',
+ 'ISHLVIMHEEGEVDGKAIPDLTAPVSAVQAAVSNLVRVGKETVQTT',
+ '2::222:222222:22::2:::::::::::::::::::::::::::'
+ )
+ exp_edges = states2edges(
+ list(map(tmstate_f, exp_alignment[2])))
+ res_edges = states2edges(
+ list(map(tmstate_f, res_alignment[2])))
+ res = pd.Series(roc_edges(exp_edges, res_edges), index=cols)
+ self.assertGreater(res.tp, 20)
+ self.assertGreater(res.perc_id, 0.5)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/deepblast/trainer.py b/deepblast/trainer.py
index ead2fc7..e5640c0 100644
--- a/deepblast/trainer.py
+++ b/deepblast/trainer.py
@@ -13,10 +13,13 @@
from deepblast.dataset.alphabet import UniprotTokenizer
from deepblast.dataset import TMAlignDataset
from deepblast.dataset.utils import (
- decode, states2edges, collate_f, unpack_sequences, pack_sequences)
+ decode, states2edges, collate_f, unpack_sequences,
+ pack_sequences, revstate_f)
from deepblast.losses import (
- SoftAlignmentLoss, SoftPathLoss, MatrixCrossEntropy)
+ SoftAlignmentLoss, SoftPathLoss, MatrixCrossEntropy,
+ L2MatrixCrossEntropy)
from deepblast.score import roc_edges, alignment_visualization, alignment_text
+import warnings
class LightningAligner(pl.LightningModule):
@@ -28,6 +31,8 @@ def __init__(self, args):
self.initialize_aligner()
if self.hparams.loss == 'sse':
self.loss_func = SoftAlignmentLoss()
+ elif self.hparams.loss == 'l2_cross_entropy':
+ self.loss_func = L2MatrixCrossEntropy()
elif self.hparams.loss == 'cross_entropy':
self.loss_func = MatrixCrossEntropy()
elif self.hparams.loss == 'path':
@@ -41,15 +46,31 @@ def initialize_aligner(self):
n_input = self.hparams.rnn_input_dim
n_units = self.hparams.rnn_dim
n_layers = self.hparams.layers
- if self.hparams.aligner == 'nw':
- self.aligner = NeedlemanWunschAligner(
- n_alpha, n_input, n_units, n_embed, n_layers)
- else:
- raise NotImplementedError(
- f'Aligner {self.hparams.aligner_type} not implemented.')
+ n_heads = self.hparams.heads
+ self.aligner = NeedlemanWunschAligner(
+ n_alpha, n_input, n_units, n_embed, n_layers, n_heads)
+
def forward(self, x, y):
- return self.aligner.forward(x, y)
+ x_code = torch.Tensor(self.tokenizer(str.encode(x))).long()
+ y_code = torch.Tensor(self.tokenizer(str.encode(y))).long()
+ x_code = x_code.to(self.device)
+ y_code = y_code.to(self.device)
+ seq, order = pack_sequences([x_code], [y_code])
+ A, theta, gap = self.aligner(seq, order)
+ return A, theta, gap
+
+ def align(self, x, y):
+ x_code = torch.Tensor(self.tokenizer(str.encode(x))).long()
+ y_code = torch.Tensor(self.tokenizer(str.encode(y))).long()
+ x_code = x_code.to(self.device)
+ y_code = y_code.to(self.device)
+ seq, order = pack_sequences([x_code], [y_code])
+ gen = self.aligner.traceback(seq, order)
+ decoded = next(gen)
+ pred_x, pred_y, pred_states = list(zip(*decoded))
+ s = ''.join(list(map(revstate_f, pred_states)))
+ return s
def initialize_logging(self, root_dir='./', logging_path=None):
if logging_path is None:
@@ -62,7 +83,7 @@ def initialize_logging(self, root_dir='./', logging_path=None):
def train_dataloader(self):
train_dataset = TMAlignDataset(
- self.hparams.train_pairs,
+ self.hparams.train_pairs, clip_ends=self.hparams.clip_ends,
construct_paths=isinstance(self.loss_func, SoftPathLoss))
train_dataloader = DataLoader(
train_dataset, self.hparams.batch_size, collate_fn=collate_f,
@@ -72,38 +93,42 @@ def train_dataloader(self):
def val_dataloader(self):
valid_dataset = TMAlignDataset(
- self.hparams.valid_pairs,
+ self.hparams.valid_pairs, clip_ends=self.hparams.clip_ends,
construct_paths=isinstance(self.loss_func, SoftPathLoss))
valid_dataloader = DataLoader(
valid_dataset, self.hparams.batch_size, collate_fn=collate_f,
- shuffle=False, num_workers=self.hparams.num_workers,
+ shuffle=True, num_workers=self.hparams.num_workers,
pin_memory=True)
return valid_dataloader
def test_dataloader(self):
+ # Held-out TM-align dataset
test_dataset = TMAlignDataset(
- self.hparams.test_pairs,
+ self.hparams.test_pairs, clip_ends=self.hparams.clip_ends,
construct_paths=isinstance(self.loss_func, SoftPathLoss))
test_dataloader = DataLoader(
- test_dataset, self.hparams.batch_size, shuffle=False,
+ test_dataset, self.hparams.batch_size, shuffle=True,
collate_fn=collate_f, num_workers=self.hparams.num_workers,
pin_memory=True)
return test_dataloader
- def compute_loss(self, x, y, predA, A, P, theta):
-
+ def compute_loss(self, mask, predA, A, P, theta, gap):
+ x_mask, y_mask = mask
if isinstance(self.loss_func, SoftAlignmentLoss):
- loss = self.loss_func(A, predA, x, y)
+ loss = self.loss_func(A, predA, x_mask, y_mask)
elif isinstance(self.loss_func, MatrixCrossEntropy):
- loss = self.loss_func(A, predA, x, y)
+ loss = self.loss_func(A, predA, x_mask, y_mask)
+ elif isinstance(self.loss_func, L2MatrixCrossEntropy):
+ loss = self.loss_func(A, predA, theta, gap, x_mask, y_mask)
elif isinstance(self.loss_func, SoftPathLoss):
- loss = self.loss_func(P, predA, x, y)
+ loss = self.loss_func(P, predA, x_mask, y_mask)
if self.hparams.multitask:
current_lr = self.trainer.lr_schedulers[0]['scheduler']
current_lr = current_lr.get_last_lr()[0]
max_lr = self.hparams.learning_rate
lam = current_lr / max_lr
- match_loss = self.loss_func(torch.sigmoid(theta), predA, x, y)
+ match_loss = self.loss_func(torch.sigmoid(theta), predA,
+ x_mask, y_mask)
# when learning rate is large, weight match loss
# otherwise, weight towards DP
loss = lam * match_loss + (1 - lam) * loss
@@ -111,17 +136,18 @@ def compute_loss(self, x, y, predA, A, P, theta):
def training_step(self, batch, batch_idx):
self.aligner.train()
- genes, others, s, A, P = batch
+ genes, others, s, A, P, mask = batch
seq, order = pack_sequences(genes, others)
predA, theta, gap = self.aligner(seq, order)
- _, xlen, _, ylen = unpack_sequences(seq, order)
- loss = self.compute_loss(xlen, ylen, predA, A, P, theta)
+ x_mask, y_mask = mask
+ loss = self.compute_loss(mask, predA, A, P, theta, gap)
assert torch.isnan(loss).item() is False
if len(self.trainer.lr_schedulers) >= 1:
current_lr = self.trainer.lr_schedulers[0]['scheduler']
current_lr = current_lr.get_last_lr()[0]
else:
current_lr = self.hparams.learning_rate
+
tensorboard_logs = {'train_loss': loss, 'lr': current_lr}
# log the learning rate
return {'loss': loss, 'log': tensorboard_logs}
@@ -143,6 +169,11 @@ def validation_stats(self, x, y, xlen, ylen, gen,
truth_states = states[b].cpu().detach().numpy()
pred_edges = states2edges(pred_states)
true_edges = states2edges(truth_states)
+ if len(pred_edges) == 0:
+ raise ValueError('No predicted edges', pred_states)
+ if len(true_edges) == 0:
+ raise ValueError('No truth edges', truth_states)
+
stats = roc_edges(true_edges, pred_edges)
if random.random() < self.hparams.visualization_fraction:
Av = A[b].cpu().detach().numpy().squeeze()
@@ -169,17 +200,14 @@ def validation_stats(self, x, y, xlen, ylen, gen,
return statistics
def validation_step(self, batch, batch_idx):
- # TODO: something weird is going on with the lengths
- # Need to make sure that they are being sorted properly
- genes, others, s, A, P = batch
+ genes, others, s, A, P, mask = batch
seq, order = pack_sequences(genes, others)
predA, theta, gap = self.aligner(seq, order)
x, xlen, y, ylen = unpack_sequences(seq, order)
- loss = self.compute_loss(xlen, ylen, predA, A, P, theta)
+ loss = self.compute_loss(mask, predA, A, P, theta, gap)
assert torch.isnan(loss).item() is False
# Obtain alignment statistics + visualizations
gen = self.aligner.traceback(seq, order)
- # TODO; compare the traceback and the forward
statistics = self.validation_stats(
x, y, xlen, ylen, gen, s, A, predA, theta, gap, batch_idx)
statistics = pd.DataFrame(
@@ -194,12 +222,26 @@ def validation_step(self, batch, batch_idx):
return {'validation_loss': loss,
'log': tensorboard_logs}
- def test_step(self, batch, batch_idx):
- pass
+ def custom_parameter_histogram(self):
+ # iterating through all parameters
+ for name, params in self.named_parameters():
+ if params.requires_grad and (params.grad is not None):
+ self.logger.experiment.add_histogram(
+ f'{name}/value', params, self.global_step)
+
+ # def on_after_backward(self):
+ # # example to inspect gradient information in tensorboard
+ # if self.trainer.global_step % 200 == 0: # don't make the tf file huge
+ # for name, params in self.named_parameters():
+ # if params.requires_grad and (params.grad is not None):
+ # self.logger.experiment.add_histogram(
+ # f'{name}/grad', params.grad, self.global_step)
def validation_epoch_end(self, outputs):
loss_f = lambda x: x['validation_loss']
losses = list(map(loss_f, outputs))
+ if len(losses) == 0:
+ raise ValueError('No losses reported', output)
loss = sum(losses) / len(losses)
self.logger.experiment.add_scalar('val_loss', loss, self.global_step)
metrics = ['val_tp', 'val_fp', 'val_fn', 'val_perc_id',
@@ -207,16 +249,24 @@ def validation_epoch_end(self, outputs):
scores = []
for i, m in enumerate(metrics):
loss_f = lambda x: x['log'][m]
- losses = list(map(loss_f, outputs))
- scalar = sum(losses) / len(losses)
- scores.append(scalar)
- self.logger.experiment.add_scalar(m, scalar, self.global_step)
+ losses = np.array(list(map(loss_f, outputs)))
+ losses = losses[np.logical_not(np.isnan(losses))]
+ if len(losses) > 0:
+ scalar = sum(losses) / len(losses)
+ scores.append(scalar)
+ self.logger.experiment.add_scalar(m, scalar, self.global_step)
+ else:
+ warnings.warn(f'No losses reported for {m}.', RuntimeWarning)
+ self.custom_parameter_histogram()
tensorboard_logs = dict(
[('val_loss', loss)] + list(zip(metrics, scores))
)
return {'val_loss': loss, 'log': tensorboard_logs}
+ def test_step(self, batch, batch_idx):
+ pass
+
def test_epoch_end(self, outputs):
pass
@@ -229,7 +279,7 @@ def configure_optimizers(self):
grad_params, lr=self.hparams.learning_rate)
if self.hparams.scheduler == 'cosine_restarts':
scheduler = CosineAnnealingWarmRestarts(
- optimizer, T_0=1, T_mult=2)
+ optimizer, T_0=1, T_mult=1)
elif self.hparams.scheduler == 'cosine':
scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.epochs)
elif self.hparams.scheduler == 'triangular':
@@ -246,6 +296,13 @@ def configure_optimizers(self):
steps = int(np.log2(self.hparams.learning_rate / m))
steps = self.hparams.epochs // steps
scheduler = StepLR(optimizer, step_size=steps, gamma=0.5)
+ elif self.hparams.scheduler == 'inv_steplr':
+ m = 1e-4 # maximum learning rate
+ optimizer = torch.optim.Adam(
+ grad_params, lr=m)
+ steps = int(np.log2(m / self.hparams.learning_rate))
+ steps = self.hparams.epochs // steps
+ scheduler = StepLR(optimizer, step_size=steps, gamma=0.5)
elif self.hparams.scheduler == 'none':
return [optimizer]
else:
@@ -262,10 +319,6 @@ def add_model_specific_args(parent_parser):
'--test-pairs', help='Testing pairs file', required=True)
parser.add_argument(
'--valid-pairs', help='Validation pairs file', required=True)
- parser.add_argument(
- '-a', '--aligner',
- help='Aligner type. Choices include (nw, hmm).',
- required=False, type=str, default='nw')
parser.add_argument(
'--embedding-dim', help='Embedding dimension (default 512).',
required=False, type=int, default=512)
@@ -278,9 +331,13 @@ def add_model_specific_args(parent_parser):
parser.add_argument(
'--layers', help='Number of RNN layers (default 2).',
required=False, type=int, default=2)
+ parser.add_argument(
+ '--heads', help='Number heads in attention layer (default 8).',
+ required=False, type=int, default=8)
parser.add_argument(
'--loss',
- help=('Loss function. Options include {sse, path, cross_entropy} '
+ help=('Loss function. Options include '
+ '{sse, path, cross_entropy, l2_cross_entropy} '
'(default cross_entropy)'),
default='cross_entropy', required=False, type=str)
parser.add_argument(
diff --git a/ipynb/simulation-benchmark.ipynb b/ipynb/simulation-benchmark.ipynb
index 9b986e8..583db3f 100644
--- a/ipynb/simulation-benchmark.ipynb
+++ b/ipynb/simulation-benchmark.ipynb
@@ -19,6 +19,26 @@
"import numpy as np"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Beta-lactamase.hmm tm-align-0.9-30.tab tm_align_output_10k.ali\r\n",
+ "I-set.hmm\t tm-align-0.9-50.tab tm_align_output_10k.tab\r\n",
+ "PPR_2.hmm\t tm-align-0.9.tab\t zf-C2H2.hmm\r\n",
+ "tm-align-0.9-10.tab tm-align-0.9.txt\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!ls ../data"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -29,11 +49,11 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
- "hmm = '../data/zf-C2H2.hmm'\n",
+ "hmm = '../data/Beta-lactamase.hmm'\n",
"n_alignments = 100\n",
"np.random.seed(0)\n",
"align_df = hmm_alignments(n=40, seed=0, n_alignments=n_alignments, hmmfile=hmm)\n",
@@ -54,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -78,7 +98,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -92,7 +112,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -101,7 +121,7 @@
"'/home/juermieboop/Documents/research/garfunkel/ipynb'"
]
},
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -120,56 +140,25 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "usage: ipykernel_launcher.py [-h] --train-pairs TRAIN_PAIRS --test-pairs TEST_PAIRS --valid-pairs VALID_PAIRS [-a ALIGNER]\n",
- " [--embedding-dim EMBEDDING_DIM] [--rnn-input-dim RNN_INPUT_DIM] [--rnn-dim RNN_DIM] [--layers LAYERS]\n",
- " [--loss LOSS] [--learning-rate LEARNING_RATE] [--batch-size BATCH_SIZE] [--multitask MULTITASK]\n",
- " [--finetune FINETUNE] [--clip-ends CLIP_ENDS] [--scheduler SCHEDULER] [--epochs EPOCHS]\n",
- " [--visualization-fraction VISUALIZATION_FRACTION] -o OUTPUT_DIRECTORY [--num-workers NUM_WORKERS]\n",
- " [--gpus GPUS]\n",
- "ipykernel_launcher.py: error: unrecognized arguments: --load-from-checkpoint lightning_logs/version_5/checkpoints\n"
- ]
- },
- {
- "ename": "SystemExit",
- "evalue": "2",
- "output_type": "error",
- "traceback": [
- "An exception has occurred, use %tb to see the full traceback.\n",
- "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/juermieboop/miniconda3/envs/pytorch/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3339: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
- " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"args = [\n",
" '--train-pairs', f'{os.getcwd()}/data/train.txt',\n",
" '--test-pairs', f'{os.getcwd()}/data/test.txt',\n",
- " '--valid-pairs', f'{os.getcwd()}/data/valid.txt',\n",
+ " '--valid-pairs', f'{os.getcwd()}/data/train.txt',\n",
" '--output-directory', output_dir,\n",
- " '--epochs', '32',\n",
+ " '--epochs', '128',\n",
" '--batch-size', '20', \n",
" '--num-workers', '30',\n",
- " '--learning-rate', '1e-4',\n",
- " '--layers', '2',\n",
+ " '--learning-rate', '1e-4', \n",
+ " '--layers', '4',\n",
+ " '--heads', '4',\n",
" '--visualization-fraction', '1',\n",
- " '--loss', 'cross_entropy',\n",
- " '--scheduler', 'none',\n",
- " '--gpus', '1',\n",
- " '--load-from-checkpoint', 'lightning_logs/version_5/checkpoints'\n",
+ " '--loss', 'l2_cross_entropy',\n",
+ " '--scheduler', 'steplr',\n",
+ " '--gpus', '1'\n",
"]\n",
"parser = argparse.ArgumentParser(add_help=False)\n",
"parser = LightningAligner.add_model_specific_args(parser)\n",
@@ -180,9 +169,17 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "No traceback available to show.\n"
+ ]
+ }
+ ],
"source": [
"%tb"
]
@@ -196,9 +193,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/juermieboop/Documents/research/garfunkel/deepblast/embedding.py:9: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n",
+ " nn.init.xavier_uniform(m.weight)\n",
+ "/home/juermieboop/Documents/research/garfunkel/deepblast/embedding.py:36: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n",
+ " nn.init.xavier_uniform(self.linear.weight)\n"
+ ]
+ }
+ ],
"source": [
"model = LightningAligner(args)"
]
@@ -212,18 +220,215 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {
"scrolled": false
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "CUDA_VISIBLE_DEVICES: [0]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------------------\n",
+ "0 | aligner | NeedlemanWunschAligner | 48 M \n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a15504f7cc40479f86df76fd73aed1e2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/juermieboop/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n",
+ " warnings.warn(*args, **kwargs)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "1"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
+ "from pytorch_lightning.profiler import AdvancedProfiler\n",
+ "profiler=AdvancedProfiler()\n",
"trainer = Trainer(\n",
" max_epochs=args.epochs,\n",
" gpus=args.gpus,\n",
" check_val_every_n_epoch=10,\n",
+ " gradient_clip_val=10,\n",
+ " # val_percent_check=0.25 \n",
" # profiler=profiler,\n",
- " fast_dev_run=False,\n",
+ " # fast_dev_run=True,\n",
" # auto_scale_batch_size='power'\n",
")\n",
"\n",
@@ -239,16 +444,26 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "version_0 version_11\tversion_14 version_4 version_7\r\n",
+ "version_1 version_12\tversion_2 version_5 version_8\r\n",
+ "version_10 version_13\tversion_3 version_6 version_9\r\n"
+ ]
+ }
+ ],
"source": [
"!ls lightning_logs"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -257,11 +472,44 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {
"scrolled": false
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Reusing TensorBoard on port 6006 (pid 3827), started 7:41:35 ago. (Use '!kill 3827' to kill it.)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"%tensorboard --logdir lightning_logs"
]
@@ -275,27 +523,102 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "NeedlemanWunschAligner(\n",
+ " (lm): BiLM(\n",
+ " (embed): Embedding(22, 21, padding_idx=21)\n",
+ " (dropout): Dropout(p=0, inplace=False)\n",
+ " (rnn): ModuleList(\n",
+ " (0): LSTM(21, 1024, batch_first=True)\n",
+ " (1): LSTM(1024, 1024, batch_first=True)\n",
+ " )\n",
+ " (linear): Linear(in_features=1024, out_features=21, bias=True)\n",
+ " )\n",
+ " (match_embedding): StackedRNN(\n",
+ " (embed): Embedding(21, 512, padding_idx=20)\n",
+ " (dropout): Dropout(p=0, inplace=False)\n",
+ " (rnn): GRU(512, 512, num_layers=4, batch_first=True, bidirectional=True)\n",
+ " (proj): Linear(in_features=1024, out_features=512, bias=True)\n",
+ " )\n",
+ " (gap_embedding): StackedRNN(\n",
+ " (embed): Embedding(21, 512, padding_idx=20)\n",
+ " (dropout): Dropout(p=0, inplace=False)\n",
+ " (rnn): GRU(512, 512, num_layers=4, batch_first=True, bidirectional=True)\n",
+ " (proj): Linear(in_features=1024, out_features=512, bias=True)\n",
+ " )\n",
+ " (match_mixture): MultiheadProduct(\n",
+ " (multilinear): MultiLinear(\n",
+ " (multi_output): ModuleList(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (linear): Linear(in_features=1, out_features=1, bias=True)\n",
+ " )\n",
+ " (gap_mixture): MultiheadProduct(\n",
+ " (multilinear): MultiLinear(\n",
+ " (multi_output): ModuleList(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (linear): Linear(in_features=1, out_features=1, bias=True)\n",
+ " )\n",
+ " (nw): NeedlemanWunschDecoder()\n",
+ ")"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"model.aligner"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "'epoch=119.ckpt'\r\n"
+ ]
+ }
+ ],
"source": [
"!ls lightning_logs/version_5/checkpoints"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "FileNotFoundError",
+ "evalue": "[Errno 2] No such file or directory: 'lightning_logs/version_70/checkpoints/epoch=59.ckpt'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcheckpoint_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'lightning_logs/version_70/checkpoints'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf'{checkpoint_dir}/epoch=59.ckpt'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLightningAligner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_from_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/core/saving.py\u001b[0m in \u001b[0;36mload_from_checkpoint\u001b[0;34m(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mcheckpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mcheckpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstorage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;31m# add the hparams from csv file to checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(path_or_url, map_location)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0murlparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscheme\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m''\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrive\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# no scheme or with a drive letter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhub\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict_from_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 525\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 526\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0m_open_zipfile_reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'lightning_logs/version_70/checkpoints/epoch=59.ckpt'"
+ ]
+ }
+ ],
"source": [
"from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n",
"checkpoint_dir = 'lightning_logs/version_70/checkpoints'\n",
diff --git a/ipynb/struct-benchmark.ipynb b/ipynb/struct-benchmark.ipynb
index dea9251..b4cee16 100644
--- a/ipynb/struct-benchmark.ipynb
+++ b/ipynb/struct-benchmark.ipynb
@@ -9,7 +9,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -29,7 +29,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +52,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -76,7 +76,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -90,7 +90,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -99,7 +99,7 @@
"'/home/juermieboop/Documents/research/garfunkel/ipynb'"
]
},
- "execution_count": 5,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -118,7 +118,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -127,15 +127,16 @@
" '--test-pairs', f'{os.getcwd()}/data/test.txt',\n",
" '--valid-pairs', f'{os.getcwd()}/data/valid.txt',\n",
" '--output-directory', output_dir,\n",
- " '--epochs', '128',\n",
+ " '--epochs', '32',\n",
" '--batch-size', '20', \n",
" '--num-workers', '30',\n",
- " '--layers', '2',\n",
+ " '--layers', '4',\n",
+ " '--heads', '8',\n",
" '--learning-rate', '5e-5',\n",
" '--visualization-fraction', '1',\n",
" '--loss', 'cross_entropy',\n",
- " '--scheduler', 'steplr', \n",
- " '--clip-ends', 'True',\n",
+ " '--scheduler', 'inv_steplr', \n",
+ " # '--clip-ends', 'True',\n",
" '--gpus', '1'\n",
"]"
]
@@ -149,7 +150,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -161,6 +162,26 @@
"model = LightningAligner(args)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Namespace(batch_size=20, clip_ends=False, embedding_dim=512, epochs=32, finetune=False, gpus=1, heads=8, layers=4, learning_rate=5e-05, loss='cross_entropy', multitask=False, num_workers=30, output_directory='struct_results', rnn_dim=512, rnn_input_dim=512, scheduler='inv_steplr', test_pairs='/home/juermieboop/Documents/research/garfunkel/ipynb/data/test.txt', train_pairs='/home/juermieboop/Documents/research/garfunkel/ipynb/data/train.txt', valid_pairs='/home/juermieboop/Documents/research/garfunkel/ipynb/data/valid.txt', visualization_fraction=1.0)"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "args"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -170,9 +191,9 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 18,
"metadata": {
- "scrolled": false
+ "scrolled": true
},
"outputs": [
{
@@ -185,13 +206,13 @@
"\n",
" | Name | Type | Params\n",
"---------------------------------------------------\n",
- "0 | aligner | NeedlemanWunschAligner | 34 M \n"
+ "0 | aligner | NeedlemanWunschAligner | 52 M \n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "",
+ "model_id": "c45254eb1e5f42e6acd2b65d4a560fa1",
"version_major": 2,
"version_minor": 0
},
@@ -203,69 +224,38 @@
"output_type": "display_data"
},
{
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "ce3249fdf417451c896ecfa630cf1943",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "cc3be93a41df413195269996a5009591",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
+ "ename": "RuntimeError",
+ "evalue": "CUDA out of memory. Tried to allocate 328.00 MiB (GPU 0; 10.92 GiB total capacity; 4.26 GiB already allocated; 115.12 MiB free; 4.55 GiB reserved in total by PyTorch)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders)\u001b[0m\n\u001b[1;32m 977\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 978\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msingle_gpu\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 979\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msingle_gpu_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 980\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 981\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_tpu\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pragma: no-cover\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_parts.py\u001b[0m in \u001b[0;36msingle_gpu_train\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreinit_scheduler_properties\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr_schedulers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_pretrain_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtpu_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtpu_core_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_pretrain_routine\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 1134\u001b[0m \u001b[0mnum_loaders\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval_dataloaders\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1135\u001b[0m \u001b[0mmax_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_sanity_val_steps\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnum_loaders\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1136\u001b[0;31m eval_results = self._evaluate(model,\n\u001b[0m\u001b[1;32m 1137\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[0mmax_batches\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36m_evaluate\u001b[0;34m(self, model, dataloaders, max_batches, test_mode)\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_mode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 293\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_mode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 294\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;31m# on dp / ddp2 might still want to do something with the batch parts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36mevaluation_forward\u001b[0;34m(self, model, batch, batch_idx, dataloader_idx, test_mode)\u001b[0m\n\u001b[1;32m 483\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 484\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 485\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 486\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/Documents/research/garfunkel/deepblast/trainer.py\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mgenes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mothers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mseq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpack_sequences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mothers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m \u001b[0mpredA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maligner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxlen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mylen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munpack_sequences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgap\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/Documents/research/garfunkel/deepblast/alignment.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, order)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Obtain theta through an inner product across latent dimensions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mtheta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatch_mixture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mgap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgap_mixture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;31m#G = self.gap_mixture(gx, gy)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;31m# zero out first row and first column for local alignments\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/Documents/research/garfunkel/deepblast/embedding.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, y)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mzx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultilinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mzy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultilinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mdists\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'bidh,bjdh->bijh'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdists\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/functional.py\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(equation, *operands)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;31m# the old interface of passing the operands as one list argument\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0moperands\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moperands\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_VariableFunctions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mequation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperands\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 328.00 MiB (GPU 0; 10.92 GiB total capacity; 4.26 GiB already allocated; 115.12 MiB free; 4.55 GiB reserved in total by PyTorch)"
+ ]
}
],
"source": [
+ "from pytorch_lightning.profiler import AdvancedProfiler\n",
+ "profiler=AdvancedProfiler()\n",
"trainer = Trainer(\n",
" max_epochs=args.epochs,\n",
" gpus=args.gpus,\n",
- " check_val_every_n_epoch=8,\n",
- " # profiler=profiler,\n",
- " # fast_dev_run=True,\n",
+ " check_val_every_n_epoch=1,\n",
+ " val_percent_check=0.1\n",
+ " #profiler=profiler,\n",
+ " #fast_dev_run=True,\n",
" # auto_scale_batch_size='power'\n",
")\n",
"\n",
@@ -290,7 +280,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -299,11 +289,44 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {
"scrolled": false
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Reusing TensorBoard on port 6006 (pid 3827), started 5:48:33 ago. (Use '!kill 3827' to kill it.)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"%tensorboard --logdir lightning_logs"
]
diff --git a/scripts/deepblast-train b/scripts/deepblast-train
index 85c7b06..6d3f15c 100644
--- a/scripts/deepblast-train
+++ b/scripts/deepblast-train
@@ -24,12 +24,16 @@ def main(args):
num_nodes=args.nodes,
accumulate_grad_batches=args.grad_accum,
gradient_clip_val=args.grad_clip,
-
distributed_backend=args.backend,
precision=args.precision,
- # check_val_every_n_epoch=1,
+ #check_val_every_n_epoch=1,
+ limit_train_batches=0.1,
+ limit_val_batches=0.2,
+ limit_test_batches=0.2,
val_check_interval=0.25,
fast_dev_run=False,
+ # overfit the data
+ # overfit_pct=0.01,
# auto_scale_batch_size='power',
# profiler=profiler,
)
@@ -44,7 +48,7 @@ def main(args):
# initialize Model Checkpoint Saver
checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
- period=1,
+ save_top_k=1,
monitor='validation_loss',
mode='min',
verbose=True
@@ -63,7 +67,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--gpus', type=int, default=None)
parser.add_argument('--grad-accum', type=int, default=1)
- parser.add_argument('--grad-clip', type=int, default=0)
+ parser.add_argument('--grad-clip', type=int, default=10)
parser.add_argument('--nodes', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--precision', type=int, default=32)