Skip to content

Commit

Permalink
Merge pull request #8 from nod-ai/triton_and_f8_attention
Browse files Browse the repository at this point in the history
Triton harness & FP8 Attention
  • Loading branch information
suryajasper authored Aug 9, 2024
2 parents e1cf1ad + 30275b7 commit 8f0bd1d
Show file tree
Hide file tree
Showing 534 changed files with 6,508 additions and 448 deletions.
17 changes: 14 additions & 3 deletions .github/workflows/run_bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=rocblas --repeat=1 --output=results/rocblas.hdf
./gb run --backends=rocblas --repeat=1 --output=results/rocblas.csv
sudo pkill -f gemm-bench
deactivate
Expand All @@ -72,16 +72,27 @@ jobs:
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=hipblaslt --repeat=1 --output=results/hipblaslt.hdf
./gb run --backends=hipblaslt --repeat=1 --output=results/hipblaslt.csv
sudo pkill -f gemm-bench
deactivate
- name: Run IREE Benchmarks
run: |
sudo pkill -f gemm-bench
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=iree --repeat=1 --output=results/iree.csv
sudo pkill -f gemm-bench
deactivate
- name: Run SHARK Attention Benchmarks
run: |
sudo pkill -f gemm-bench
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=iree --repeat=1 --output=results/iree.hdf
./gb run --backends=sharkfa --suite=flash_attention --repeat=1 --output=results/sharkfa_llama_sdxl_attention.csv
sudo pkill -f gemm-bench
deactivate
Expand Down
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@ build
rocm_gemm_venv
*venv
__pycache__
*.hdf
attention/mlir
attention/vmfb
*.hdf
7 changes: 5 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,14 @@
"syncstream": "c",
"runtime.hpp": "c",
"strstream": "cpp",
"__functional_03": "cpp"
"__functional_03": "cpp",
"hash_map": "cpp",
"hash_set": "cpp"
},
"editor.formatOnSave": true,
"editor.rulers": [80],
"clang-format.executable": "usr/bin/clang-format",
"C_Cpp.errorSquiggles": "enabled",
"dotnet.defaultSolution": "disable"
"dotnet.defaultSolution": "disable",
"git.ignoreLimitWarning": true
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV1024_DH128_f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV1024_DH128_f8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV1024_DH128_fp16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV1024_DH128_fp8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV2048_DH128_f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV2048_DH128_f8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV2048_DH128_fp16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV2048_DH128_fp8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV4096_DH128_f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,4096,128],f16>, %301 : !torch.vtensor<[1,32,4096,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV4096_DH128_f8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,4096,128],f16>, %301 : !torch.vtensor<[1,32,4096,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV4096_DH128_fp16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,4096,128],f16>, %301 : !torch.vtensor<[1,32,4096,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ1024_SKV4096_DH128_fp8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,1024,128],f16>, %298 : !torch.vtensor<[1,32,4096,128],f16>, %301 : !torch.vtensor<[1,32,4096,128],f16>) -> !torch.vtensor<[1,32,1024,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.vtensor<[1,32,4096,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024], f32>)
return %282#0 : !torch.vtensor<[1,32,1024,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV1024_DH128_f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV1024_DH128_f8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV1024_DH128_fp16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV1024_DH128_fp8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,1024,128],f16>, %301 : !torch.vtensor<[1,32,1024,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.vtensor<[1,32,1024,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV2048_DH128_f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV2048_DH128_f8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV2048_DH128_fp16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
11 changes: 11 additions & 0 deletions attention/mlir/attention_B1_H32_SQ2048_SKV2048_DH128_fp8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

module {
func.func @main_0(%295 : !torch.vtensor<[1,32,2048,128],f16>, %298 : !torch.vtensor<[1,32,2048,128],f16>, %301 : !torch.vtensor<[1,32,2048,128],f16>) -> !torch.vtensor<[1,32,2048,128],f16> {
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,2048,128],f16>, !torch.vtensor<[1,32,2048], f32>)
return %282#0 : !torch.vtensor<[1,32,2048,128],f16>
}
}
Loading

0 comments on commit 8f0bd1d

Please sign in to comment.