Skip to content

Commit

Permalink
[BugFix,Temp] quick fix for edge case of #228
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Oct 28, 2024
1 parent 28c1166 commit c5b9045
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions rl4co/models/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __init__(
mask_inner: bool = True,
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Optional[Callable] = None,
sdpa_fn: Optional[Union[Callable, str]] = "default",
**kwargs,
):
super(PointerAttention, self).__init__()
Expand All @@ -258,9 +258,27 @@ def __init__(

# Projection - query, key, value already include projections
self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias)
self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention
self.check_nan = check_nan

# Defaults for sdpa_fn implementation
# see https://github.com/ai4co/rl4co/issues/228
if isinstance(sdpa_fn, str):
if sdpa_fn == "default":
sdpa_fn = scaled_dot_product_attention
elif sdpa_fn == "simple":
sdpa_fn = scaled_dot_product_attention_simple
else:
raise ValueError(
f"Unknown sdpa_fn: {sdpa_fn}. Available options: ['default', 'simple']"
)
else:
if sdpa_fn is None:
sdpa_fn = scaled_dot_product_attention
log.info(
"Using default scaled_dot_product_attention for PointerAttention"
)
self.sdpa_fn = sdpa_fn

def forward(self, query, key, value, logit_key, attn_mask=None):
"""Compute attention logits given query, key, value, logit key and attention mask.
Expand Down

0 comments on commit c5b9045

Please sign in to comment.