Skip to content

Commit

Permalink
CUDA Graphs: enable in forward (Lightning-AI#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved authored May 27, 2024
1 parent 11e6143 commit 7e23c5a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
11 changes: 10 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,17 @@ def jit(
assert type(record_history) is bool

# TODO RC1 Refine the compile data option to remove unused options
# TODO: refine options
# NOTE(fixme): use_cudagraphs is being absorbed into compile_options
use_cudagraphs = compile_options.get("use_cudagraphs", False)
cd = CompileData(
fn=fn,
langctx=langctx,
executors_list=executors,
cache_option=cache,
sharp_edges=sharp_edges,
using_jit=True,
use_cudagraphs=False,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd,
use_rematerialization=False,
only_execute_prims=False,
Expand Down Expand Up @@ -581,6 +584,12 @@ def get_computation_and_inputs(*args, **kwargs):
else:
backward_fn = None

# TODO: using vanilla CUDAGraphExecutor is not safe unless the graph is always static!
# (fixme): inspect torch.cuda.make_graph_callables and/or use it instead!
# See https://github.com/Lightning-AI/lightning-thunder/issues/433
if cd.use_cudagraphs:
comp = CUDAGraphExecutor(comp)

# TODO RC1 Update the cache
cache_entry = CacheEntry(
pro,
Expand Down
12 changes: 9 additions & 3 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,22 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype):
torch_result = gpt(idx)

tom = executor.make_callable(gpt, use_cudagraphs=True, disable_torch_autograd=True)
thunder_result = tom(idx)

thunder_result = tom(idx)
assert_close(torch_result, thunder_result)

# Check we really run CUDAGraphExecutor {
assert tom._lc_cd.use_cudagraphs == True

from thunder.cudagraphs import CUDAGraphExecutor

assert type(tom._lc_cs.last_executed) == CUDAGraphExecutor
# }


@instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,))
@requiresCUDA
def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype):
pytest.skip("https://github.com/Lightning-AI/lightning-thunder/issues/1403")

tdtype = ttorch.to_torch_dtype(dtype)

# Creates a nanoGPT model with a smaller size than any of the default options for testing
Expand Down

0 comments on commit 7e23c5a

Please sign in to comment.