Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Apr 27, 2024
1 parent 5db4698 commit b4621e1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,10 @@ def _broadcast_in_dim_prim_grad(
bcast_dims = tuple(b for i, b in enumerate(broadcast_dimensions) if i not in unit_dims)
reduce_dims = tuple(s for i, s in enumerate(range(len(shape))) if i not in bcast_dims)

g = ltorch.sum(g, reduce_dims)
# NOTE When the reduce_dims tuple is empty, pytorch reduces all dimensions.
# In this case, we do not want to reduce any dimensions, so skip this sum.
if len(reduce_dims) > 0:
g = ltorch.sum(g, reduce_dims)

# NOTE This must be clang.unsqueeze because torch.unsqueeze, unlike clang.unsqueeze, only accepts an integer
# (put another way, torch only allows one unsqueeze at a time)
Expand Down
1 change: 1 addition & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,7 @@ def broadcast_in_dim_sample_generator(op, device, dtype, requires_grad, **kwargs

# inshape, outshape, dims
cases = (
([5], [5], [0]),
([2], [2, 2], [0]),
([2], [2, 2], [1]),
([2], [2, 3], [0]),
Expand Down

0 comments on commit b4621e1

Please sign in to comment.