From 916c8d700b546e01d626e22499ceb2f14e249f7a Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 31 Aug 2022 17:38:23 +0900 Subject: [PATCH 1/2] update position encoding to make detr torch fx traceable --- models/position_encoding.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/models/position_encoding.py b/models/position_encoding.py index 73ae39edf..3df273491 100644 --- a/models/position_encoding.py +++ b/models/position_encoding.py @@ -37,7 +37,7 @@ def forward(self, tensor_list: NestedTensor): y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch_arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -87,3 +87,8 @@ def build_position_encoding(args): raise ValueError(f"not supported {args.position_embedding}") return position_embedding + + +@torch.fx.wrap +def torch_arange(x, dtype, device): + return torch.arange(x, dtype=dtype, device=device) From 7d31e0f16a6700d620df2ad6820b786f2a9a5d63 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 31 Aug 2022 18:20:36 +0900 Subject: [PATCH 2/2] fix torch script bug introduced by adding torch fx --- models/position_encoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/position_encoding.py b/models/position_encoding.py index 3df273491..32ea9d458 100644 --- a/models/position_encoding.py +++ b/models/position_encoding.py @@ -90,5 +90,5 @@ def build_position_encoding(args): @torch.fx.wrap -def torch_arange(x, dtype, device): +def torch_arange(x: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: return torch.arange(x, dtype=dtype, device=device)