Skip to content

Commit

Permalink
Merge pull request #111 from jrzaurin/fix_additive_attention
Browse files Browse the repository at this point in the history
Fix additive attention (#110)
  • Loading branch information
jrzaurin authored Oct 7, 2022
2 parents 1c8709f + 7e5d118 commit f3297c4
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0
1.2.1
10 changes: 8 additions & 2 deletions examples/notebooks/10_3rd_party_integration-RayTune_WnB.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
" quantity.\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" wb: object,\n",
Expand Down Expand Up @@ -1061,7 +1062,7 @@
" \"wandb\": {\n",
" \"project\": \"test\",\n",
" # \"api_key_file\": os.getcwd() + \"/wandb_api.key\",\n",
" \"api_key\": \"WNB_API_KEY\", \n",
" \"api_key\": \"WNB_API_KEY\",\n",
" },\n",
"}\n",
"\n",
Expand All @@ -1080,7 +1081,12 @@
" trainer = Trainer(\n",
" model,\n",
" objective=\"binary_focal_loss\",\n",
" callbacks=[RayTuneReporter, WnBReportBest(wb=wandb), early_stopping, model_checkpoint],\n",
" callbacks=[\n",
" RayTuneReporter,\n",
" WnBReportBest(wb=wandb),\n",
" early_stopping,\n",
" model_checkpoint,\n",
" ],\n",
" lr_schedulers={\"deeptabular\": deep_sch},\n",
" initializers={\"deeptabular\": XavierNormal},\n",
" optimizers={\"deeptabular\": deep_opt},\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
" quantity.\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" wb: object,\n",
Expand Down Expand Up @@ -1061,7 +1062,7 @@
" \"wandb\": {\n",
" \"project\": \"test\",\n",
" # \"api_key_file\": os.getcwd() + \"/wandb_api.key\",\n",
" \"api_key\": \"WNB_API_KEY\", \n",
" \"api_key\": \"WNB_API_KEY\",\n",
" },\n",
"}\n",
"\n",
Expand All @@ -1080,7 +1081,12 @@
" trainer = Trainer(\n",
" model,\n",
" objective=\"binary_focal_loss\",\n",
" callbacks=[RayTuneReporter, WnBReportBest(wb=wandb), early_stopping, model_checkpoint],\n",
" callbacks=[\n",
" RayTuneReporter,\n",
" WnBReportBest(wb=wandb),\n",
" early_stopping,\n",
" model_checkpoint,\n",
" ],\n",
" lr_schedulers={\"deeptabular\": deep_sch},\n",
" initializers={\"deeptabular\": XavierNormal},\n",
" optimizers={\"deeptabular\": deep_opt},\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,14 @@ def forward(self, X: Tensor) -> Tensor:
v = self.qv_proj(X) if self.share_qv_weights else self.v_proj(X)
k = self.k_proj(X)

alphas = (self.W_q(q) / math.sqrt(self.head_dim)).softmax(dim=-1)
alphas = (self.W_q(q) / math.sqrt(self.head_dim)).softmax(dim=1)
q_r = einops.rearrange(q, "b s (h d) -> b s h d", h=self.n_heads)
global_query = einsum(" b s h, b s h d -> b h d", alphas, q_r)
global_query = einops.rearrange(global_query, "b h d -> b () (h d)")

p = k * global_query

betas = (self.W_k(p) / math.sqrt(self.head_dim)).softmax(dim=-1)
betas = (self.W_k(p) / math.sqrt(self.head_dim)).softmax(dim=1)
p_r = einops.rearrange(p, "b s (h d) -> b s h d", h=self.n_heads)
global_key = einsum(" b s h, b s h d -> b h d", betas, p_r)
global_key = einops.rearrange(global_key, "b h d -> b () (h d)")
Expand Down
2 changes: 1 addition & 1 deletion pytorch_widedeep/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.2.1"

0 comments on commit f3297c4

Please sign in to comment.