Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Apr 24, 2024
1 parent b805c45 commit f4ee626
Showing 1 changed file with 27 additions and 43 deletions.
70 changes: 27 additions & 43 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4389,51 +4389,35 @@ def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs):
make_source = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# NOTE The value gradient is only valid when src.shape == index.shape.
if requires_grad:
for shape_a, dim, shape_b in take_along_axis_cases:
canonicalized_dim = dim if dim >= 0 else dim + len(shape_a)
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
c = make_source(shape_b)
yield SampleInput(a, index=b, src=c, dim=dim)

# a.shape, dim, index.shape
scatter_add_cases = (
((4, 5, 3), 0, (3, 2, 3)),
((4, 5, 3), 1, (3, 5, 2)),
((4, 5, 3), 2, (3, 2, 8)),
)
for shape_a, dim, shape_b in scatter_add_cases:
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
c = make_source(shape_b)
yield SampleInput(a, index=b, src=c, dim=dim)
else:
for shape_a, dim, shape_b in take_along_axis_cases:
canonicalized_dim = dim if dim >= 0 else dim + len(shape_a)
# For gradient testing, we use the index shape for the source tensor.
for shape_a, dim, shape_b in take_along_axis_cases:
canonicalized_dim = dim if dim >= 0 else dim + len(shape_a)
if requires_grad:
shape_source = shape_b
else:
shape_source = list(shape_a)
shape_source[canonicalized_dim] = shape_b[canonicalized_dim]
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
c = make_source(shape_source)
yield SampleInput(a, index=b, src=c, dim=dim)

# Questionable use case. Do we want to support these?!
# Note that scatter_add doesn't have the broadcast requirement, it only requires
# 1. a.shape[i] >= index.shape[i] for i != dim
# 2. source.shape[i] >= index.shape[i] for all i
#
# a.shape, dim, index.shape, source.shape
scatter_add_cases = (
((4, 5, 3), 0, (3, 2, 3), (4, 3, 9)),
((4, 5, 3), 1, (3, 5, 2), (3, 8, 8)),
((4, 5, 3), 2, (3, 2, 8), (5, 8, 8)),
)
for shape_a, dim, shape_b, shape_source in scatter_add_cases:
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
c = make_source(shape_source)
yield SampleInput(a, index=b, src=c, dim=dim)
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
c = make_source(shape_source)
yield SampleInput(a, index=b, src=c, dim=dim)

# Questionable use case. Do we want to support these?!
# Note that scatter_add doesn't have the broadcast requirement, it only requires
# 1. a.shape[i] >= index.shape[i] for i != dim
# 2. source.shape[i] >= index.shape[i] for all i
#
# a.shape, dim, index.shape, source.shape
scatter_add_cases = (
((4, 5, 3), 0, (3, 2, 3), (4, 3, 9)),
((4, 5, 3), 1, (3, 5, 2), (3, 8, 8)),
((4, 5, 3), 2, (3, 2, 8), (5, 8, 8)),
)
for shape_a, dim, shape_b, shape_source in scatter_add_cases:
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
c = make_source(shape_b if requires_grad else shape_source)
yield SampleInput(a, index=b, src=c, dim=dim)


scatter_add_opinfo = OpInfo(
Expand Down

0 comments on commit f4ee626

Please sign in to comment.