From 9068564d042795caa65f776e1eaf10efd5c6d833 Mon Sep 17 00:00:00 2001 From: Benkovichnikita Date: Tue, 19 Nov 2024 23:41:10 +0400 Subject: [PATCH] change resampler attention to scale_dot_product_attention --- ip_adapter/resampler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ip_adapter/resampler.py b/ip_adapter/resampler.py index 2426667..6f7e50c 100644 --- a/ip_adapter/resampler.py +++ b/ip_adapter/resampler.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange @@ -68,10 +69,7 @@ def forward(self, x, latents): v = reshape_tensor(v, self.heads) # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v + out = F.scale_dot_product_attention(q, k, v) out = out.permute(0, 2, 1, 3).reshape(b, l, -1)