Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Priors #68

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
eb4003b
adding methods for prediction / evaluation
mortonjt Aug 14, 2020
087aa0b
Addressing issue #65
mortonjt Aug 18, 2020
a372784
WIP: adding priors, multilinear layer and mask
mortonjt Aug 19, 2020
2ac1d61
adding embedding tests
mortonjt Aug 19, 2020
1aaa91e
adding multihead product. zeroing out first gap row/column for local …
mortonjt Aug 19, 2020
efdcbd3
updating penalized cross entropy
mortonjt Aug 19, 2020
dae23c4
adjusting gap mask
mortonjt Aug 19, 2020
e8c8a48
flake8
mortonjt Aug 19, 2020
6b38107
more flake8
mortonjt Aug 19, 2020
81b7ee1
fix dataset tests
mortonjt Aug 19, 2020
2d828de
fixing tests in alignment
mortonjt Aug 19, 2020
0374422
fixing dimensionality issues
mortonjt Aug 19, 2020
9ae419d
fixing gap indexing error
mortonjt Aug 19, 2020
82ca197
making priors looser
mortonjt Aug 19, 2020
e752757
remove local alignment in traceback. adding clip-ends option.
mortonjt Aug 22, 2020
1475d60
fixing scalar issue in validation
mortonjt Aug 22, 2020
f23e513
fixed bug with gap_mask - the outputs were swapped before
mortonjt Aug 26, 2020
837082a
adding more asserts to hunt down validation TB issue
mortonjt Aug 28, 2020
ec392c3
adding more logging options
mortonjt Aug 29, 2020
6237622
adding batch normalization
mortonjt Aug 30, 2020
9439ccc
adding weight initialization
mortonjt Aug 30, 2020
514d0a7
removing batch norm
mortonjt Aug 30, 2020
cdecd87
bye bye batch norm
mortonjt Aug 30, 2020
7bd2f27
adding more heads
mortonjt Aug 30, 2020
9190411
adding more layers
mortonjt Aug 30, 2020
60cc74a
simplifying validation
mortonjt Aug 30, 2020
41e28e4
zeroing out gap edges
mortonjt Aug 30, 2020
8208bcf
Merge branch 'priors' of github.com:mortonjt/garfunkel into priors
mortonjt Aug 30, 2020
b5d3e7a
removing end gap zeroing
mortonjt Aug 30, 2020
2e5f326
fixing local alignments
mortonjt Aug 30, 2020
83ded77
Speeding up epochs for easier debugging
mortonjt Sep 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions deepblast/alignment.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepblast.language_model import BiLM, pretrained_language_models
from deepblast.nw_cuda import NeedlemanWunschDecoder as NWDecoderCUDA
from deepblast.embedding import StackedRNN, EmbedLinear
from deepblast.embedding import StackedRNN, EmbedLinear, MultiheadProduct
from deepblast.dataset.utils import unpack_sequences
import torch.nn.functional as F
import math


def swish(x):
return x * F.sigmoid(x)


class NeedlemanWunschAligner(nn.Module):

def __init__(self, n_alpha, n_input, n_units, n_embed,
n_layers=2, lm=None, device='gpu'):
n_layers=2, n_heads=16, lm=None, device='gpu', local=True):
""" NeedlemanWunsch Alignment model

Parameters
Expand All @@ -25,36 +30,48 @@ def __init__(self, n_alpha, n_input, n_units, n_embed,
Embedding dimension
n_layers : int
Number of RNN layers.
n_heads : int
Number of heads in multilinear layer.
lm : BiLM
Pretrained language model (optional)
padding_idx : int
Location of padding index in embedding (default -1)
transform : function
Activation function (default relu)
sparse : False?
local : bool
Specifies if local alignment should be performed on the traceback
"""
super(NeedlemanWunschAligner, self).__init__()
if lm is None:
path = pretrained_language_models['bilstm']
self.lm = BiLM()
self.lm.load_state_dict(torch.load(path))
self.lm.eval()
transform = swish
if n_layers > 1:
self.match_embedding = StackedRNN(
n_alpha, n_input, n_units, n_embed, n_layers, lm=lm)
n_alpha, n_input, n_units, n_embed, n_layers, lm=lm,
transform=swish, rnn_type='gru')
self.gap_embedding = StackedRNN(
n_alpha, n_input, n_units, n_embed, n_layers, lm=lm)
n_alpha, n_input, n_units, n_embed, n_layers, lm=lm,
transform=swish, rnn_type='gru')
else:
self.match_embedding = EmbedLinear(
n_alpha, n_input, n_embed, lm=lm)
n_alpha, n_input, n_embed, lm=lm,
transform=swish)
self.gap_embedding = EmbedLinear(
n_alpha, n_input, n_embed, lm=lm)
n_alpha, n_input, n_embed, lm=lm,
transform=swish)

self.match_mixture = MultiheadProduct(n_embed, n_embed, n_heads)
self.gap_mixture = MultiheadProduct(n_embed, n_embed, n_heads)
# TODO: make cpu compatible version
# if device == 'cpu':
# self.nw = NWDecoderCPU(operator='softmax')
# else:
self.nw = NWDecoderCUDA(operator='softmax')
self.local = local

