Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SDPA backend tests and refactor generate.py #1477

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mikekgfb
Copy link
Contributor

Push backend manager into caller

Push backend manager into caller
Copy link

pytorch-bot bot commented Jan 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1477

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f44fad1 with merge base 9686c79 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 24, 2025
Add tests for backends
@mikekgfb mikekgfb changed the title Update generate.py Add SDPA backend tests and refactor generate.py Jan 24, 2025
@mikekgfb
Copy link
Contributor Author

The new attention-backend option has a new interesting wrinkle: it only specifies one backend, but if that kernel does not work for paramters, things fail. At a minimum we may always to have specify "math" as a backup? (see specific error below)

Also, should we have a backend option "auto" that relies simply on the sdpa logic to find the best kernel?

  ******* cuda bfloat16 flash_attention 
  + python torchchat.py generate --checkpoint-path checkpoints/stories15M/stories15M.pt --attention-backend flash_attention --device cuda --dtype bfloat16 --temperature 0
  NumExpr defaulting to 16 threads.
  PyTorch version 2.7.0.dev20250124+cu124 available.
  Warning: PTEModel (ExecuTorch) not available with exception: No module named 'executorch'
  Unable to import torchao experimental quant_api with error:  [Errno 2] No such file or directory: '/pytorch/torchchat/torchao-build/src/ao/torchao/experimental/quant_api.py'
  Using device=cuda NVIDIA A10G
  Loading model...
  Time to load model: 0.18 seconds
  -----------------------------------------------------------
  /pytorch/torchchat/torchchat/model.py:897: UserWarning: Memory efficient kernel not used because: (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:784.)
    y = F.scaled_dot_product_attention(
  /pytorch/torchchat/torchchat/model.py:897: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:548.)
    File "/home/ec2-user/actions-runner/_work/torchchat/torchchat/test-infra/.github/scripts/run_with_env_secrets.py", line 102, in <module>
      main()
    File "/home/ec2-user/actions-runner/_work/torchchat/torchchat/test-infra/.github/scripts/run_with_env_secrets.py", line 98, in main
      run_cmd_or_die(f"docker exec -t {container_name} /exec")
    File "/home/ec2-user/actions-runner/_work/torchchat/torchchat/test-infra/.github/scripts/run_with_env_secrets.py", line 39, in run_cmd_or_die
      raise RuntimeError(f"Command {cmd} failed with exit code {exit_code}")
  RuntimeError: Command docker exec -t 0fdb53b0e3bc3278763141f75f84b0cac3a9f6b5b3f38b52267a9b6b16de883b /exec failed with exit code 1
    y = F.scaled_dot_product_attention(
  /pytorch/torchchat/torchchat/model.py:897: UserWarning: Flash attention kernel not used because: (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:786.)
    y = F.scaled_dot_product_attention(
  /pytorch/torchchat/torchchat/model.py:897: UserWarning: Flash Attention does not support non-null attn_mask. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.)
    y = F.scaled_dot_product_attention(
  /pytorch/torchchat/torchchat/model.py:897: UserWarning: CuDNN attention kernel not used because: (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:788.)
    y = F.scaled_dot_product_attention(
  /pytorch/torchchat/torchchat/model.py:897: UserWarning: CuDNN attention has been runtime disabled. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:535.)
    y = F.scaled_dot_product_attention(
  Traceback (most recent call last):
    File "/pytorch/torchchat/torchchat.py", line 96, in <module>
      generate_main(args)
    File "/pytorch/torchchat/torchchat/generate.py", line 1648, in main
      run_generator(args)
    File "/pytorch/torchchat/torchchat/generate.py", line 1630, in run_generator
      for _ in gen.chat(generator_args):
    File "/pytorch/torchchat/torchchat/generate.py", line 1194, in chat
      for token_tensor, metrics in generator_func:
    File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 36, in generator_context
      response = gen.send(None)
                 ^^^^^^^^^^^^^^
    File "/pytorch/torchchat/torchchat/generate.py", line 735, in generate
      next_token = self.prefill(
                   ^^^^^^^^^^^^^
    File "/pytorch/torchchat/torchchat/generate.py", line 491, in prefill
      logits = model(x, input_pos)
               ^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/pytorch/torchchat/torchchat/model.py", line 568, in forward
      return self.model(tokens, input_pos)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/pytorch/torchchat/torchchat/model.py", line 736, in forward
      x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/pytorch/torchchat/torchchat/model.py", line 770, in forward
      h = x + self.attention(
              ^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/pytorch/torchchat/torchchat/model.py", line 897, in forward
      y = F.scaled_dot_product_attention(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  RuntimeError: No available kernel. Aborting execution.
  Error: Process completed with exit code 1.

@Jack-Khuu

@Jack-Khuu
Copy link
Contributor

it only specifies one backend, but if that kernel does not work for paramters, things fail.

In a perfect case, we would have a support matrix for configurations we are confident in (with backing tests), and spit out right at the start, a warning when arguments trek into territory that we are not guaranteeing cc: @yanbing-j

At a minimum we may always to have specify "math" as a backup?

I'm cautious of torchchat's use of fallbacks (and in other pytorch projects in general, but that's a different bag of worms), since it opens the door for misinterpretation; Users thinking we are doing what they ask, when in reality it's succeeding with a different config.
In a perfect world, if an argument isn't in a config we are confident in, we should

  • warn early and fail early [In prebuilt package] (packaging is something we'll add this half)
  • warn early and continue execution to potential failure [In dev mode]

Thoughts?

Also, should we have a backend option "auto" that relies simply on the sdpa logic to find the best kernel?

This sounds amazing. Having an explicit list with justifications for the rankings also let's us catch if numerics look fishy

@yanbing-j
Copy link
Contributor

it only specifies one backend, but if that kernel does not work for paramters, things fail.

In a perfect case, we would have a support matrix for configurations we are confident in (with backing tests), and spit out right at the start, a warning when arguments trek into territory that we are not guaranteeing cc: @yanbing-j

Confirmed that for CPU side, only math and flash_attention can be chosen as attention backend. And so far, flash_attention does not encounter the issue of that kernel does not work for parameters from our validation. Please ping me when error occurs and we need investigate case by case.

@mikekgfb
Copy link
Contributor Author

@Jack-Khuu:

I'm cautious of torchchat's use of fallbacks (and in other pytorch projects in general, but that's a different bag of worms), since it opens the door for misinterpretation; Users thinking we are doing what they ask, when in reality it's succeeding with a different config.
In a perfect world, if an argument isn't in a config we are confident in, we should

warn early and fail early [In prebuilt package] (packaging is something we'll add this half)
warn early and continue execution to potential failure [In dev mode]

I'm not so enamored of the "fail early". "Fail early" turns into a "no configuration ever works unless perfect", and that quickly becomes the empty set with overly fragile systems. I agree with the "misleading" configuration issue. A user might feel that giving them MATH when they asked for FLASH. This one in particular is a very real concern -- at least in the past, the causal mask needed to be specified by a bool flag for flash attention to be selected, and having a general mask forced the model into the MATH fallback because it could practically be any mask at runtime - The transformer.py in pytorch detects that in a very brute force way -> https://github.com/pytorch/pytorch/blob/30dea8429d977bfaf0f7e17bd39c2b7d1794359c/torch/nn/modules/transformer.py#L1170

OTOH, if we don't posit that things will just work. The moral equivalent in my mind is the example of a compiler that will only compile programs it has been tested on, because everything else can't be expected to work. That's not useful for the end user, and transfer a lot of responsibility to them -- responsibility that they don't have enough data to make a meaningful call. Imagine if every compiler returned "Unknown program. Compiled code not guaranteed to be correct." and thereby transfers the responsibility to the compiler user.

mikekgfb added a commit to mikekgfb/torchchat-1 that referenced this pull request Jan 27, 2025
Tests from pytorch#1477 only, without the generate.py refactor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants