In this exercise we will write a attention mechanism to solve the following problem
"Given a length 10 sequence of integers, are there more '2s' than '4s' ?"
This could of course be easily handled by a fully connected network, but we'll force the network to learn this by learning to place attention on the right values. I.e. the strategy is
- embed the individual integer into a high-dimensional vector (using torch.nn.Embedding)
- once we have those embeddings compute how much attention to place on each vector by comparing "key" values computed from the embedded values
- compute "answer" values to our query by weighting the individual responsed by their attention value
$$ v_{ik} = \mathrm{softmax}\mathrm{keys}(\frac{q{i}\cdot k_{j}}{\sqrt{d}}) v_{jk} $$
Write a data-generating function that produces a batch of N examples or length-10 sequences of random integers between 0 and 9 as well as a binary label indicating whether the sequence has more 2s than 4s. The output should return (X,y), where X has shape (N,10)
and y has shape (N,1)
Deep Learning works well in higher dimensions. So we'll embed the 10 possible integers into a vector space using torch.nn.Embedding
- Verify that using e.g. a module like
torch.nn.Embedding(10,embed_dim)
achieves this - Take a random vector of integers of shape (N,M) and evaluate them through an embedding module
- Does the output dimension make sense?
Once data is embedded we can extract keys and values by a linear projection
- For 2 linear layers
torch.nn.Linear(embed_dim,att_dim)
we can extract keys and values for the output of the previous step - verify that this works from a shape perspective
Implement the Attention-formula from above in a batched manner, such that for a input set of sequences (N,10)
you get an output set of attention-weighted values (N,1)
- to test use a random "single query" vector of shape
(1,att_dim)
- It's easiest when using the function
torch.einsum
which uses the Einstein summation you may be familiar with from special relativity - e.g. a "batched" dot product is performed using
einsum('bik,bjk->bij')
whereb
indicates the batch index,i
andj
are position indices andk
are the coordinates of the vectors
Complete the following torch Module:
class AttentionModel(torch.nn.Module):
def __init__(self):
super(AttentionModel,self).__init__()
self.embed_dim = 5
self.att_dim = 5
self.embed = torch.nn.Embedding(10,self.embed_dim)
#one query
self.query = torch.nn.Parameter(torch.randn(1,self.att_dim))
#used to compute keys
self.WK = torch.nn.Linear(self.embed_dim,self.att_dim)
#used to compute values
self.WV = torch.nn.Linear(self.embed_dim,1)
#final decision based on attention-weighted value
self.nn = torch.nn.Sequential(
torch.nn.Linear(1,200),
torch.nn.ReLU(),
torch.nn.Linear(200,1),
torch.nn.Sigmoid(),
)
def attention(self,x):
# compute attention
...
def values(self,x):
# compute values
...
def forward(self,x):
# compute final classification using attention, values and final NN
Once you have the model completed, train the model. My suggestion is to use a batch size of 500 and a learning rate of 5e-5.
- Write a training loop that trains on binary cross entropy.
You can visualize the output e.g. using this plotting function, that receives the model, a batch size as well as a loss trajectory. If everything is working right we should be seeing that the model is paying attention mostly on the 2s and 4s as we would expect
def plot(model,N,traj):
x,y = make_batch(N)
f,axarr = plt.subplots(1,3)
f.set_size_inches(13,2)
ax = axarr[0]
at = model.attention(model.embed(x))[:,0,:].detach().numpy()
ax.imshow(at)
ax = axarr[1]
vals = model.values(model.embed(x))[:,:,0].detach().numpy()
nan = np.ones_like(vals)*np.nan
nan = np.where(at > 0.1, vals, nan)
ax.imshow(nan,vmin = -1, vmax = 1)
for i,xx in enumerate(x):
for j,xxx in enumerate(xx):
ax = axarr[0]
ax.text(j,i,xxx.numpy(), c = 'r' if (xxx in [2,4]) else 'w')
ax = axarr[1]
ax.text(j,i,xxx.numpy(), c = 'r' if (xxx in [2,4]) else 'w')
ax = axarr[2]
ax.plot(traj)