Skip to content

Commit

Permalink
streamline autograd function wrapping (Lightning-AI#892)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jul 31, 2024
1 parent d202ba3 commit 67ec2c6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 64 deletions.
8 changes: 2 additions & 6 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,13 +921,9 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
args_, kwargs_ = tree_map(unwrap, (args, kwargs))
ctx = torch.autograd.function.FunctionCtx()

pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(custom_forward).provenance])
wrapped_ctx = wrap(ctx, provenance=pr)
args_, kwargs_ = tree_map(lambda a: wrap(a, provenance=pr), (args_, kwargs_))
return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_)
wrapped_ctx = wrap_const(ctx)
return _interpret_call(custom_forward, wrapped_ctx, *args, **kwargs)


@register_general_jit_lookaside(torch.finfo)
Expand Down
58 changes: 0 additions & 58 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,64 +2850,6 @@ def foo(x):
torch.testing.assert_close(actual, expected)


@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`")
def test_custom_autograd_function():
from torch.autograd.gradcheck import GradcheckError
from torch.testing._internal.common_utils import gradcheck

class MyFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x * 2.0

# this is wrong on purpose.
@staticmethod
def backward(ctx, grad_output) -> torch.Tensor:
return grad_output

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x) -> torch.Tensor:
return MyFunction.apply(x)

x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)

gradcheck(jitted, (x,))
with pytest.raises(GradcheckError):
gradcheck(model, (x,))

class MyLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(x)
ctx.pretty_attr = 100
return torch.matmul(x, weight.t())

@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return torch.matmul(grad_output, weight), torch.matmul(grad_output.t(), x)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(2, 2, bias=False)

def forward(self, x):
return MyLinear.apply(x, self.l1.weight)

x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)

gradcheck(jitted, (x,))


def test_proxy_repr():
# Verify that we can call `__repr__` on different proxy subclasses.
t = thunder.core.trace.TraceCtx()
Expand Down
60 changes: 60 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,3 +1135,63 @@ def foo(t, batch_size):
assert_close(expected, actual)
assert thunder.cache_misses(jfoo) == 1
assert thunder.cache_hits(jfoo) == 1


@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`")
def test_custom_autograd_function():
from torch.autograd.gradcheck import GradcheckError
from torch.testing._internal.common_utils import gradcheck

class MyFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x * 2.0

# this is wrong on purpose.
@staticmethod
def backward(ctx, grad_output) -> torch.Tensor:
return grad_output

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x) -> torch.Tensor:
return MyFunction.apply(x)

x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)

gradcheck(jitted, (x,))
with pytest.raises(GradcheckError):
gradcheck(model, (x,))

class MyLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, shape: tuple) -> torch.Tensor:
ctx.shape = shape
ctx.save_for_backward(x, weight)
ctx.pretty_attr = 100
return torch.matmul(x, weight.t())

@staticmethod
def backward(ctx, grad_output):
(x, weight) = ctx.saved_tensors
assert weight.shape == ctx.shape # really bogus, just to use ctx.shape
return torch.matmul(grad_output, weight), torch.matmul(grad_output.t(), x)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(2, 2, bias=False)

def forward(self, x):
return MyLinear.apply(x, self.l1.weight, self.l1.weight.shape)

x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)

gradcheck(jitted, (x,))

0 comments on commit 67ec2c6

Please sign in to comment.