From 9099424b2829c6f72946c6429b7766161e5860a9 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Tue, 5 Nov 2024 09:14:50 +0000 Subject: [PATCH] remove fullgraph=True from torch.compile --- README.md | 2 +- tests/test_interface.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6ffc919..34a34a5 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ with torch.no_grad(), torch.cuda.device(rank): def func(*args, **kwargs): return F.scaled_dot_product_attention(*args, **kwargs) - func = torch.compile(func, fullgraph=True) + func = torch.compile(func) for _ in range(2): mesh = init_device_mesh(device, mesh_shape, mesh_dim_names=("ulysses", "ring")) diff --git a/tests/test_interface.py b/tests/test_interface.py index 1b9fc42..df40ee4 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -48,7 +48,7 @@ def _test_attn_func(self, dtype, device, B, H, S_Q, S_KV, D, is_causal, compile) func = self.attn_func if compile: - func = torch.compile(func, fullgraph=True) + func = torch.compile(func) for _ in range(2 if compile else 1): out_slice = func( @@ -93,7 +93,7 @@ def func(*args, **kwargs): return F.scaled_dot_product_attention(*args, **kwargs) if compile: - func = torch.compile(func, fullgraph=True) + func = torch.compile(func) for _ in range(2 if compile else 1): with self.attn_mode(device):