From 7e23c5a6a1f9357c89df05fde30610c14b0077fc Mon Sep 17 00:00:00 2001 From: nikitaved Date: Mon, 27 May 2024 15:14:55 +0200 Subject: [PATCH] CUDA Graphs: enable in forward (#430) --- thunder/__init__.py | 11 ++++++++++- thunder/tests/test_networks.py | 12 +++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index adc6d6927b..e545564b1e 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -322,6 +322,9 @@ def jit( assert type(record_history) is bool # TODO RC1 Refine the compile data option to remove unused options + # TODO: refine options + # NOTE(fixme): use_cudagraphs is being absorbed into compile_options + use_cudagraphs = compile_options.get("use_cudagraphs", False) cd = CompileData( fn=fn, langctx=langctx, @@ -329,7 +332,7 @@ def jit( cache_option=cache, sharp_edges=sharp_edges, using_jit=True, - use_cudagraphs=False, + use_cudagraphs=use_cudagraphs, disable_torch_autograd_support=disable_torch_autograd, use_rematerialization=False, only_execute_prims=False, @@ -581,6 +584,12 @@ def get_computation_and_inputs(*args, **kwargs): else: backward_fn = None + # TODO: using vanilla CUDAGraphExecutor is not safe unless the graph is always static! + # (fixme): inspect torch.cuda.make_graph_callables and/or use it instead! + # See https://github.com/Lightning-AI/lightning-thunder/issues/433 + if cd.use_cudagraphs: + comp = CUDAGraphExecutor(comp) + # TODO RC1 Update the cache cache_entry = CacheEntry( pro, diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 15760f2948..cac1666cab 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -79,16 +79,22 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype): torch_result = gpt(idx) tom = executor.make_callable(gpt, use_cudagraphs=True, disable_torch_autograd=True) - thunder_result = tom(idx) + thunder_result = tom(idx) assert_close(torch_result, thunder_result) + # Check we really run CUDAGraphExecutor { + assert tom._lc_cd.use_cudagraphs == True + + from thunder.cudagraphs import CUDAGraphExecutor + + assert type(tom._lc_cs.last_executed) == CUDAGraphExecutor + # } + @instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,)) @requiresCUDA def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype): - pytest.skip("https://github.com/Lightning-AI/lightning-thunder/issues/1403") - tdtype = ttorch.to_torch_dtype(dtype) # Creates a nanoGPT model with a smaller size than any of the default options for testing