Skip to content

Commit

Permalink
remove fullgraph=True from torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 5, 2024
1 parent d53ee5c commit 9099424
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ with torch.no_grad(), torch.cuda.device(rank):
def func(*args, **kwargs):
return F.scaled_dot_product_attention(*args, **kwargs)

func = torch.compile(func, fullgraph=True)
func = torch.compile(func)

for _ in range(2):
mesh = init_device_mesh(device, mesh_shape, mesh_dim_names=("ulysses", "ring"))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _test_attn_func(self, dtype, device, B, H, S_Q, S_KV, D, is_causal, compile)

func = self.attn_func
if compile:
func = torch.compile(func, fullgraph=True)
func = torch.compile(func)

for _ in range(2 if compile else 1):
out_slice = func(
Expand Down Expand Up @@ -93,7 +93,7 @@ def func(*args, **kwargs):
return F.scaled_dot_product_attention(*args, **kwargs)

if compile:
func = torch.compile(func, fullgraph=True)
func = torch.compile(func)

for _ in range(2 if compile else 1):
with self.attn_mode(device):
Expand Down

0 comments on commit 9099424

Please sign in to comment.