def forward(self, x, order):
""" Generate alignment matrix.
Expand All @@ -72,10 +89,13 @@ def forward(self, x, order):
with torch.enable_grad():
zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order)

# Obtain theta through an inner product across latent dimensions
theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
theta = self.match_mixture(zx, zy)
#A = self.gap_mixture(gx, gy)
# zero out first row and first column for local alignments
G = self.gap_mixture(gx, gy)
A = torch.zeros(G.shape).to(G.device)
A[:, 1:, 1:] += G[:, 1:, 1:]
aln = self.nw.decode(theta, A)
return aln, theta, A

Expand All @@ -84,13 +104,26 @@ def traceback(self, x, order):
with torch.enable_grad():
zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order)
match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
match = self.match_mixture(zx, zy)
# gap = self.gap_mixture(gx, gy)
A = self.gap_mixture(gx, gy)
gap = torch.zeros(A.shape).to(A.device)
gap[:, 1:, 1:] += A[:, 1:, 1:]

# zero out first row and first column for local alignments
# L = gx.shape[1]
# gap = torch.zeros((L, L))
# gap[1:, 1:] = self.gap_mixture(gx[:, 1:, :], gy[:, 1:, :])

B, _, _ = match.shape

for b in range(B):
aln = self.nw.decode(
match[b, :xlen[b], :ylen[b]].unsqueeze(0),
gap[b, :xlen[b], :ylen[b]].unsqueeze(0)
)
M = match[b, :xlen[b], :ylen[b]].unsqueeze(0)
G = gap[b, :xlen[b], :ylen[b]].unsqueeze(0)
# val = math.log(1 - (1/50)) # based on average insertion length
# if self.local:
# G[0, 0, :] = val
# G[0, :, 0] = val
aln = self.nw.decode(M, G)
decoded = self.nw.traceback(aln.squeeze())
yield decoded, aln
6 changes: 6 additions & 0 deletions deepblast/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
x, m, y = 0, 1, 2 # state numberings

match_mean = 2
match_std = 10

gap_mean = -4
gap_std = 10
33 changes: 26 additions & 7 deletions deepblast/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@
from deepblast.constants import m
from deepblast.dataset.utils import (
state_f, tmstate_f,
clip_boundaries, states2matrix, states2edges,
path_distance_matrix
remove_gaps, states2matrix, states2edges,
path_distance_matrix, gap_mask
)


def reshape(x, N, M):
if x.shape != (N, M) and x.shape != (M, N):
raise ValueError(f'The shape of `x` {x.shape} '
f'does not agree with ({N}, {M})')
if tuple(x.shape) != (N, M):
return x.t()
else:
return x


class AlignmentDataset(Dataset):
def __init__(self, pairs, tokenizer=UniprotTokenizer()):
self.tokenizer = tokenizer
Expand All @@ -38,13 +48,19 @@ def __iter__(self):
yield self.__getitem__(i)


class FastaDataset(AlignmentDataset):
""" Dataset for searching. """
def __init__(self, query_path, db_path, tokenizer=UniprotTokenizer()):
pass


class TMAlignDataset(AlignmentDataset):
""" Dataset for training and testing.

This is appropriate for the Malisam / Malidup datasets.
"""
def __init__(self, path, tokenizer=UniprotTokenizer(),
tm_threshold=0.4, max_len=1024, pad_ends=False,
tm_threshold=0.5, max_len=1024, pad_ends=False,
clip_ends=True, construct_paths=False):
""" Read in pairs of proteins.

Expand Down Expand Up @@ -126,12 +142,11 @@ def __getitem__(self, i):
pos = self.pairs.iloc[i]['chain2']
states = self.pairs.iloc[i]['alignment']
states = list(map(tmstate_f, states))
if self.clip_ends:
gene, pos, states = clip_boundaries(gene, pos, states)
gene, pos, states = remove_gaps(gene, pos, states, self.clip_ends)
gene_mask, pos_mask = gap_mask(states)

if self.pad_ends:
states = [m] + states + [m]

states = torch.Tensor(states).long()
gene = self.tokenizer(str.encode(gene))
pos = self.tokenizer(str.encode(pos))
Expand All @@ -148,7 +163,11 @@ def __getitem__(self, i):
path_matrix = path_matrix.t()
if tuple(alignment_matrix.shape) != (len(gene), len(pos)):
alignment_matrix = alignment_matrix.t()
return gene, pos, states, alignment_matrix, path_matrix

# gene_mask = torch.Tensor(gene_mask).long()
# pos_mask = torch.Tensor(pos_mask).long()
return (gene, pos, states, alignment_matrix, path_matrix,
gene_mask, pos_mask)


class MaliAlignmentDataset(AlignmentDataset):
Expand Down
9 changes: 4 additions & 5 deletions deepblast/dataset/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def test_getitem(self):
x = TMAlignDataset(self.data_path, tm_threshold=0,
pad_ends=False, clip_ends=False)
res = x[0]
self.assertEqual(len(res), 5)
gene, pos, states, alignment_matrix, _ = res
self.assertEqual(len(res), 7)
gene, pos, states, alignment_matrix, _, _, _ = res
# test the lengths
self.assertEqual(len(gene), 103)
self.assertEqual(len(pos), 21)
self.assertEqual(len(states), 103)
# wtf is going on here??
self.assertEqual(alignment_matrix.shape, (22, 103))
self.assertEqual(alignment_matrix.shape, (103, 21))


class TestMaliDataset(unittest.TestCase):
Expand All @@ -47,7 +46,7 @@ def test_getitem(self):
self.assertEqual(len(gene), 81)
self.assertEqual(len(pos), 81)
self.assertEqual(len(states), 100)
self.assertEqual(alignment_matrix.shape, (81, 82))
self.assertEqual(alignment_matrix.shape, (81, 81))


if __name__ == '__main__':
Expand Down
Loading