From b4621e18e736959bb7be62b895c5313e013a6b24 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 26 Apr 2024 21:41:25 -0700 Subject: [PATCH] fix --- thunder/core/transforms.py | 5 ++++- thunder/tests/opinfos.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index bf6843fe43..85a0af4921 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -701,7 +701,10 @@ def _broadcast_in_dim_prim_grad( bcast_dims = tuple(b for i, b in enumerate(broadcast_dimensions) if i not in unit_dims) reduce_dims = tuple(s for i, s in enumerate(range(len(shape))) if i not in bcast_dims) - g = ltorch.sum(g, reduce_dims) + # NOTE When the reduce_dims tuple is empty, pytorch reduces all dimensions. + # In this case, we do not want to reduce any dimensions, so skip this sum. + if len(reduce_dims) > 0: + g = ltorch.sum(g, reduce_dims) # NOTE This must be clang.unsqueeze because torch.unsqueeze, unlike clang.unsqueeze, only accepts an integer # (put another way, torch only allows one unsqueeze at a time) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index ea1bc1a469..18b7f4ff0c 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2786,6 +2786,7 @@ def broadcast_in_dim_sample_generator(op, device, dtype, requires_grad, **kwargs # inshape, outshape, dims cases = ( + ([5], [5], [0]), ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]),