Skip to content

Commit

Permalink
Add pad tests (Lightning-AI#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
tfogal authored Jun 6, 2024
1 parent 0c6e38c commit d3c06ff
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3748,6 +3748,11 @@ def pad_torch_sample_generator(op, device, dtype, requires_grad, **kwargs):
for shape, padding_config in cases:
yield SampleInput(make(shape), padding_config, "constant", make_number(dtype=dtype))

# The `value` parameter of the pad op is unceremoniously cast to the type of the
# tensor being padded. Yield some tests with explicitly-different data types.
yield SampleInput(make((2, 3)), pad=(1, 2), value=6.4)
yield SampleInput(make((2,)), pad=(1, 2), value=1)


def pad_torch_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)
Expand Down

0 comments on commit d3c06ff

Please sign in to comment.