diff --git a/deepblast/alignment.py b/deepblast/alignment.py index 25a222d..168cf5b 100644 --- a/deepblast/alignment.py +++ b/deepblast/alignment.py @@ -1,16 +1,21 @@ import torch import torch.nn as nn +import torch.nn.functional as F from deepblast.language_model import BiLM, pretrained_language_models from deepblast.nw_cuda import NeedlemanWunschDecoder as NWDecoderCUDA -from deepblast.embedding import StackedRNN, EmbedLinear +from deepblast.embedding import StackedRNN, EmbedLinear, MultiheadProduct from deepblast.dataset.utils import unpack_sequences -import torch.nn.functional as F +import math + + +def swish(x): + return x * F.sigmoid(x) class NeedlemanWunschAligner(nn.Module): def __init__(self, n_alpha, n_input, n_units, n_embed, - n_layers=2, lm=None, device='gpu'): + n_layers=2, n_heads=16, lm=None, device='gpu', local=True): """ NeedlemanWunsch Alignment model Parameters @@ -25,6 +30,8 @@ def __init__(self, n_alpha, n_input, n_units, n_embed, Embedding dimension n_layers : int Number of RNN layers. + n_heads : int + Number of heads in multilinear layer. lm : BiLM Pretrained language model (optional) padding_idx : int @@ -32,6 +39,8 @@ def __init__(self, n_alpha, n_input, n_units, n_embed, transform : function Activation function (default relu) sparse : False? + local : bool + Specifies if local alignment should be performed on the traceback """ super(NeedlemanWunschAligner, self).__init__() if lm is None: @@ -39,22 +48,30 @@ def __init__(self, n_alpha, n_input, n_units, n_embed, self.lm = BiLM() self.lm.load_state_dict(torch.load(path)) self.lm.eval() + transform = swish if n_layers > 1: self.match_embedding = StackedRNN( - n_alpha, n_input, n_units, n_embed, n_layers, lm=lm) + n_alpha, n_input, n_units, n_embed, n_layers, lm=lm, + transform=swish, rnn_type='gru') self.gap_embedding = StackedRNN( - n_alpha, n_input, n_units, n_embed, n_layers, lm=lm) + n_alpha, n_input, n_units, n_embed, n_layers, lm=lm, + transform=swish, rnn_type='gru') else: self.match_embedding = EmbedLinear( - n_alpha, n_input, n_embed, lm=lm) + n_alpha, n_input, n_embed, lm=lm, + transform=swish) self.gap_embedding = EmbedLinear( - n_alpha, n_input, n_embed, lm=lm) + n_alpha, n_input, n_embed, lm=lm, + transform=swish) + self.match_mixture = MultiheadProduct(n_embed, n_embed, n_heads) + self.gap_mixture = MultiheadProduct(n_embed, n_embed, n_heads) # TODO: make cpu compatible version # if device == 'cpu': # self.nw = NWDecoderCPU(operator='softmax') # else: self.nw = NWDecoderCUDA(operator='softmax') + self.local = local def forward(self, x, order): """ Generate alignment matrix. @@ -72,10 +89,13 @@ def forward(self, x, order): with torch.enable_grad(): zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order) gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order) - # Obtain theta through an inner product across latent dimensions - theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy)) - A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy)) + theta = self.match_mixture(zx, zy) + #A = self.gap_mixture(gx, gy) + # zero out first row and first column for local alignments + G = self.gap_mixture(gx, gy) + A = torch.zeros(G.shape).to(G.device) + A[:, 1:, 1:] += G[:, 1:, 1:] aln = self.nw.decode(theta, A) return aln, theta, A @@ -84,13 +104,26 @@ def traceback(self, x, order): with torch.enable_grad(): zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order) gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order) - match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy)) - gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy)) + match = self.match_mixture(zx, zy) + # gap = self.gap_mixture(gx, gy) + A = self.gap_mixture(gx, gy) + gap = torch.zeros(A.shape).to(A.device) + gap[:, 1:, 1:] += A[:, 1:, 1:] + + # zero out first row and first column for local alignments + # L = gx.shape[1] + # gap = torch.zeros((L, L)) + # gap[1:, 1:] = self.gap_mixture(gx[:, 1:, :], gy[:, 1:, :]) + B, _, _ = match.shape + for b in range(B): - aln = self.nw.decode( - match[b, :xlen[b], :ylen[b]].unsqueeze(0), - gap[b, :xlen[b], :ylen[b]].unsqueeze(0) - ) + M = match[b, :xlen[b], :ylen[b]].unsqueeze(0) + G = gap[b, :xlen[b], :ylen[b]].unsqueeze(0) + # val = math.log(1 - (1/50)) # based on average insertion length + # if self.local: + # G[0, 0, :] = val + # G[0, :, 0] = val + aln = self.nw.decode(M, G) decoded = self.nw.traceback(aln.squeeze()) yield decoded, aln diff --git a/deepblast/constants.py b/deepblast/constants.py index 633a8c5..bf7c572 100644 --- a/deepblast/constants.py +++ b/deepblast/constants.py @@ -1 +1,7 @@ x, m, y = 0, 1, 2 # state numberings + +match_mean = 2 +match_std = 10 + +gap_mean = -4 +gap_std = 10 diff --git a/deepblast/dataset/dataset.py b/deepblast/dataset/dataset.py index 45a0e4b..4c6f1e9 100644 --- a/deepblast/dataset/dataset.py +++ b/deepblast/dataset/dataset.py @@ -7,11 +7,21 @@ from deepblast.constants import m from deepblast.dataset.utils import ( state_f, tmstate_f, - clip_boundaries, states2matrix, states2edges, - path_distance_matrix + remove_gaps, states2matrix, states2edges, + path_distance_matrix, gap_mask ) +def reshape(x, N, M): + if x.shape != (N, M) and x.shape != (M, N): + raise ValueError(f'The shape of `x` {x.shape} ' + f'does not agree with ({N}, {M})') + if tuple(x.shape) != (N, M): + return x.t() + else: + return x + + class AlignmentDataset(Dataset): def __init__(self, pairs, tokenizer=UniprotTokenizer()): self.tokenizer = tokenizer @@ -38,13 +48,19 @@ def __iter__(self): yield self.__getitem__(i) +class FastaDataset(AlignmentDataset): + """ Dataset for searching. """ + def __init__(self, query_path, db_path, tokenizer=UniprotTokenizer()): + pass + + class TMAlignDataset(AlignmentDataset): """ Dataset for training and testing. This is appropriate for the Malisam / Malidup datasets. """ def __init__(self, path, tokenizer=UniprotTokenizer(), - tm_threshold=0.4, max_len=1024, pad_ends=False, + tm_threshold=0.5, max_len=1024, pad_ends=False, clip_ends=True, construct_paths=False): """ Read in pairs of proteins. @@ -126,12 +142,11 @@ def __getitem__(self, i): pos = self.pairs.iloc[i]['chain2'] states = self.pairs.iloc[i]['alignment'] states = list(map(tmstate_f, states)) - if self.clip_ends: - gene, pos, states = clip_boundaries(gene, pos, states) + gene, pos, states = remove_gaps(gene, pos, states, self.clip_ends) + gene_mask, pos_mask = gap_mask(states) if self.pad_ends: states = [m] + states + [m] - states = torch.Tensor(states).long() gene = self.tokenizer(str.encode(gene)) pos = self.tokenizer(str.encode(pos)) @@ -148,7 +163,11 @@ def __getitem__(self, i): path_matrix = path_matrix.t() if tuple(alignment_matrix.shape) != (len(gene), len(pos)): alignment_matrix = alignment_matrix.t() - return gene, pos, states, alignment_matrix, path_matrix + + # gene_mask = torch.Tensor(gene_mask).long() + # pos_mask = torch.Tensor(pos_mask).long() + return (gene, pos, states, alignment_matrix, path_matrix, + gene_mask, pos_mask) class MaliAlignmentDataset(AlignmentDataset): diff --git a/deepblast/dataset/tests/test_dataset.py b/deepblast/dataset/tests/test_dataset.py index 5270189..d314fa9 100644 --- a/deepblast/dataset/tests/test_dataset.py +++ b/deepblast/dataset/tests/test_dataset.py @@ -18,14 +18,13 @@ def test_getitem(self): x = TMAlignDataset(self.data_path, tm_threshold=0, pad_ends=False, clip_ends=False) res = x[0] - self.assertEqual(len(res), 5) - gene, pos, states, alignment_matrix, _ = res + self.assertEqual(len(res), 7) + gene, pos, states, alignment_matrix, _, _, _ = res # test the lengths self.assertEqual(len(gene), 103) self.assertEqual(len(pos), 21) self.assertEqual(len(states), 103) - # wtf is going on here?? - self.assertEqual(alignment_matrix.shape, (22, 103)) + self.assertEqual(alignment_matrix.shape, (103, 21)) class TestMaliDataset(unittest.TestCase): @@ -47,7 +46,7 @@ def test_getitem(self): self.assertEqual(len(gene), 81) self.assertEqual(len(pos), 81) self.assertEqual(len(states), 100) - self.assertEqual(alignment_matrix.shape, (81, 82)) + self.assertEqual(alignment_matrix.shape, (81, 81)) if __name__ == '__main__': diff --git a/deepblast/dataset/tests/test_utils.py b/deepblast/dataset/tests/test_utils.py index 7c0db9b..b1139cf 100644 --- a/deepblast/dataset/tests/test_utils.py +++ b/deepblast/dataset/tests/test_utils.py @@ -1,8 +1,10 @@ import unittest from deepblast.dataset.utils import ( tmstate_f, states2matrix, states2alignment, - path_distance_matrix, clip_boundaries, - pack_sequences, unpack_sequences) + path_distance_matrix, remove_gaps, + pack_sequences, unpack_sequences, + gap_mask, merge_mask, + remove_orphans) from math import sqrt import numpy as np import numpy.testing as npt @@ -30,6 +32,11 @@ def test_states2matrix_zinc(self): s = np.array(list(map(tmstate_f, s))) states2matrix(s, sparse=True) + def test_states2matrix_insert(self): + # Test how this is constructed if there are + # gaps in the beginning of the alignment + pass + def test_states2matrix_only_matches(self): s = ":11::11:" s = np.array(list(map(tmstate_f, s))) @@ -223,7 +230,7 @@ def test_clip_ends_none(self): s_ = [m, m, m, m] x_ = 'GSSG' y_ = 'GEIR' - rx, ry, rs = clip_boundaries(x_, y_, s_) + rx, ry, rs = remove_gaps(x_, y_, s_) self.assertEqual(x_, rx) self.assertEqual(y_, ry) self.assertEqual(s_, rs) @@ -233,7 +240,7 @@ def test_clip_ends(self): s = [x, m, m, m, y] x = 'GSSG' y = 'GEIR' - rx, ry, rs = clip_boundaries(x, y, s) + rx, ry, rs = remove_gaps(x, y, s) ex, ey, es = 'SSG', 'GEI', [m, m, m] self.assertEqual(ex, rx) self.assertEqual(ey, ry) @@ -245,7 +252,7 @@ def test_clip_ends_2(self): st = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1]) - rx, ry, rs = clip_boundaries(gen, oth, st) + rx, ry, rs = remove_gaps(gen, oth, st) self.assertTrue(1) def test_pack_sequences(self): @@ -273,5 +280,116 @@ def test_unpack_sequences(self): tt.assert_allclose(expY, resY) +class TestPreprocess(unittest.TestCase): + + def test_gap_mask(self): + s = ":11::22:" + res = gap_mask(s) + exp_x = np.array([3, 4]) + exp_y = np.array([1, 2]) + + npt.assert_equal(res[0], exp_x) + npt.assert_equal(res[1], exp_y) + + s = ":11:.:22:" + res = gap_mask(s) + exp_x = np.array([2, 4, 5]) + exp_y = np.array([1, 2, 4]) + npt.assert_equal(res[0], exp_x) + npt.assert_equal(res[1], exp_y) + + def test_gap_mask2(self): + s = ( + '222222222222222222.11112222222222222222222222222' + '222222222222222222222222222222222222222222222222' + '22222222...::::::..:2:22::2:::::::..11.111...::.' + '::::::::::.::::......:::::::::::222:.::::::::.11' + '.:::::::::.:22.::::::::::::2:::::::::::::::1::..' + '.::::::::::::::::::::::22:2:2::::::::::1::::::::' + '::::22222::::::::::1::::::.' + ) + # N, M = 197, 283 + gap_mask(s) + + def test_gap_mask3(self): + seq = ('TSKINKELITTANDKKYTIATVVKVDGIAWFDRRDGVDQFKADTGNDVWVGPSQA' + 'DAAAQVQIVENLIAQGVDAIAIVPFSVEAVEPVLKKARERGIVVISHEASNIQNV' + 'DYDIEAFDNKAYGANLKELGKSGGKGKYVTTVGSLTSKSQEWIDGAVEYQKANFP' + 'ESEATGRLETYDDANTDYNKLKEATAYPDITGILGAPPTSAGAGRLIAEGGLKGK' + 'VFFAGTGLVSVAGEYIKNDDVQYIQFWDPAVAGYANLAVAALEKKNDQIKAGLNL' + 'GLPGYESLLAPDAAKPNLLYGAGWVGVTKEND') + st = ('222222222222222:::::::::::2::::::1:::::::::::1::::::::.:' + ':::::::::::::::::::::::::::::::::::::::::::::::::::1.:::' + ':::::::2:::::::::1:::::::1:::::::::::::::::1::::::::::::' + ':::::::::2:::::::::::::::::1:::::::::::::1::::::::::::::' + ':::::::::::1:::11:::::::::::::::::::2::.:::1:::::::::22:' + '::222222222222222222222::::::::::::11::11:11.11111111111' + '1111111') + xmask, ymask = gap_mask(st) + xidx = merge_mask(xmask, len(seq), len(seq)) + yidx = merge_mask(ymask, len(seq), len(seq)) + + + def test_gap_mask4(self): + + st1 = ('2222222222222222222222222222222222222222222222222222222222222.' + '22222222222222222222222222222222222222222222222222222222222222' + '22222222222222222222222222222222222222222222222222222222222222' + '2::.:.22222222222222222222222222222..2..:.::::::::::::11...::.' + '..:::.111111111..:..::::::1:.11.111.::::::.:::::::::::::::::2:' + ':..11111111111111122222222222222222222...:::::::::::::::::::2:' + '::..:::::::::::::::::::::::::2.:.:::::::::::::::2:::::::::::::' + '1::::::::::1:......:::::::::::::::1111.1.11:11:11111111.111111' + '1111111111111111111111111111.:1:::::::.2222.22.:...:..::..::.:' + '::.::::::.11.....::::::::222.22222222222222222222222222 ') + st1 = list(map(tmstate_f, st1)) + L = 205 + xmask, ymask = gap_mask(st1) + xidx = merge_mask(xmask, L, L) + yidx = merge_mask(ymask, L, L) + self.assertGreater(len(xidx), 0) + self.assertGreater(len(yidx), 0) + + seq = ('PKYQIIDAAVEVIAENGYHQSQVSKIAKQAGVADGTIYLYFKNKEDILISLFKEKGQFI' + 'EREEDIKEKATAKEKLALVISKHFSLLAGDHNLAIVTQLELRQSNLELRQKINEILKGY' + 'LNILDGILTEGIQSGEIKEGLDVRLARQIFGTIDETVTTWVNDQKYDLVALSNSVLELL' + 'VSGIHNK') + states = [2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, + 2, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, + 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2, 2, 1] + xmask, ymask = gap_mask(states) + + self.assertLess(max(xmask), len(seq)) + + def test_replace_orphans_small(self): + s = ":11:11:" + e = ":111211:" + r = remove_orphans(s, threshold=3) + self.assertEqual(r, e) + + def test_replace_orphans(self): + s = ":1111111111:11111111111111:" + e = ":11111111111211111111111111:" + r = remove_orphans(s, threshold=9) + self.assertEqual(r, e) + + s = ":2222222222:22222222222222:" + e = ":22222222221222222222222222:" + r = remove_orphans(s, threshold=9) + self.assertEqual(r, e) + + s = ":1111111111:22222222222222:" + r = remove_orphans(s, threshold=9) + self.assertEqual(r, s) + + if __name__ == '__main__': unittest.main() diff --git a/deepblast/dataset/utils.py b/deepblast/dataset/utils.py index 14a35f3..7cac1c5 100644 --- a/deepblast/dataset/utils.py +++ b/deepblast/dataset/utils.py @@ -4,6 +4,8 @@ from scipy.sparse import coo_matrix from scipy.spatial import cKDTree from deepblast.constants import x, m, y +from itertools import islice +from functools import reduce def state_f(z): @@ -25,17 +27,29 @@ def tmstate_f(z): return m -def clip_boundaries(X, Y, A): +def revstate_f(z): + if z == x: + return '1' + if z == y: + return '2' + if z == m: + return ':' + + +def remove_gaps(X, Y, A, clip_ends=True): """ Remove xs and ys from ends. """ - if A[0] == m: - first = 0 - else: - first = A.index(m) + first = 0 + last = len(A) + if clip_ends: + if A[0] == m: + first = 0 + else: + first = A.index(m) - if A[-1] == m: - last = len(A) - else: - last = len(A) - A[::-1].index(m) + if A[-1] == m: + last = len(A) + else: + last = len(A) - A[::-1].index(m) X, Y = states2alignment(np.array(A), X, Y) X_ = X[first:last].replace('-', '') Y_ = Y[first:last].replace('-', '') @@ -98,8 +112,17 @@ def states2edges(states): prev_s, next_s = states[:-1], states[1:] transitions = list(zip(prev_s, next_s)) state_diffs = np.array(list(map(state_diff_f, transitions))) - coords = np.cumsum(state_diffs, axis=0).tolist() - coords = [(0, 0)] + list(map(tuple, coords)) + coords = np.cumsum(state_diffs, axis=0) + if states[0] == m: + coords = [(0, 0)] + list(map(tuple, coords.tolist())) + elif states[0] == y: + coords[:, 0] = np.maximum(coords[:, 0] - 1, 0) + coords = [(0, 0)] + list(map(tuple, coords.tolist())) + elif states[0] == x: + coords[:, 1] = np.maximum(coords[:, 1] - 1, 0) + coords = [(0, 0)] + list(map(tuple, coords.tolist())) + else: + raise ValueError(f'Unrecognized state {states[0]}') return coords @@ -138,7 +161,6 @@ def states2alignment(states: np.array, X: str, Y: str): f'The state string length {sy} does not match ' f'the length of sequence {len(X)}.\n' f'SequenceX: {X}\nSequenceY: {Y}\nStates: {states}\n' - ) i, j = 0, 0 @@ -236,16 +258,30 @@ def collate_f(batch): states = [x[2] for x in batch] alignments = [x[3] for x in batch] paths = [x[4] for x in batch] - max_x = max(map(len, genes)) - max_y = max(map(len, others)) + g_mask = [x[5] for x in batch] + p_mask = [x[6] for x in batch] + + x_len = list(map(len, genes)) + y_len = list(map(len, others)) + max_x = max(x_len) + max_y = max(y_len) + max_l = max(max_x, max_y) + x_mask = [] + y_mask = [] B = len(genes) - dm = torch.zeros((B, max_x, max_y)) - p = torch.zeros((B, max_x, max_y)) + dm = torch.zeros((B, max_l, max_l)) + p = torch.zeros((B, max_l, max_l)) for b in range(B): n, m = len(genes[b]), len(others[b]) dm[b, :n, :m] = alignments[b] p[b, :n, :m] = paths[b] - return genes, others, states, dm, p + gm = merge_mask(g_mask[b], n, max_l) + pm = merge_mask(p_mask[b], m, max_l) + assert len(gm) > 0, (len(g_mask[b]), max(g_mask[b]), n, max_l) + assert len(pm) > 0, (len(p_mask[b]), max(p_mask[b]), m, max_l) + x_mask.append(gm) + y_mask.append(pm) + return genes, others, states, dm, p, (x_mask, y_mask) def path_distance_matrix(pi): @@ -273,3 +309,93 @@ def path_distance_matrix(pi): d, i = model.query(coords) Pdist = np.array(coo_matrix((d, (coords[:, 0], coords[:, 1]))).todense()) return Pdist + + +def merge_mask(idx, length, max_len): + pads = set(list(range(length, max_len))) + idx = set(idx.tolist()) | pads + allx = set(list(range(0, max_len))) + idx = torch.Tensor(list(allx - idx)).long() + return idx + + +# Preprocessing functions +def gap_mask(states): + """ Builds a mask for all gaps. + + Reports rows and columns that should be completely masked. + + Parameters + ---------- + states : str + List of alignment states + + Returns + ------- + mask : np.array + Masked array. + """ + i, j = 0, 0 + rows, cols = [], [] + for k in range(len(states)): + if states[k] == x: + cols.append(i) + i += 1 + elif states[k] == y: + rows.append(j) + j += 1 + elif states[k] == m: + i += 1 + j += 1 + else: + raise ValueError(f'{states[k]} is not recognized') + return np.array(cols), np.array(rows) + + +def window(seq, n=2): + "Returns a sliding window (of width n) over data from the iterable" + " s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... " + it = iter(seq) + result = tuple(islice(it, n)) + if len(result) == n: + yield result + for elem in it: + result = result[1:] + (elem,) + yield result + + +def replace_orphan(w, s=5): + i = len(w) // 2 + # identify orphans and replace with gaps + sw = ''.join(w) + if (w[i] == ':') and ((('1' * s) in sw[:i] and ('1' * s) in sw[i:]) or + (('2' * s) in sw[:i] and ('2' * s) in sw[i:])): + return ['1', '2'] + else: + return [w[i]] + + +def remove_orphans(states, threshold: int = 11): + """ Removes singletons and doubletons that are orphaned. + A match is considered orphaned if it exceeds the `threshold` gap. + Parameters + ---------- + states : np.array + List of alignment states + threshold : int + Number of consecutive gaps surrounding a matched required for it + to be considered an orphan. + Returns + ------- + new_states : np.array + States string with orphans removed. + Notes + ----- + The threshold *must* be an odd number. This determines the window size. + """ + wins = list(window(states, threshold)) + rwins = list(map(lambda x: replace_orphan(x, threshold // 2), list(wins))) + new_states = list(reduce(lambda x, y: x + y, rwins)) + new_states = list(states[:threshold // 2]) + new_states + new_states += list(states[-threshold // 2 + 1:]) + return ''.join(new_states) diff --git a/deepblast/embedding.py b/deepblast/embedding.py index 510e572..909573f 100644 --- a/deepblast/embedding.py +++ b/deepblast/embedding.py @@ -1,7 +1,48 @@ +import torch import torch.nn as nn from torch.nn.utils.rnn import PackedSequence +def init_weights(m): + # https://stackoverflow.com/a/49433937/1167475 + if type(m) == nn.Linear: + nn.init.xavier_uniform(m.weight) + m.bias.data.fill_(0.01) + + +class MultiLinear(nn.Module): + """ Multiple linear layers concatenated together""" + def __init__(self, n_input, n_output, n_heads=16): + super(MultiLinear, self).__init__() + self.multi_output = nn.ModuleList( + [ + nn.Linear(n_input, n_output) + for i in range(n_heads) + ] + ) + self.multi_output.apply(init_weights) + + def forward(self, x): + outputs = torch.stack( + [head(x) for head in self.multi_output], dim=-1) + return outputs + + +class MultiheadProduct(nn.Module): + def __init__(self, n_input, n_output, n_heads=16): + super(MultiheadProduct, self).__init__() + self.multilinear = MultiLinear(n_input, n_output, n_heads) + self.linear = nn.Linear(n_heads, 1) + nn.init.xavier_uniform(self.linear.weight) + + def forward(self, x, y): + zx = self.multilinear(x) + zy = self.multilinear(y) + dists = torch.einsum('bidh,bjdh->bijh', zx, zy) + output = self.linear(dists) + return output.squeeze() + + class LMEmbed(nn.Module): def __init__(self, nin, nout, lm, padding_idx=-1, transform=nn.ReLU(), sparse=False): @@ -41,7 +82,7 @@ def forward(self, x): class EmbedLinear(nn.Module): def __init__(self, nin, nhidden, nout, padding_idx=-1, - sparse=False, lm=None): + sparse=False, lm=None, transform=nn.ReLU()): super(EmbedLinear, self).__init__() if padding_idx == -1: @@ -49,7 +90,8 @@ def __init__(self, nin, nhidden, nout, padding_idx=-1, if lm is not None: self.embed = LMEmbed( - nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse) + nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse, + transform=transform) self.proj = nn.Linear(self.embed.nout, nout) self.lm = True else: @@ -58,6 +100,7 @@ def __init__(self, nin, nhidden, nout, padding_idx=-1, self.proj = nn.Linear(nout, nout) self.lm = False + init_weights(self.proj) self.nout = nout def forward(self, x): @@ -85,7 +128,7 @@ def forward(self, x): class StackedRNN(nn.Module): def __init__(self, nin, nembed, nunits, nout, nlayers=2, padding_idx=-1, dropout=0, rnn_type='lstm', - sparse=False, lm=None): + sparse=False, lm=None, transform=nn.ReLU()): super(StackedRNN, self).__init__() if padding_idx == -1: @@ -93,7 +136,8 @@ def __init__(self, nin, nembed, nunits, nout, nlayers=2, if lm is not None: self.embed = LMEmbed( - nin, nembed, lm, padding_idx=padding_idx, sparse=sparse) + nin, nembed, lm, padding_idx=padding_idx, sparse=sparse, + transform=transform) nembed = self.embed.nout self.lm = True else: diff --git a/deepblast/losses.py b/deepblast/losses.py index 7d5b509..4cb727a 100644 --- a/deepblast/losses.py +++ b/deepblast/losses.py @@ -1,4 +1,10 @@ import torch +from torch.distributions import Normal +from deepblast.constants import match_mean, match_std, gap_mean, gap_std + + +def mask_tensor(A, x_mask, y_mask): + return A[x_mask][:, y_mask] class AlignmentAccuracy: @@ -6,8 +12,52 @@ def __call__(self, true_edges, pred_edges): pass +class L2MatrixCrossEntropy: + def __call__(self, Ytrue, Ypred, M, G, x_mask, y_mask): + """ Computes binary cross entropy on the matrix with regularizers. + + The matrix cross entropy loss is given by + + d(ypred, ytrue) = - (mean(ytrue x log(ypred)) + + mean((1 - ytrue) x log(1 - ypred))) + + Parameters + ---------- + Ytrue : torch.Tensor + Ground truth alignment matrix of dimension N x M. + All entries are marked by 0 and 1. + Ypred : torch.Tensor + Predicted alignment matrix of dimension N x M. + M : torch.Tensor + Match score matrix + G : torch.Tensor + Gap score matrix + """ + score = 0 + eps = 3e-8 # unfortunately, this is the smallest eps we can have :( + Ypred = torch.clamp(Ypred, min=eps, max=1 - eps) + for b in range(len(x_mask)): + pos = torch.mean( + mask_tensor(Ytrue[b], x_mask[b], y_mask[b]) * torch.log( + mask_tensor(Ypred[b], x_mask[b], y_mask[b])) + ) + neg = torch.mean( + (1 - mask_tensor(Ytrue[b], x_mask[b], y_mask[b])) * torch.log( + 1 - mask_tensor(Ypred[b], x_mask[b], y_mask[b])) + ) + score += -(pos + neg) + + match_prior = Normal(match_mean, match_std) + gap_prior = Normal(gap_mean, gap_std) + log_like = score / len(x_mask) + match_log = match_prior.log_prob(M).mean() + gap_log = gap_prior.log_prob(G).mean() + score = log_like - match_log - gap_log + return score + + class MatrixCrossEntropy: - def __call__(self, Ytrue, Ypred, x_len, y_len): + def __call__(self, Ytrue, Ypred, x_mask, y_mask): """ Computes binary cross entropy on the matrix The matrix cross entropy loss is given by @@ -26,21 +76,29 @@ def __call__(self, Ytrue, Ypred, x_len, y_len): score = 0 eps = 3e-8 # unfortunately, this is the smallest eps we can have :( Ypred = torch.clamp(Ypred, min=eps, max=1 - eps) - for b in range(len(x_len)): + for b in range(len(x_mask)): pos = torch.mean( - Ytrue[b, :x_len[b], :y_len[b]] * torch.log( - Ypred[b, :x_len[b], :y_len[b]]) + mask_tensor(Ytrue[b], x_mask[b], y_mask[b]) * torch.log( + mask_tensor(Ypred[b], x_mask[b], y_mask[b])) ) neg = torch.mean( - (1 - Ytrue[b, :x_len[b], :y_len[b]]) * torch.log( - 1 - Ypred[b, :x_len[b], :y_len[b]]) + (1 - mask_tensor(Ytrue[b], x_mask[b], y_mask[b])) * torch.log( + 1 - mask_tensor(Ypred[b], x_mask[b], y_mask[b])) ) + # pos = torch.mean( + # Ytrue[b, x_mask[b], y_mask[b]] * torch.log( + # Ypred[b, x_mask[b], y_mask[b]]) + # ) + # neg = torch.mean( + # (1 - Ytrue[b, x_mask[b], y_mask[b]]) * torch.log( + # 1 - Ypred[b, x_mask[b], y_mask[b]]) + # f) score += -(pos + neg) - return score / len(x_len) + return score / len(x_mask) class SoftPathLoss: - def __call__(self, Pdist, Ypred, x_len, y_len): + def __call__(self, Pdist, Ypred, x_mask, y_mask): """ Computes a soft path loss The soft path loss is given by @@ -60,15 +118,15 @@ def __call__(self, Pdist, Ypred, x_len, y_len): Predicted alignment matrix of dimension N x M. """ score = 0 - for b in range(len(x_len)): + for b in range(len(x_mask)): score += torch.norm( - Pdist[b, :x_len[b], :y_len[b]] * Ypred[b, :x_len[b], :y_len[b]] + Pdist[b, x_mask[b], y_mask[b]] * Ypred[b, x_mask[b], y_mask[b]] ) - return score / len(x_len) + return score / len(x_mask) class SoftAlignmentLoss: - def __call__(self, Ytrue, Ypred, x_len, y_len): + def __call__(self, Ytrue, Ypred, x_mask, y_mask): """ Computes soft alignment loss as proposed in Mensch et al. The soft alignment loss is given by @@ -96,8 +154,8 @@ def __call__(self, Ytrue, Ypred, x_len, y_len): since it is possible to leave out important parts of the alignment. """ score = 0 - for b in range(len(x_len)): + for b in range(len(x_mask)): score += torch.norm( - Ytrue[b, :x_len[b], :y_len[b]] - Ypred[b, :x_len[b], :y_len[b]] + Ytrue[b, x_mask[b], y_mask[b]] - Ypred[b, x_mask[b], y_mask[b]] ) - return score / len(x_len) + return score / len(x_mask) diff --git a/deepblast/score.py b/deepblast/score.py index 2397ec1..81b920d 100644 --- a/deepblast/score.py +++ b/deepblast/score.py @@ -1,6 +1,6 @@ import numpy as np import matplotlib.pyplot as plt -from deepblast.dataset.utils import states2alignment +from deepblast.dataset.utils import states2alignment, states2edges, tmstate_f def roc_edges(true_edges, pred_edges): @@ -16,6 +16,25 @@ def roc_edges(true_edges, pred_edges): return tp, fp, fn, perc_id, ppv, fnr, fdr +def alignment_score(true_states: str, pred_states: str): + """ + Computes ROC statistics on alignment + + Parameters + ---------- + true_states : str + Ground truth state string + pred_states : str + Predicted state string + """ + pred_states = list(map(tmstate_f, pred_states)) + true_states = list(map(tmstate_f, true_states)) + pred_edges = states2edges(pred_states) + true_edges = states2edges(true_states) + stats = roc_edges(true_edges, pred_edges) + return stats + + def alignment_visualization(truth, pred, match, gap, xlen, ylen): """ Visualize alignment matrix diff --git a/deepblast/tests/test_alignment.py b/deepblast/tests/test_alignment.py index f713ab7..a539768 100644 --- a/deepblast/tests/test_alignment.py +++ b/deepblast/tests/test_alignment.py @@ -18,20 +18,20 @@ def setUp(self): nalpha, ninput, nunits, nembed = 22, 1024, 1024, 1024 self.aligner = NeedlemanWunschAligner(nalpha, ninput, nunits, nembed) - @unittest.skip + @unittest.skipUnless(torch.cuda.is_available(), 'No GPU was detected') def test_alignment(self): self.embedding = self.embedding.cuda() self.aligner = self.aligner.cuda() x = torch.Tensor( self.tokenizer(b'ARNDCQEGHILKMFPSTWYVXOUBZ') - ).unsqueeze(0).long().cuda() + ).long().cuda() y = torch.Tensor( self.tokenizer(b'ARNDCQEGHILKARNDCQMFPSTWYVXOUBZ') - ).unsqueeze(0).long().cuda() - N, M = x.shape[1], y.shape[1] - seq, order = pack_sequences([x], [y]) + ).long().cuda() + M = max(x.shape[0], y.shape[0]) + seq, order = pack_sequences([x, x], [y, y]) aln, theta, A = self.aligner(seq, order) - self.assertEqual(aln.shape, (1, N, M)) + self.assertEqual(aln.shape, (2, M, M)) @unittest.skipUnless(torch.cuda.is_available(), "No GPU detected") def test_batch_alignment(self): @@ -64,8 +64,10 @@ def test_collate_alignment(self): A2 = torch.ones((len(x2), len(y2))).long() P1 = torch.ones((len(x1), len(y1))).long() P2 = torch.ones((len(x2), len(y2))).long() - batch = [(x1, y1, s1, A1, P1), (x2, y2, s2, A2, P2)] - gene_codes, other_codes, states, dm, p = collate_f(batch) + mask = [torch.Tensor([0, 1]), torch.Tensor([0, 1])] + batch = [(x1, y1, s1, A1, P1, mask[0], mask[1]), + (x2, y2, s2, A2, P2, mask[0], mask[1])] + gene_codes, other_codes, states, dm, p, mask = collate_f(batch) self.embedding = self.embedding.cuda() self.aligner = self.aligner.cuda() seq, order = pack_sequences(gene_codes, other_codes) diff --git a/deepblast/tests/test_embedding.py b/deepblast/tests/test_embedding.py new file mode 100644 index 0000000..371db34 --- /dev/null +++ b/deepblast/tests/test_embedding.py @@ -0,0 +1,28 @@ +import torch +from deepblast.embedding import MultiLinear, MultiheadProduct +import unittest + + +class TestEmbedding(unittest.TestCase): + def setUp(self): + b, L, d, h = 3, 100, 50, 8 + self.x = torch.randn(b, L, d) + self.y = torch.randn(b, L, d) + self.b = b + self.L = L + self.d = d + self.h = h + + def test_multilinear(self): + model = MultiLinear(self.d, self.d, self.h) + res = model(self.x) + self.assertEqual(tuple(res.shape), (self.b, self.L, self.d, self.h)) + + def test_multihead_product(self): + model = MultiheadProduct(self.d, self.d, self.h) + res = model(self.x, self.y) + self.assertEqual(tuple(res.shape), (self.b, self.L, self.L)) + + +if __name__ == '__main__': + unittest.main() diff --git a/deepblast/tests/test_score.py b/deepblast/tests/test_score.py new file mode 100644 index 0000000..e1c8a25 --- /dev/null +++ b/deepblast/tests/test_score.py @@ -0,0 +1,90 @@ +from deepblast.score import roc_edges, alignment_text +from deepblast.dataset.utils import states2edges, tmstate_f +import pandas as pd +import numpy as np +import unittest + + +class TestScore(unittest.TestCase): + + def setUp(self): + pass + + def test_alignment_text(self): + gene = 'YACSGGCGQNFRTMSEFNEHMIRLVH' + other = 'LICPKHTRDCGKVFKRNSSLRVHEKTH' + pred = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 2, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1]) + truth = np.array([1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1]) + stats = np.array([1, 1, 1, 1, 1, 1, 1]) + alignment_text(gene, other, pred, truth, stats) + + def test_roc_edges(self): + cols = ['tp', 'fp', 'fn', 'perc_id', 'ppv', 'fnr', 'fdr'] + + exp_alignment = ( + 'FRCPRPAGCE--KLYSTSSHVNKHLLL', + 'YDCE---ICQSFKDFSPYMKLRKHRAT', + '::::111:::22:::::::::::::::' + ) + res_alignment = ( + 'FRCPRPAGCEKLYSTSSHVNKHLL', + 'YDCEICQSFKDFSPYMKLRKHRAT', + '::::::::::::::::::::::::' + ) + # TODO: there are still parts of the alignment + # that are being clipped erronously + exp_edges = states2edges( + list(map(tmstate_f, exp_alignment[2]))) + res_edges = states2edges( + list(map(tmstate_f, res_alignment[2]))) + + res = pd.Series(roc_edges(exp_edges, res_edges), index=cols) + + self.assertGreater(res.perc_id, 0.1) + + def test_roc_edges_2(self): + cols = ['tp', 'fp', 'fn', 'perc_id', 'ppv', 'fnr', 'fdr'] + exp_alignment = ( + 'SVHTLLDEKHETLDSEWEKLVRDAMTSGVSKKQFREFLDYQKWRKSQ', + ':1111111111111111111111111111::::::::::::::::::', + 'I----------------------------FTYGELQRMQEKERNKGQ' + ) + res_alignment = ( + 'SVHTLLDEKHETLDSEWEKLVRDAMTSGVSKKQFREFLDYQKWRKSQ', + '1:1111111:111111111111111111:11::::::::::::::::', + '-I-------F------------------T--YGELQRMQEKERNKGQ' + ) + + exp_edges = states2edges( + list(map(tmstate_f, exp_alignment[2]))) + res_edges = states2edges( + list(map(tmstate_f, res_alignment[2]))) + res = pd.Series(roc_edges(exp_edges, res_edges), index=cols) + self.assertGreater(res.tp, 20) + self.assertGreater(res.perc_id, 0.5) + + def test_roc_edges_3(self): + cols = ['tp', 'fp', 'fn', 'perc_id', 'ppv', 'fnr', 'fdr'] + exp_alignment = ( + 'F--GD--D--------QN-PYTESVDILEDLVIEFITEMTHKAMSI', + 'ISHLVIMHEEGEVDGKAIPDLTAPVSAVQAAVSNLVRVGKETVQTT', + ':22::22:22222222::2:::::::::::::::::::::::::::' + ) + res_alignment = ( + '-FG---D------D--QN-PYTESVDILEDLVIEFITEMTHKAMSI', + 'ISHLVIMHEEGEVDGKAIPDLTAPVSAVQAAVSNLVRVGKETVQTT', + '2::222:222222:22::2:::::::::::::::::::::::::::' + ) + exp_edges = states2edges( + list(map(tmstate_f, exp_alignment[2]))) + res_edges = states2edges( + list(map(tmstate_f, res_alignment[2]))) + res = pd.Series(roc_edges(exp_edges, res_edges), index=cols) + self.assertGreater(res.tp, 20) + self.assertGreater(res.perc_id, 0.5) + + +if __name__ == '__main__': + unittest.main() diff --git a/deepblast/trainer.py b/deepblast/trainer.py index ead2fc7..e5640c0 100644 --- a/deepblast/trainer.py +++ b/deepblast/trainer.py @@ -13,10 +13,13 @@ from deepblast.dataset.alphabet import UniprotTokenizer from deepblast.dataset import TMAlignDataset from deepblast.dataset.utils import ( - decode, states2edges, collate_f, unpack_sequences, pack_sequences) + decode, states2edges, collate_f, unpack_sequences, + pack_sequences, revstate_f) from deepblast.losses import ( - SoftAlignmentLoss, SoftPathLoss, MatrixCrossEntropy) + SoftAlignmentLoss, SoftPathLoss, MatrixCrossEntropy, + L2MatrixCrossEntropy) from deepblast.score import roc_edges, alignment_visualization, alignment_text +import warnings class LightningAligner(pl.LightningModule): @@ -28,6 +31,8 @@ def __init__(self, args): self.initialize_aligner() if self.hparams.loss == 'sse': self.loss_func = SoftAlignmentLoss() + elif self.hparams.loss == 'l2_cross_entropy': + self.loss_func = L2MatrixCrossEntropy() elif self.hparams.loss == 'cross_entropy': self.loss_func = MatrixCrossEntropy() elif self.hparams.loss == 'path': @@ -41,15 +46,31 @@ def initialize_aligner(self): n_input = self.hparams.rnn_input_dim n_units = self.hparams.rnn_dim n_layers = self.hparams.layers - if self.hparams.aligner == 'nw': - self.aligner = NeedlemanWunschAligner( - n_alpha, n_input, n_units, n_embed, n_layers) - else: - raise NotImplementedError( - f'Aligner {self.hparams.aligner_type} not implemented.') + n_heads = self.hparams.heads + self.aligner = NeedlemanWunschAligner( + n_alpha, n_input, n_units, n_embed, n_layers, n_heads) + def forward(self, x, y): - return self.aligner.forward(x, y) + x_code = torch.Tensor(self.tokenizer(str.encode(x))).long() + y_code = torch.Tensor(self.tokenizer(str.encode(y))).long() + x_code = x_code.to(self.device) + y_code = y_code.to(self.device) + seq, order = pack_sequences([x_code], [y_code]) + A, theta, gap = self.aligner(seq, order) + return A, theta, gap + + def align(self, x, y): + x_code = torch.Tensor(self.tokenizer(str.encode(x))).long() + y_code = torch.Tensor(self.tokenizer(str.encode(y))).long() + x_code = x_code.to(self.device) + y_code = y_code.to(self.device) + seq, order = pack_sequences([x_code], [y_code]) + gen = self.aligner.traceback(seq, order) + decoded = next(gen) + pred_x, pred_y, pred_states = list(zip(*decoded)) + s = ''.join(list(map(revstate_f, pred_states))) + return s def initialize_logging(self, root_dir='./', logging_path=None): if logging_path is None: @@ -62,7 +83,7 @@ def initialize_logging(self, root_dir='./', logging_path=None): def train_dataloader(self): train_dataset = TMAlignDataset( - self.hparams.train_pairs, + self.hparams.train_pairs, clip_ends=self.hparams.clip_ends, construct_paths=isinstance(self.loss_func, SoftPathLoss)) train_dataloader = DataLoader( train_dataset, self.hparams.batch_size, collate_fn=collate_f, @@ -72,38 +93,42 @@ def train_dataloader(self): def val_dataloader(self): valid_dataset = TMAlignDataset( - self.hparams.valid_pairs, + self.hparams.valid_pairs, clip_ends=self.hparams.clip_ends, construct_paths=isinstance(self.loss_func, SoftPathLoss)) valid_dataloader = DataLoader( valid_dataset, self.hparams.batch_size, collate_fn=collate_f, - shuffle=False, num_workers=self.hparams.num_workers, + shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True) return valid_dataloader def test_dataloader(self): + # Held-out TM-align dataset test_dataset = TMAlignDataset( - self.hparams.test_pairs, + self.hparams.test_pairs, clip_ends=self.hparams.clip_ends, construct_paths=isinstance(self.loss_func, SoftPathLoss)) test_dataloader = DataLoader( - test_dataset, self.hparams.batch_size, shuffle=False, + test_dataset, self.hparams.batch_size, shuffle=True, collate_fn=collate_f, num_workers=self.hparams.num_workers, pin_memory=True) return test_dataloader - def compute_loss(self, x, y, predA, A, P, theta): - + def compute_loss(self, mask, predA, A, P, theta, gap): + x_mask, y_mask = mask if isinstance(self.loss_func, SoftAlignmentLoss): - loss = self.loss_func(A, predA, x, y) + loss = self.loss_func(A, predA, x_mask, y_mask) elif isinstance(self.loss_func, MatrixCrossEntropy): - loss = self.loss_func(A, predA, x, y) + loss = self.loss_func(A, predA, x_mask, y_mask) + elif isinstance(self.loss_func, L2MatrixCrossEntropy): + loss = self.loss_func(A, predA, theta, gap, x_mask, y_mask) elif isinstance(self.loss_func, SoftPathLoss): - loss = self.loss_func(P, predA, x, y) + loss = self.loss_func(P, predA, x_mask, y_mask) if self.hparams.multitask: current_lr = self.trainer.lr_schedulers[0]['scheduler'] current_lr = current_lr.get_last_lr()[0] max_lr = self.hparams.learning_rate lam = current_lr / max_lr - match_loss = self.loss_func(torch.sigmoid(theta), predA, x, y) + match_loss = self.loss_func(torch.sigmoid(theta), predA, + x_mask, y_mask) # when learning rate is large, weight match loss # otherwise, weight towards DP loss = lam * match_loss + (1 - lam) * loss @@ -111,17 +136,18 @@ def compute_loss(self, x, y, predA, A, P, theta): def training_step(self, batch, batch_idx): self.aligner.train() - genes, others, s, A, P = batch + genes, others, s, A, P, mask = batch seq, order = pack_sequences(genes, others) predA, theta, gap = self.aligner(seq, order) - _, xlen, _, ylen = unpack_sequences(seq, order) - loss = self.compute_loss(xlen, ylen, predA, A, P, theta) + x_mask, y_mask = mask + loss = self.compute_loss(mask, predA, A, P, theta, gap) assert torch.isnan(loss).item() is False if len(self.trainer.lr_schedulers) >= 1: current_lr = self.trainer.lr_schedulers[0]['scheduler'] current_lr = current_lr.get_last_lr()[0] else: current_lr = self.hparams.learning_rate + tensorboard_logs = {'train_loss': loss, 'lr': current_lr} # log the learning rate return {'loss': loss, 'log': tensorboard_logs} @@ -143,6 +169,11 @@ def validation_stats(self, x, y, xlen, ylen, gen, truth_states = states[b].cpu().detach().numpy() pred_edges = states2edges(pred_states) true_edges = states2edges(truth_states) + if len(pred_edges) == 0: + raise ValueError('No predicted edges', pred_states) + if len(true_edges) == 0: + raise ValueError('No truth edges', truth_states) + stats = roc_edges(true_edges, pred_edges) if random.random() < self.hparams.visualization_fraction: Av = A[b].cpu().detach().numpy().squeeze() @@ -169,17 +200,14 @@ def validation_stats(self, x, y, xlen, ylen, gen, return statistics def validation_step(self, batch, batch_idx): - # TODO: something weird is going on with the lengths - # Need to make sure that they are being sorted properly - genes, others, s, A, P = batch + genes, others, s, A, P, mask = batch seq, order = pack_sequences(genes, others) predA, theta, gap = self.aligner(seq, order) x, xlen, y, ylen = unpack_sequences(seq, order) - loss = self.compute_loss(xlen, ylen, predA, A, P, theta) + loss = self.compute_loss(mask, predA, A, P, theta, gap) assert torch.isnan(loss).item() is False # Obtain alignment statistics + visualizations gen = self.aligner.traceback(seq, order) - # TODO; compare the traceback and the forward statistics = self.validation_stats( x, y, xlen, ylen, gen, s, A, predA, theta, gap, batch_idx) statistics = pd.DataFrame( @@ -194,12 +222,26 @@ def validation_step(self, batch, batch_idx): return {'validation_loss': loss, 'log': tensorboard_logs} - def test_step(self, batch, batch_idx): - pass + def custom_parameter_histogram(self): + # iterating through all parameters + for name, params in self.named_parameters(): + if params.requires_grad and (params.grad is not None): + self.logger.experiment.add_histogram( + f'{name}/value', params, self.global_step) + + # def on_after_backward(self): + # # example to inspect gradient information in tensorboard + # if self.trainer.global_step % 200 == 0: # don't make the tf file huge + # for name, params in self.named_parameters(): + # if params.requires_grad and (params.grad is not None): + # self.logger.experiment.add_histogram( + # f'{name}/grad', params.grad, self.global_step) def validation_epoch_end(self, outputs): loss_f = lambda x: x['validation_loss'] losses = list(map(loss_f, outputs)) + if len(losses) == 0: + raise ValueError('No losses reported', output) loss = sum(losses) / len(losses) self.logger.experiment.add_scalar('val_loss', loss, self.global_step) metrics = ['val_tp', 'val_fp', 'val_fn', 'val_perc_id', @@ -207,16 +249,24 @@ def validation_epoch_end(self, outputs): scores = [] for i, m in enumerate(metrics): loss_f = lambda x: x['log'][m] - losses = list(map(loss_f, outputs)) - scalar = sum(losses) / len(losses) - scores.append(scalar) - self.logger.experiment.add_scalar(m, scalar, self.global_step) + losses = np.array(list(map(loss_f, outputs))) + losses = losses[np.logical_not(np.isnan(losses))] + if len(losses) > 0: + scalar = sum(losses) / len(losses) + scores.append(scalar) + self.logger.experiment.add_scalar(m, scalar, self.global_step) + else: + warnings.warn(f'No losses reported for {m}.', RuntimeWarning) + self.custom_parameter_histogram() tensorboard_logs = dict( [('val_loss', loss)] + list(zip(metrics, scores)) ) return {'val_loss': loss, 'log': tensorboard_logs} + def test_step(self, batch, batch_idx): + pass + def test_epoch_end(self, outputs): pass @@ -229,7 +279,7 @@ def configure_optimizers(self): grad_params, lr=self.hparams.learning_rate) if self.hparams.scheduler == 'cosine_restarts': scheduler = CosineAnnealingWarmRestarts( - optimizer, T_0=1, T_mult=2) + optimizer, T_0=1, T_mult=1) elif self.hparams.scheduler == 'cosine': scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.epochs) elif self.hparams.scheduler == 'triangular': @@ -246,6 +296,13 @@ def configure_optimizers(self): steps = int(np.log2(self.hparams.learning_rate / m)) steps = self.hparams.epochs // steps scheduler = StepLR(optimizer, step_size=steps, gamma=0.5) + elif self.hparams.scheduler == 'inv_steplr': + m = 1e-4 # maximum learning rate + optimizer = torch.optim.Adam( + grad_params, lr=m) + steps = int(np.log2(m / self.hparams.learning_rate)) + steps = self.hparams.epochs // steps + scheduler = StepLR(optimizer, step_size=steps, gamma=0.5) elif self.hparams.scheduler == 'none': return [optimizer] else: @@ -262,10 +319,6 @@ def add_model_specific_args(parent_parser): '--test-pairs', help='Testing pairs file', required=True) parser.add_argument( '--valid-pairs', help='Validation pairs file', required=True) - parser.add_argument( - '-a', '--aligner', - help='Aligner type. Choices include (nw, hmm).', - required=False, type=str, default='nw') parser.add_argument( '--embedding-dim', help='Embedding dimension (default 512).', required=False, type=int, default=512) @@ -278,9 +331,13 @@ def add_model_specific_args(parent_parser): parser.add_argument( '--layers', help='Number of RNN layers (default 2).', required=False, type=int, default=2) + parser.add_argument( + '--heads', help='Number heads in attention layer (default 8).', + required=False, type=int, default=8) parser.add_argument( '--loss', - help=('Loss function. Options include {sse, path, cross_entropy} ' + help=('Loss function. Options include ' + '{sse, path, cross_entropy, l2_cross_entropy} ' '(default cross_entropy)'), default='cross_entropy', required=False, type=str) parser.add_argument( diff --git a/ipynb/simulation-benchmark.ipynb b/ipynb/simulation-benchmark.ipynb index 9b986e8..583db3f 100644 --- a/ipynb/simulation-benchmark.ipynb +++ b/ipynb/simulation-benchmark.ipynb @@ -19,6 +19,26 @@ "import numpy as np" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Beta-lactamase.hmm tm-align-0.9-30.tab tm_align_output_10k.ali\r\n", + "I-set.hmm\t tm-align-0.9-50.tab tm_align_output_10k.tab\r\n", + "PPR_2.hmm\t tm-align-0.9.tab\t zf-C2H2.hmm\r\n", + "tm-align-0.9-10.tab tm-align-0.9.txt\r\n" + ] + } + ], + "source": [ + "!ls ../data" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -29,11 +49,11 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "hmm = '../data/zf-C2H2.hmm'\n", + "hmm = '../data/Beta-lactamase.hmm'\n", "n_alignments = 100\n", "np.random.seed(0)\n", "align_df = hmm_alignments(n=40, seed=0, n_alignments=n_alignments, hmmfile=hmm)\n", @@ -54,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -78,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -101,7 +121,7 @@ "'/home/juermieboop/Documents/research/garfunkel/ipynb'" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -120,56 +140,25 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "usage: ipykernel_launcher.py [-h] --train-pairs TRAIN_PAIRS --test-pairs TEST_PAIRS --valid-pairs VALID_PAIRS [-a ALIGNER]\n", - " [--embedding-dim EMBEDDING_DIM] [--rnn-input-dim RNN_INPUT_DIM] [--rnn-dim RNN_DIM] [--layers LAYERS]\n", - " [--loss LOSS] [--learning-rate LEARNING_RATE] [--batch-size BATCH_SIZE] [--multitask MULTITASK]\n", - " [--finetune FINETUNE] [--clip-ends CLIP_ENDS] [--scheduler SCHEDULER] [--epochs EPOCHS]\n", - " [--visualization-fraction VISUALIZATION_FRACTION] -o OUTPUT_DIRECTORY [--num-workers NUM_WORKERS]\n", - " [--gpus GPUS]\n", - "ipykernel_launcher.py: error: unrecognized arguments: --load-from-checkpoint lightning_logs/version_5/checkpoints\n" - ] - }, - { - "ename": "SystemExit", - "evalue": "2", - "output_type": "error", - "traceback": [ - "An exception has occurred, use %tb to see the full traceback.\n", - "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/juermieboop/miniconda3/envs/pytorch/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3339: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n", - " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n" - ] - } - ], + "outputs": [], "source": [ "args = [\n", " '--train-pairs', f'{os.getcwd()}/data/train.txt',\n", " '--test-pairs', f'{os.getcwd()}/data/test.txt',\n", - " '--valid-pairs', f'{os.getcwd()}/data/valid.txt',\n", + " '--valid-pairs', f'{os.getcwd()}/data/train.txt',\n", " '--output-directory', output_dir,\n", - " '--epochs', '32',\n", + " '--epochs', '128',\n", " '--batch-size', '20', \n", " '--num-workers', '30',\n", - " '--learning-rate', '1e-4',\n", - " '--layers', '2',\n", + " '--learning-rate', '1e-4', \n", + " '--layers', '4',\n", + " '--heads', '4',\n", " '--visualization-fraction', '1',\n", - " '--loss', 'cross_entropy',\n", - " '--scheduler', 'none',\n", - " '--gpus', '1',\n", - " '--load-from-checkpoint', 'lightning_logs/version_5/checkpoints'\n", + " '--loss', 'l2_cross_entropy',\n", + " '--scheduler', 'steplr',\n", + " '--gpus', '1'\n", "]\n", "parser = argparse.ArgumentParser(add_help=False)\n", "parser = LightningAligner.add_model_specific_args(parser)\n", @@ -180,9 +169,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No traceback available to show.\n" + ] + } + ], "source": [ "%tb" ] @@ -196,9 +193,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/juermieboop/Documents/research/garfunkel/deepblast/embedding.py:9: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", + " nn.init.xavier_uniform(m.weight)\n", + "/home/juermieboop/Documents/research/garfunkel/deepblast/embedding.py:36: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", + " nn.init.xavier_uniform(self.linear.weight)\n" + ] + } + ], "source": [ "model = LightningAligner(args)" ] @@ -212,18 +220,215 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "---------------------------------------------------\n", + "0 | aligner | NeedlemanWunschAligner | 48 M \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a15504f7cc40479f86df76fd73aed1e2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/juermieboop/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "from pytorch_lightning.profiler import AdvancedProfiler\n", + "profiler=AdvancedProfiler()\n", "trainer = Trainer(\n", " max_epochs=args.epochs,\n", " gpus=args.gpus,\n", " check_val_every_n_epoch=10,\n", + " gradient_clip_val=10,\n", + " # val_percent_check=0.25 \n", " # profiler=profiler,\n", - " fast_dev_run=False,\n", + " # fast_dev_run=True,\n", " # auto_scale_batch_size='power'\n", ")\n", "\n", @@ -239,16 +444,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "version_0 version_11\tversion_14 version_4 version_7\r\n", + "version_1 version_12\tversion_2 version_5 version_8\r\n", + "version_10 version_13\tversion_3 version_6 version_9\r\n" + ] + } + ], "source": [ "!ls lightning_logs" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -257,11 +472,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Reusing TensorBoard on port 6006 (pid 3827), started 7:41:35 ago. (Use '!kill 3827' to kill it.)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "%tensorboard --logdir lightning_logs" ] @@ -275,27 +523,102 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "NeedlemanWunschAligner(\n", + " (lm): BiLM(\n", + " (embed): Embedding(22, 21, padding_idx=21)\n", + " (dropout): Dropout(p=0, inplace=False)\n", + " (rnn): ModuleList(\n", + " (0): LSTM(21, 1024, batch_first=True)\n", + " (1): LSTM(1024, 1024, batch_first=True)\n", + " )\n", + " (linear): Linear(in_features=1024, out_features=21, bias=True)\n", + " )\n", + " (match_embedding): StackedRNN(\n", + " (embed): Embedding(21, 512, padding_idx=20)\n", + " (dropout): Dropout(p=0, inplace=False)\n", + " (rnn): GRU(512, 512, num_layers=4, batch_first=True, bidirectional=True)\n", + " (proj): Linear(in_features=1024, out_features=512, bias=True)\n", + " )\n", + " (gap_embedding): StackedRNN(\n", + " (embed): Embedding(21, 512, padding_idx=20)\n", + " (dropout): Dropout(p=0, inplace=False)\n", + " (rnn): GRU(512, 512, num_layers=4, batch_first=True, bidirectional=True)\n", + " (proj): Linear(in_features=1024, out_features=512, bias=True)\n", + " )\n", + " (match_mixture): MultiheadProduct(\n", + " (multilinear): MultiLinear(\n", + " (multi_output): ModuleList(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " )\n", + " )\n", + " (linear): Linear(in_features=1, out_features=1, bias=True)\n", + " )\n", + " (gap_mixture): MultiheadProduct(\n", + " (multilinear): MultiLinear(\n", + " (multi_output): ModuleList(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " )\n", + " )\n", + " (linear): Linear(in_features=1, out_features=1, bias=True)\n", + " )\n", + " (nw): NeedlemanWunschDecoder()\n", + ")" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model.aligner" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'epoch=119.ckpt'\r\n" + ] + } + ], "source": [ "!ls lightning_logs/version_5/checkpoints" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'lightning_logs/version_70/checkpoints/epoch=59.ckpt'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcheckpoint_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'lightning_logs/version_70/checkpoints'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf'{checkpoint_dir}/epoch=59.ckpt'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLightningAligner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_from_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/core/saving.py\u001b[0m in \u001b[0;36mload_from_checkpoint\u001b[0;34m(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mcheckpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mcheckpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstorage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;31m# add the hparams from csv file to checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(path_or_url, map_location)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0murlparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscheme\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m''\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrive\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# no scheme or with a drive letter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhub\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict_from_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 525\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 526\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0m_open_zipfile_reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_zipfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'lightning_logs/version_70/checkpoints/epoch=59.ckpt'" + ] + } + ], "source": [ "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", "checkpoint_dir = 'lightning_logs/version_70/checkpoints'\n", diff --git a/ipynb/struct-benchmark.ipynb b/ipynb/struct-benchmark.ipynb index dea9251..b4cee16 100644 --- a/ipynb/struct-benchmark.ipynb +++ b/ipynb/struct-benchmark.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -90,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -99,7 +99,7 @@ "'/home/juermieboop/Documents/research/garfunkel/ipynb'" ] }, - "execution_count": 5, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -127,15 +127,16 @@ " '--test-pairs', f'{os.getcwd()}/data/test.txt',\n", " '--valid-pairs', f'{os.getcwd()}/data/valid.txt',\n", " '--output-directory', output_dir,\n", - " '--epochs', '128',\n", + " '--epochs', '32',\n", " '--batch-size', '20', \n", " '--num-workers', '30',\n", - " '--layers', '2',\n", + " '--layers', '4',\n", + " '--heads', '8',\n", " '--learning-rate', '5e-5',\n", " '--visualization-fraction', '1',\n", " '--loss', 'cross_entropy',\n", - " '--scheduler', 'steplr', \n", - " '--clip-ends', 'True',\n", + " '--scheduler', 'inv_steplr', \n", + " # '--clip-ends', 'True',\n", " '--gpus', '1'\n", "]" ] @@ -149,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -161,6 +162,26 @@ "model = LightningAligner(args)" ] }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(batch_size=20, clip_ends=False, embedding_dim=512, epochs=32, finetune=False, gpus=1, heads=8, layers=4, learning_rate=5e-05, loss='cross_entropy', multitask=False, num_workers=30, output_directory='struct_results', rnn_dim=512, rnn_input_dim=512, scheduler='inv_steplr', test_pairs='/home/juermieboop/Documents/research/garfunkel/ipynb/data/test.txt', train_pairs='/home/juermieboop/Documents/research/garfunkel/ipynb/data/train.txt', valid_pairs='/home/juermieboop/Documents/research/garfunkel/ipynb/data/valid.txt', visualization_fraction=1.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "args" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -170,9 +191,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [ { @@ -185,13 +206,13 @@ "\n", " | Name | Type | Params\n", "---------------------------------------------------\n", - "0 | aligner | NeedlemanWunschAligner | 34 M \n" + "0 | aligner | NeedlemanWunschAligner | 52 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "", + "model_id": "c45254eb1e5f42e6acd2b65d4a560fa1", "version_major": 2, "version_minor": 0 }, @@ -203,69 +224,38 @@ "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ce3249fdf417451c896ecfa630cf1943", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cc3be93a41df413195269996a5009591", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "RuntimeError", + "evalue": "CUDA out of memory. Tried to allocate 328.00 MiB (GPU 0; 10.92 GiB total capacity; 4.26 GiB already allocated; 115.12 MiB free; 4.55 GiB reserved in total by PyTorch)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders)\u001b[0m\n\u001b[1;32m 977\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 978\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msingle_gpu\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 979\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msingle_gpu_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 980\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 981\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_tpu\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pragma: no-cover\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_parts.py\u001b[0m in \u001b[0;36msingle_gpu_train\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreinit_scheduler_properties\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr_schedulers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_pretrain_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtpu_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtpu_core_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_pretrain_routine\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 1134\u001b[0m \u001b[0mnum_loaders\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval_dataloaders\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1135\u001b[0m \u001b[0mmax_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_sanity_val_steps\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnum_loaders\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1136\u001b[0;31m eval_results = self._evaluate(model,\n\u001b[0m\u001b[1;32m 1137\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[0mmax_batches\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36m_evaluate\u001b[0;34m(self, model, dataloaders, max_batches, test_mode)\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_mode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 293\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluation_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_mode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 294\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;31m# on dp / ddp2 might still want to do something with the batch parts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py\u001b[0m in \u001b[0;36mevaluation_forward\u001b[0;34m(self, model, batch, batch_idx, dataloader_idx, test_mode)\u001b[0m\n\u001b[1;32m 483\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 484\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 485\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 486\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/research/garfunkel/deepblast/trainer.py\u001b[0m in \u001b[0;36mvalidation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mgenes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mothers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mseq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpack_sequences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mothers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m \u001b[0mpredA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maligner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxlen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mylen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munpack_sequences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgap\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/research/garfunkel/deepblast/alignment.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, order)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Obtain theta through an inner product across latent dimensions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mtheta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatch_mixture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mgap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgap_mixture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;31m#G = self.gap_mixture(gx, gy)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;31m# zero out first row and first column for local alignments\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/research/garfunkel/deepblast/embedding.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, y)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mzx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultilinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mzy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultilinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mdists\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'bidh,bjdh->bijh'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdists\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/functional.py\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(equation, *operands)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;31m# the old interface of passing the operands as one list argument\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0moperands\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moperands\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_VariableFunctions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mequation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperands\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 328.00 MiB (GPU 0; 10.92 GiB total capacity; 4.26 GiB already allocated; 115.12 MiB free; 4.55 GiB reserved in total by PyTorch)" + ] } ], "source": [ + "from pytorch_lightning.profiler import AdvancedProfiler\n", + "profiler=AdvancedProfiler()\n", "trainer = Trainer(\n", " max_epochs=args.epochs,\n", " gpus=args.gpus,\n", - " check_val_every_n_epoch=8,\n", - " # profiler=profiler,\n", - " # fast_dev_run=True,\n", + " check_val_every_n_epoch=1,\n", + " val_percent_check=0.1\n", + " #profiler=profiler,\n", + " #fast_dev_run=True,\n", " # auto_scale_batch_size='power'\n", ")\n", "\n", @@ -290,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -299,11 +289,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Reusing TensorBoard on port 6006 (pid 3827), started 5:48:33 ago. (Use '!kill 3827' to kill it.)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "%tensorboard --logdir lightning_logs" ] diff --git a/scripts/deepblast-train b/scripts/deepblast-train index 85c7b06..6d3f15c 100644 --- a/scripts/deepblast-train +++ b/scripts/deepblast-train @@ -24,12 +24,16 @@ def main(args): num_nodes=args.nodes, accumulate_grad_batches=args.grad_accum, gradient_clip_val=args.grad_clip, - distributed_backend=args.backend, precision=args.precision, - # check_val_every_n_epoch=1, + #check_val_every_n_epoch=1, + limit_train_batches=0.1, + limit_val_batches=0.2, + limit_test_batches=0.2, val_check_interval=0.25, fast_dev_run=False, + # overfit the data + # overfit_pct=0.01, # auto_scale_batch_size='power', # profiler=profiler, ) @@ -44,7 +48,7 @@ def main(args): # initialize Model Checkpoint Saver checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, - period=1, + save_top_k=1, monitor='validation_loss', mode='min', verbose=True @@ -63,7 +67,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(add_help=False) parser.add_argument('--gpus', type=int, default=None) parser.add_argument('--grad-accum', type=int, default=1) - parser.add_argument('--grad-clip', type=int, default=0) + parser.add_argument('--grad-clip', type=int, default=10) parser.add_argument('--nodes', type=int, default=1) parser.add_argument('--num-workers', type=int, default=1) parser.add_argument('--precision', type=int, default=32)