Skip to content

Commit

Permalink
grad rule for copy_with_setitem (Lightning-AI#1322)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Oct 29, 2024
1 parent 49e4b57 commit b28d5b3
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
20 changes: 20 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,26 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /):
# This operation creates no grad associations
register_grad(pids.SHAPE, prims.shape)


def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):
fwd = prims.copy_with_setitem(a, index, value)
g = get_grad(fwd)

a_grad = prims.copy_with_setitem(g, index, 0)
put_grad(a, a_grad)

if isinstance(value, TensorProxy):
value_grad = g[index]
expanded_dims = value_grad.ndim - value.ndim
if expanded_dims > 0:
value_grad = prims.sum(value_grad, tuple(range(expanded_dims)))
put_grad(value, value_grad)

return fwd


register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)

#
# Phantom grad transform helpers
#
Expand Down
62 changes: 52 additions & 10 deletions thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,60 @@ def foo():
tfoo()


def test_setitem():
def fn(a):
a[:3] = 2
@pytest.mark.parametrize("requires_grad", (True, False))
def test_setitem(requires_grad):

def _test_forward_and_backward(fn, a, value):
a_ref = a.detach().clone()
a_ref.requires_grad_(a.requires_grad)

if isinstance(value, torch.Tensor):
value_ref = value.detach().clone()
value_ref.requires_grad_(value.requires_grad)
else:
value_ref = value

out_ref = fn(a_ref, value_ref)
jf = thunder.jit(fn)
out = jf(a, value)
assert_close(a, a_ref)
assert_close(out, out_ref)

if requires_grad:
g = torch.randn_like(out)
inputs = (a, value) if isinstance(value, torch.Tensor) else (a,)
actual_grad = torch.autograd.grad(out, inputs, g)

inputs_ref = (a_ref, value_ref) if isinstance(value, torch.Tensor) else (a_ref,)
expected_grad = torch.autograd.grad(out_ref, inputs_ref, g)
assert_close(actual_grad, expected_grad)

def clone_if_requires_grad(a):
if requires_grad:
# Withou the clone
# PyTorch eager errors with
# `RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.`
# and thunder has silent correctness issue - https://github.com/Lightning-AI/lightning-thunder/issues/1284
return a.clone()
return a

def fn(a, value):
a = clone_if_requires_grad(a)
a[:3] = value
return a * 2

a_ref = torch.ones(5)
out_ref = fn(a_ref)
a = torch.ones(5)
jf = thunder.jit(fn)
out = jf(a)
assert_close(a, a_ref)
assert_close(out, out_ref)
# set value: scalar
_test_forward_and_backward(fn, torch.randn(5, requires_grad=requires_grad), 2.0)

# set value: tensor which needs to be broadcasted
_test_forward_and_backward(
fn, torch.randn(5, requires_grad=requires_grad), torch.tensor(2.0, requires_grad=requires_grad)
)

# set value: tensor of same rank
_test_forward_and_backward(
fn, torch.randn(5, requires_grad=requires_grad), torch.tensor([1.0, 2.0, 3.0], requires_grad=requires_grad)
)


# TODO: Add random operator support to OpInfo
Expand Down

0 comments on commit b28d5b3

Please sign in to comment.