Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh user guide for 3.1 release #724

Open
amd-chrissosa opened this issue Dec 20, 2024 · 5 comments
Open

Refresh user guide for 3.1 release #724

amd-chrissosa opened this issue Dec 20, 2024 · 5 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@amd-chrissosa
Copy link
Contributor

Include new use cases including:

  • Refreshing SDXL guide with any changes since 3.0
  • Adding pointers to both the developer and user guides for Llama 3.1 supported variants with a special call out to distributed inference demo.
  • Extend above with support for Llama 3.1 variants that require multi-device
  • Expanding on the SDXL guide to include instructions for Flux.
@amd-chrissosa amd-chrissosa self-assigned this Dec 20, 2024
@amd-chrissosa amd-chrissosa added the documentation Improvements or additions to documentation label Dec 20, 2024
@amd-chrissosa
Copy link
Contributor Author

Getting the latest is an interesting combination:

  • Install nightly of sharktank, shortfin
  • Then install sharkaiapps

@amd-chrissosa
Copy link
Contributor Author

amd-chrissosa commented Dec 21, 2024

Ran into first issue - may be due to not installing latest deps:

/home/sosa/export/model.mlir:1072:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
    %363 = torch.prims.convert_element_type %362, %int5_65 : !torch.vtensor<[1,?,32,128],f32>, !torch.int -> !torch.vtensor<[1,?,32,128],f16>
           ^
/home/sosa/export/model.mlir:1072:12: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg2: !hal.device, %arg3: index):
    %67 = "arith.constant"() <{value = 1 : index}> : () -> index
    %68 = "arith.constant"() <{value = 32 : index}> : () -> index
    "hal.return"(%68, %arg3, %67) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index, subgroup_size = 64 : index, sym_name = "prefill_bs1$async_dispatch_6_elementwise_broadcast_Dx32x128_f32xf16", workgroup_size = [128 : index, 1 : index, 1 : index]} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "prefill_bs1$async_dispatch_6_elementwise_broadcast_Dx32x128_f32xf16"}> ({
      %0 = "arith.constant"() <{value = dense<0> : vector<1x1x1xindex>}> : () -> vector<1x1x1xindex>
      %1 = "arith.constant"() <{value = dense<0> : vector<1x1xindex>}> : () -> vector<1x1xindex>
      %2 = "arith.constant"() <{value = 32 : index}> : () -> index
      %3 = "arith.constant"() <{value = 1 : index}> : () -> index
      %4 = "arith.constant"() <{value = dense<128> : vector<1xindex>}> : () -> vector<1xindex>
      %5 = "arith.constant"() <{value = dense<32> : vector<1xindex>}> : () -> vector<1xindex>
      %6 = "arith.constant"() <{value = 0 : index}> : () -> index
      %7 = "arith.constant"() <{value = dense<0> : vector<1xindex>}> : () -> vector<1xindex>
      %8 = "arith.constant"() <{value = dense<0.000000e+00> : vector<1x1x1xf32>}> : () -> vector<1x1x1xf32>
      %9 = "arith.constant"() <{value = dense<true> : vector<1x1x1xi1>}> : () -> vector<1x1x1xi1>
      %10 = "arith.constant"() <{value = dense<1> : vector<1xindex>}> : () -> vector<1xindex>
      %11 = "arith.constant"() <{value = dense<2> : vector<1xindex>}> : () -> vector<1xindex>
      %12 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
      %13 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
      %14 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
      %15 = "arith.index_castui"(%12) : (i32) -> index
      %16 = "arith.index_castui"(%13) : (i32) -> index
      %17 = "arith.index_castui"(%14) : (i32) -> index
      %18:3 = "util.assume.int"(%15, %16, %17) <{assumptions = [[#util.int.assumption<umin = 134742016, umax = 2281177088>], [#util.int.assumption<umin = 134479872, umax = 1207697408>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>]]}> : (index, index, index) -> (index, index, index)
      %19 = "hal.interface.binding.subspan"(%6) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<131072x128xf32, #gpu.address_space<global>>
      "memref.assume_alignment"(%19) <{alignment = 64 : i32}> : (memref<131072x128xf32, #gpu.address_space<global>>) -> ()
      %20 = "hal.interface.binding.subspan"(%18#0, %18#2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x32x128xf16, strided<[?, 4096, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%20) <{alignment = 1 : i32}> : (memref<1x?x32x128xf16, strided<[?, 4096, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %21 = "hal.interface.binding.subspan"(%18#1, %18#2) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x32x128xf16, strided<[4096, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%21) <{alignment = 1 : i32}> : (memref<?x32x128xf16, strided<[4096, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %22 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
      %23 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
      %24 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<1x131040x32x128xf32, #gpu.address_space<workgroup>>
      %25 = "memref.reinterpret_cast"(%24, %18#2) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0>, static_sizes = array<i64: 1, -9223372036854775808, 32, 128>, static_strides = array<i64: 536739840, 4096, 128, 1>}> : (memref<1x131040x32x128xf32, #gpu.address_space<workgroup>>, index) -> memref<1x?x32x128xf32, strided<[536739840, 4096, 128, 1]>, #gpu.address_space<workgroup>>
      %26 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
      "scf.for"(%6, %18#2, %3) ({
      ^bb0(%arg0: index):
        "scf.for"(%6, %2, %3) ({
        ^bb0(%arg1: index):
          %65 = "memref.load"(%20, %6, %arg0, %arg1, %26) <{nontemporal = false}> : (memref<1x?x32x128xf16, strided<[?, 4096, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index, index) -> f16
          %66 = "arith.extf"(%65) : (f16) -> f32
          "memref.store"(%66, %24, %6, %arg0, %arg1, %26) <{nontemporal = false}> : (f32, memref<1x131040x32x128xf32, #gpu.address_space<workgroup>>, index, index, index, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "gpu.barrier"() : () -> ()
      %27 = "memref.load"(%19, %22, %26) <{nontemporal = false}> : (memref<131072x128xf32, #gpu.address_space<global>>, index, index) -> f32
      %28 = "vector.broadcast"(%27) : (f32) -> vector<1xf32>
      %29 = "vector.step"() : () -> vector<1xindex>
      %30 = "vector.splat"(%22) : (index) -> vector<1xindex>
      %31 = "arith.addi"(%29, %30) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %32 = "vector.splat"(%23) : (index) -> vector<1xindex>
      %33 = "arith.addi"(%29, %32) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %34 = "vector.splat"(%26) : (index) -> vector<1xindex>
      %35 = "arith.addi"(%34, %29) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %36 = "arith.divui"(%35, %11) : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %37 = "arith.remui"(%35, %11) : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %38 = "math.cos"(%28) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>) -> vector<1xf32>
      %39 = "math.sin"(%28) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>) -> vector<1xf32>
      %40 = "arith.muli"(%36, %11) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %41 = "arith.addi"(%40, %10) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %42 = "arith.muli"(%31, %5) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %43 = "arith.addi"(%33, %42) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %44 = "arith.muli"(%43, %4) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %45 = "arith.addi"(%40, %44) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %46 = "vector.insert"(%45, %1) <{static_position = array<i64: 0>}> : (vector<1xindex>, vector<1x1xindex>) -> vector<1x1xindex>
      %47 = "vector.insert"(%46, %0) <{static_position = array<i64: 0>}> : (vector<1x1xindex>, vector<1x1x1xindex>) -> vector<1x1x1xindex>
      %48 = "vector.gather"(%25, %6, %6, %6, %6, %47, %9, %8) : (memref<1x?x32x128xf32, strided<[536739840, 4096, 128, 1]>, #gpu.address_space<workgroup>>, index, index, index, index, vector<1x1x1xindex>, vector<1x1x1xi1>, vector<1x1x1xf32>) -> vector<1x1x1xf32>
      %49 = "arith.addi"(%41, %44) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %50 = "vector.insert"(%49, %1) <{static_position = array<i64: 0>}> : (vector<1xindex>, vector<1x1xindex>) -> vector<1x1xindex>
      %51 = "vector.insert"(%50, %0) <{static_position = array<i64: 0>}> : (vector<1x1xindex>, vector<1x1x1xindex>) -> vector<1x1x1xindex>
      %52 = "vector.gather"(%25, %6, %6, %6, %6, %51, %9, %8) : (memref<1x?x32x128xf32, strided<[536739840, 4096, 128, 1]>, #gpu.address_space<workgroup>>, index, index, index, index, vector<1x1x1xindex>, vector<1x1x1xi1>, vector<1x1x1xf32>) -> vector<1x1x1xf32>
      %53 = "arith.cmpi"(%37, %7) <{predicate = 0 : i64}> : (vector<1xindex>, vector<1xindex>) -> vector<1xi1>
      %54 = "vector.extract"(%48) <{static_position = array<i64: 0, 0>}> : (vector<1x1x1xf32>) -> vector<1xf32>
      %55 = "arith.mulf"(%54, %38) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %56 = "vector.extract"(%52) <{static_position = array<i64: 0, 0>}> : (vector<1x1x1xf32>) -> vector<1xf32>
      %57 = "arith.mulf"(%56, %39) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %58 = "arith.subf"(%55, %57) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %59 = "arith.mulf"(%56, %38) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %60 = "arith.mulf"(%54, %39) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %61 = "arith.addf"(%59, %60) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %62 = "arith.select"(%53, %58, %61) : (vector<1xi1>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %63 = "arith.truncf"(%62) : (vector<1xf32>) -> vector<1xf16>
      %64 = "vector.extract"(%63) <{static_position = array<i64: 0>}> : (vector<1xf16>) -> f16
      "memref.store"(%64, %21, %22, %23, %26) <{nontemporal = false}> : (f16, memref<?x32x128xf16, strided<[4096, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> ()
      "func.return"() : () -> ()
    }) : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "rocm_hsaco_fb", target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>} : () -> ()
/home/sosa/export/model.mlir:1154:12: error: 'func.func' op uses -131072 bytes of shared memory; exceeded the limit of 65536 bytes
    %390 = torch.prims.convert_element_type %389, %int5_110 : !torch.vtensor<[1,?,8,128],f32>, !torch.int -> !torch.vtensor<[1,?,8,128],f16>
           ^
/home/sosa/export/model.mlir:1154:12: note: see current operation:
"func.func"() <{function_type = () -> (), sym_name = "prefill_bs1$async_dispatch_8_elementwise_broadcast_Dx8x128_f32xf16"}> ({
  %0 = "arith.constant"() <{value = dense<0> : vector<1x1x1xindex>}> : () -> vector<1x1x1xindex>
  %1 = "arith.constant"() <{value = dense<0> : vector<1x1xindex>}> : () -> vector<1x1xindex>
  %2 = "arith.constant"() <{value = 8 : index}> : () -> index
  %3 = "arith.constant"() <{value = 1 : index}> : () -> index
  %4 = "arith.constant"() <{value = dense<128> : vector<1xindex>}> : () -> vector<1xindex>
  %5 = "arith.constant"() <{value = dense<8> : vector<1xindex>}> : () -> vector<1xindex>
  %6 = "arith.constant"() <{value = 0 : index}> : () -> index
  %7 = "arith.constant"() <{value = dense<0> : vector<1xindex>}> : () -> vector<1xindex>
  %8 = "arith.constant"() <{value = dense<0.000000e+00> : vector<1x1x1xf32>}> : () -> vector<1x1x1xf32>
  %9 = "arith.constant"() <{value = dense<true> : vector<1x1x1xi1>}> : () -> vector<1x1x1xi1>
  %10 = "arith.constant"() <{value = dense<1> : vector<1xindex>}> : () -> vector<1xindex>
  %11 = "arith.constant"() <{value = dense<2> : vector<1xindex>}> : () -> vector<1xindex>
  %12 = "arith.constant"() <{value = 67108864 : index}> : () -> index
  %13 = "arith.constant"() <{value = 32 : i64}> : () -> i64
  %14 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
  %15 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
  %16 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
  %17 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
  %18 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 4 : index} : () -> i32
  %19 = "arith.extui"(%14) : (i32) -> i64
  %20 = "arith.extui"(%15) : (i32) -> i64
  %21 = "arith.shli"(%20, %13) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  %22 = "arith.ori"(%19, %21) : (i64, i64) -> i64
  %23 = "arith.index_castui"(%22) : (i64) -> index
  %24 = "arith.extui"(%16) : (i32) -> i64
  %25 = "arith.extui"(%17) : (i32) -> i64
  %26 = "arith.shli"(%25, %13) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  %27 = "arith.ori"(%24, %26) : (i64, i64) -> i64
  %28 = "arith.index_castui"(%27) : (i64) -> index
  %29 = "arith.index_castui"(%18) : (i32) -> index
  %30:3 = "util.assume.int"(%23, %28, %29) <{assumptions = [[#util.int.assumption<umin = 135528448, umax = 5501616128>], [#util.int.assumption<umin = 135659520, umax = 6038355968>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>]]}> : (index, index, index) -> (index, index, index)
  %31 = "hal.interface.binding.subspan"(%12) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<131072x128xf32, strided<[128, 1], offset: 16777216>, #gpu.address_space<global>>
  "memref.assume_alignment"(%31) <{alignment = 64 : i32}> : (memref<131072x128xf32, strided<[128, 1], offset: 16777216>, #gpu.address_space<global>>) -> ()
  %32 = "hal.interface.binding.subspan"(%30#0, %30#2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x8x128xf16, strided<[?, 1024, 128, 1], offset: ?>, #gpu.address_space<global>>
  "memref.assume_alignment"(%32) <{alignment = 1 : i32}> : (memref<1x?x8x128xf16, strided<[?, 1024, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
  %33 = "hal.interface.binding.subspan"(%30#1, %30#2) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x8x128xf16, strided<[1024, 128, 1], offset: ?>, #gpu.address_space<global>>
  "memref.assume_alignment"(%33) <{alignment = 1 : i32}> : (memref<?x8x128xf16, strided<[1024, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
  %34 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
  %35 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
  %36 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<1x131040x8x128xf32, #gpu.address_space<workgroup>>
  %37 = "memref.reinterpret_cast"(%36, %30#2) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0>, static_sizes = array<i64: 1, -9223372036854775808, 8, 128>, static_strides = array<i64: 134184960, 1024, 128, 1>}> : (memref<1x131040x8x128xf32, #gpu.address_space<workgroup>>, index) -> memref<1x?x8x128xf32, strided<[134184960, 1024, 128, 1]>, #gpu.address_space<workgroup>>
  %38 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
  "scf.for"(%6, %30#2, %3) ({
  ^bb0(%arg0: index):
    "scf.for"(%6, %2, %3) ({
    ^bb0(%arg1: index):
      %77 = "memref.load"(%32, %6, %arg0, %arg1, %38) <{nontemporal = false}> : (memref<1x?x8x128xf16, strided<[?, 1024, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index, index) -> f16
      %78 = "arith.extf"(%77) : (f16) -> f32
      "memref.store"(%78, %36, %6, %arg0, %arg1, %38) <{nontemporal = false}> : (f32, memref<1x131040x8x128xf32, #gpu.address_space<workgroup>>, index, index, index, index) -> ()
      "scf.yield"() : () -> ()
    }) : (index, index, index) -> ()
    "scf.yield"() : () -> ()
  }) : (index, index, index) -> ()
  "gpu.barrier"() : () -> ()
  %39 = "memref.load"(%31, %34, %38) <{nontemporal = false}> : (memref<131072x128xf32, strided<[128, 1], offset: 16777216>, #gpu.address_space<global>>, index, index) -> f32
  %40 = "vector.broadcast"(%39) : (f32) -> vector<1xf32>
  %41 = "vector.step"() : () -> vector<1xindex>
  %42 = "vector.splat"(%34) : (index) -> vector<1xindex>
  %43 = "arith.addi"(%41, %42) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %44 = "vector.splat"(%35) : (index) -> vector<1xindex>
  %45 = "arith.addi"(%41, %44) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %46 = "vector.splat"(%38) : (index) -> vector<1xindex>
  %47 = "arith.addi"(%46, %41) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %48 = "arith.divui"(%47, %11) : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %49 = "arith.remui"(%47, %11) : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %50 = "math.cos"(%40) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>) -> vector<1xf32>
  %51 = "math.sin"(%40) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>) -> vector<1xf32>
  %52 = "arith.muli"(%48, %11) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %53 = "arith.addi"(%52, %10) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %54 = "arith.muli"(%43, %5) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %55 = "arith.addi"(%45, %54) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %56 = "arith.muli"(%55, %4) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %57 = "arith.addi"(%52, %56) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %58 = "vector.insert"(%57, %1) <{static_position = array<i64: 0>}> : (vector<1xindex>, vector<1x1xindex>) -> vector<1x1xindex>
  %59 = "vector.insert"(%58, %0) <{static_position = array<i64: 0>}> : (vector<1x1xindex>, vector<1x1x1xindex>) -> vector<1x1x1xindex>
  %60 = "vector.gather"(%37, %6, %6, %6, %6, %59, %9, %8) : (memref<1x?x8x128xf32, strided<[134184960, 1024, 128, 1]>, #gpu.address_space<workgroup>>, index, index, index, index, vector<1x1x1xindex>, vector<1x1x1xi1>, vector<1x1x1xf32>) -> vector<1x1x1xf32>
  %61 = "arith.addi"(%53, %56) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
  %62 = "vector.insert"(%61, %1) <{static_position = array<i64: 0>}> : (vector<1xindex>, vector<1x1xindex>) -> vector<1x1xindex>
  %63 = "vector.insert"(%62, %0) <{static_position = array<i64: 0>}> : (vector<1x1xindex>, vector<1x1x1xindex>) -> vector<1x1x1xindex>
  %64 = "vector.gather"(%37, %6, %6, %6, %6, %63, %9, %8) : (memref<1x?x8x128xf32, strided<[134184960, 1024, 128, 1]>, #gpu.address_space<workgroup>>, index, index, index, index, vector<1x1x1xindex>, vector<1x1x1xi1>, vector<1x1x1xf32>) -> vector<1x1x1xf32>
  %65 = "arith.cmpi"(%49, %7) <{predicate = 0 : i64}> : (vector<1xindex>, vector<1xindex>) -> vector<1xi1>
  %66 = "vector.extract"(%60) <{static_position = array<i64: 0, 0>}> : (vector<1x1x1xf32>) -> vector<1xf32>
  %67 = "arith.mulf"(%66, %50) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %68 = "vector.extract"(%64) <{static_position = array<i64: 0, 0>}> : (vector<1x1x1xf32>) -> vector<1xf32>
  %69 = "arith.mulf"(%68, %51) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %70 = "arith.subf"(%67, %69) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %71 = "arith.mulf"(%68, %50) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %72 = "arith.mulf"(%66, %51) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %73 = "arith.addf"(%71, %72) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %74 = "arith.select"(%65, %70, %73) : (vector<1xi1>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
  %75 = "arith.truncf"(%74) : (vector<1xf32>) -> vector<1xf16>
  %76 = "vector.extract"(%75) <{static_position = array<i64: 0>}> : (vector<1xf16>) -> f16
  "memref.store"(%76, %33, %34, %35, %38) <{nontemporal = false}> : (f16, memref<?x8x128xf16, strided<[1024, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> ()
  "func.return"() : () -> ()
}) : () -> ()
/home/sosa/export/model.mlir:1154:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
    %390 = torch.prims.convert_element_type %389, %int5_110 : !torch.vtensor<[1,?,8,128],f32>, !torch.int -> !torch.vtensor<[1,?,8,128],f16>
           ^
/home/sosa/export/model.mlir:1154:12: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg2: !hal.device, %arg3: index):
    %79 = "arith.constant"() <{value = 1 : index}> : () -> index
    %80 = "arith.constant"() <{value = 8 : index}> : () -> index
    "hal.return"(%80, %arg3, %79) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index, subgroup_size = 64 : index, sym_name = "prefill_bs1$async_dispatch_8_elementwise_broadcast_Dx8x128_f32xf16", workgroup_size = [128 : index, 1 : index, 1 : index]} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "prefill_bs1$async_dispatch_8_elementwise_broadcast_Dx8x128_f32xf16"}> ({
      %0 = "arith.constant"() <{value = dense<0> : vector<1x1x1xindex>}> : () -> vector<1x1x1xindex>
      %1 = "arith.constant"() <{value = dense<0> : vector<1x1xindex>}> : () -> vector<1x1xindex>
      %2 = "arith.constant"() <{value = 8 : index}> : () -> index
      %3 = "arith.constant"() <{value = 1 : index}> : () -> index
      %4 = "arith.constant"() <{value = dense<128> : vector<1xindex>}> : () -> vector<1xindex>
      %5 = "arith.constant"() <{value = dense<8> : vector<1xindex>}> : () -> vector<1xindex>
      %6 = "arith.constant"() <{value = 0 : index}> : () -> index
      %7 = "arith.constant"() <{value = dense<0> : vector<1xindex>}> : () -> vector<1xindex>
      %8 = "arith.constant"() <{value = dense<0.000000e+00> : vector<1x1x1xf32>}> : () -> vector<1x1x1xf32>
      %9 = "arith.constant"() <{value = dense<true> : vector<1x1x1xi1>}> : () -> vector<1x1x1xi1>
      %10 = "arith.constant"() <{value = dense<1> : vector<1xindex>}> : () -> vector<1xindex>
      %11 = "arith.constant"() <{value = dense<2> : vector<1xindex>}> : () -> vector<1xindex>
      %12 = "arith.constant"() <{value = 67108864 : index}> : () -> index
      %13 = "arith.constant"() <{value = 32 : i64}> : () -> i64
      %14 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
      %15 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
      %16 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
      %17 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
      %18 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 4 : index} : () -> i32
      %19 = "arith.extui"(%14) : (i32) -> i64
      %20 = "arith.extui"(%15) : (i32) -> i64
      %21 = "arith.shli"(%20, %13) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %22 = "arith.ori"(%19, %21) : (i64, i64) -> i64
      %23 = "arith.index_castui"(%22) : (i64) -> index
      %24 = "arith.extui"(%16) : (i32) -> i64
      %25 = "arith.extui"(%17) : (i32) -> i64
      %26 = "arith.shli"(%25, %13) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %27 = "arith.ori"(%24, %26) : (i64, i64) -> i64
      %28 = "arith.index_castui"(%27) : (i64) -> index
      %29 = "arith.index_castui"(%18) : (i32) -> index
      %30:3 = "util.assume.int"(%23, %28, %29) <{assumptions = [[#util.int.assumption<umin = 135528448, umax = 5501616128>], [#util.int.assumption<umin = 135659520, umax = 6038355968>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>]]}> : (index, index, index) -> (index, index, index)
      %31 = "hal.interface.binding.subspan"(%12) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<131072x128xf32, strided<[128, 1], offset: 16777216>, #gpu.address_space<global>>
      "memref.assume_alignment"(%31) <{alignment = 64 : i32}> : (memref<131072x128xf32, strided<[128, 1], offset: 16777216>, #gpu.address_space<global>>) -> ()
      %32 = "hal.interface.binding.subspan"(%30#0, %30#2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x8x128xf16, strided<[?, 1024, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%32) <{alignment = 1 : i32}> : (memref<1x?x8x128xf16, strided<[?, 1024, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %33 = "hal.interface.binding.subspan"(%30#1, %30#2) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x8x128xf16, strided<[1024, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%33) <{alignment = 1 : i32}> : (memref<?x8x128xf16, strided<[1024, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %34 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
      %35 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
      %36 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<1x131040x8x128xf32, #gpu.address_space<workgroup>>
      %37 = "memref.reinterpret_cast"(%36, %30#2) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0>, static_sizes = array<i64: 1, -9223372036854775808, 8, 128>, static_strides = array<i64: 134184960, 1024, 128, 1>}> : (memref<1x131040x8x128xf32, #gpu.address_space<workgroup>>, index) -> memref<1x?x8x128xf32, strided<[134184960, 1024, 128, 1]>, #gpu.address_space<workgroup>>
      %38 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
      "scf.for"(%6, %30#2, %3) ({
      ^bb0(%arg0: index):
        "scf.for"(%6, %2, %3) ({
        ^bb0(%arg1: index):
          %77 = "memref.load"(%32, %6, %arg0, %arg1, %38) <{nontemporal = false}> : (memref<1x?x8x128xf16, strided<[?, 1024, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index, index) -> f16
          %78 = "arith.extf"(%77) : (f16) -> f32
          "memref.store"(%78, %36, %6, %arg0, %arg1, %38) <{nontemporal = false}> : (f32, memref<1x131040x8x128xf32, #gpu.address_space<workgroup>>, index, index, index, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "gpu.barrier"() : () -> ()
      %39 = "memref.load"(%31, %34, %38) <{nontemporal = false}> : (memref<131072x128xf32, strided<[128, 1], offset: 16777216>, #gpu.address_space<global>>, index, index) -> f32
      %40 = "vector.broadcast"(%39) : (f32) -> vector<1xf32>
      %41 = "vector.step"() : () -> vector<1xindex>
      %42 = "vector.splat"(%34) : (index) -> vector<1xindex>
      %43 = "arith.addi"(%41, %42) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %44 = "vector.splat"(%35) : (index) -> vector<1xindex>
      %45 = "arith.addi"(%41, %44) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %46 = "vector.splat"(%38) : (index) -> vector<1xindex>
      %47 = "arith.addi"(%46, %41) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %48 = "arith.divui"(%47, %11) : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %49 = "arith.remui"(%47, %11) : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %50 = "math.cos"(%40) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>) -> vector<1xf32>
      %51 = "math.sin"(%40) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>) -> vector<1xf32>
      %52 = "arith.muli"(%48, %11) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %53 = "arith.addi"(%52, %10) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %54 = "arith.muli"(%43, %5) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %55 = "arith.addi"(%45, %54) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %56 = "arith.muli"(%55, %4) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %57 = "arith.addi"(%52, %56) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %58 = "vector.insert"(%57, %1) <{static_position = array<i64: 0>}> : (vector<1xindex>, vector<1x1xindex>) -> vector<1x1xindex>
      %59 = "vector.insert"(%58, %0) <{static_position = array<i64: 0>}> : (vector<1x1xindex>, vector<1x1x1xindex>) -> vector<1x1x1xindex>
      %60 = "vector.gather"(%37, %6, %6, %6, %6, %59, %9, %8) : (memref<1x?x8x128xf32, strided<[134184960, 1024, 128, 1]>, #gpu.address_space<workgroup>>, index, index, index, index, vector<1x1x1xindex>, vector<1x1x1xi1>, vector<1x1x1xf32>) -> vector<1x1x1xf32>
      %61 = "arith.addi"(%53, %56) <{overflowFlags = #arith.overflow<none>}> : (vector<1xindex>, vector<1xindex>) -> vector<1xindex>
      %62 = "vector.insert"(%61, %1) <{static_position = array<i64: 0>}> : (vector<1xindex>, vector<1x1xindex>) -> vector<1x1xindex>
      %63 = "vector.insert"(%62, %0) <{static_position = array<i64: 0>}> : (vector<1x1xindex>, vector<1x1x1xindex>) -> vector<1x1x1xindex>
      %64 = "vector.gather"(%37, %6, %6, %6, %6, %63, %9, %8) : (memref<1x?x8x128xf32, strided<[134184960, 1024, 128, 1]>, #gpu.address_space<workgroup>>, index, index, index, index, vector<1x1x1xindex>, vector<1x1x1xi1>, vector<1x1x1xf32>) -> vector<1x1x1xf32>
      %65 = "arith.cmpi"(%49, %7) <{predicate = 0 : i64}> : (vector<1xindex>, vector<1xindex>) -> vector<1xi1>
      %66 = "vector.extract"(%60) <{static_position = array<i64: 0, 0>}> : (vector<1x1x1xf32>) -> vector<1xf32>
      %67 = "arith.mulf"(%66, %50) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %68 = "vector.extract"(%64) <{static_position = array<i64: 0, 0>}> : (vector<1x1x1xf32>) -> vector<1xf32>
      %69 = "arith.mulf"(%68, %51) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %70 = "arith.subf"(%67, %69) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %71 = "arith.mulf"(%68, %50) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %72 = "arith.mulf"(%66, %51) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %73 = "arith.addf"(%71, %72) <{fastmath = #arith.fastmath<none>}> : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %74 = "arith.select"(%65, %70, %73) : (vector<1xi1>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
      %75 = "arith.truncf"(%74) : (vector<1xf32>) -> vector<1xf16>
      %76 = "vector.extract"(%75) <{static_position = array<i64: 0>}> : (vector<1xf16>) -> f16
      "memref.store"(%76, %33, %34, %35, %38) <{nontemporal = false}> : (f16, memref<?x8x128xf16, strided<[1024, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> ()
      "func.return"() : () -> ()
    }) : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "rocm_hsaco_fb", target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>} : () -> ()
/home/sosa/export/model.mlir:19105:14: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
    %550:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%547, %548, %549, %float0.000000e00, %false_277, %370, %none_278) : (!torch.vtensor<[1,32,1,128],f16>, !torch.vtensor<[1,32,?,128],f16>, !torch.vtensor<[1,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[1,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[1,32,1,128],f16>, !torch.vtensor<[1,32,1],f32>)
             ^
/home/sosa/export/model.mlir:19105:14: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg1: !hal.device, %arg2: index):
    %28 = "arith.constant"() <{value = 4 : index}> : () -> index
    %29 = "arith.constant"() <{value = 8 : index}> : () -> index
    %30 = "arith.constant"() <{value = 1 : index}> : () -> index
    "hal.return"(%28, %29, %30) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index, subgroup_size = 64 : index, sym_name = "decode_bs1$async_dispatch_23_attention_1xDx8x1x4x128xf16_dispatch_tensor_store", workgroup_size = [128 : index, 1 : index, 1 : index]} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "decode_bs1$async_dispatch_23_attention_1xDx8x1x4x128xf16_dispatch_tensor_store"}> ({
      %0 = "arith.constant"() <{value = 32 : index}> : () -> index
      %1 = "arith.constant"() <{value = 8.837890e-02 : f16}> : () -> f16
      %2 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
      %3 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
      %4 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
      %5 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
      %6 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 4 : index} : () -> i32
      %7 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 5 : index} : () -> i32
      %8 = "arith.index_castui"(%2) : (i32) -> index
      %9 = "arith.index_castui"(%3) : (i32) -> index
      %10 = "arith.index_castui"(%4) : (i32) -> index
      %11 = "arith.index_castui"(%5) : (i32) -> index
      %12 = "arith.index_castui"(%6) : (i32) -> index
      %13 = "arith.index_castui"(%7) : (i32) -> index
      %14:6 = "util.assume.int"(%8, %9, %10, %11, %12, %13) <{assumptions = [[#util.int.assumption<umin = 24640, umax = 24640, udiv = 24640>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>, #util.int.assumption<umin = 30976, umax = 30976, udiv = 30976>], [#util.int.assumption<umin = 67248768, umax = 603922944>], [#util.int.assumption<umin = 67510912, umax = 1677402624>], [#util.int.assumption<umin = 67773056, umax = 2750882304>], [#util.int.assumption<umin = 14592, umax = 14592, udiv = 14592>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>, #util.int.assumption<umin = 22784, umax = 22784, udiv = 22784>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>]]}> : (index, index, index, index, index, index) -> (index, index, index, index, index, index)
      %15 = "hal.interface.binding.subspan"(%14#0) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<8x1x4x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%15) <{alignment = 1 : i32}> : (memref<8x1x4x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %16 = "hal.interface.binding.subspan"(%14#4) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<8x1x4x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%16) <{alignment = 1 : i32}> : (memref<8x1x4x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %17 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
      %18 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
      %19 = "memref.subview"(%16, %17, %18) <{operandSegmentSizes = array<i32: 1, 2, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, -9223372036854775808, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (memref<8x1x4x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index) -> memref<1x1x1x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      %20 = "arith.divui"(%14#5, %0) : (index, index) -> index
      %21 = "hal.interface.binding.subspan"(%14#1, %20) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x32x8x1x4x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%21) <{alignment = 1 : i32}> : (memref<1x?x32x8x1x4x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %22 = "hal.interface.binding.subspan"(%14#2, %20) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x32x8x1x4x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%22) <{alignment = 1 : i32}> : (memref<1x?x32x8x1x4x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %23 = "hal.interface.binding.subspan"(%14#3, %20) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<8x1x4x1x1x?x32xf16, strided<[?, ?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%23) <{alignment = 1 : i32}> : (memref<8x1x4x1x1x?x32xf16, strided<[?, ?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %24 = "memref.subview"(%15, %17, %18) <{operandSegmentSizes = array<i32: 1, 2, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, -9223372036854775808, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (memref<8x1x4x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index) -> memref<1x1x1x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      %25 = "memref.subview"(%21, %17, %18, %20) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: 0, 0, 0, -9223372036854775808, 0, -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808, 32, 1, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1, 1>}> : (memref<1x?x32x8x1x4x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> memref<1x?x32x1x1x1x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      %26 = "memref.subview"(%22, %17, %18, %20) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: 0, 0, 0, -9223372036854775808, 0, -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808, 32, 1, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1, 1>}> : (memref<1x?x32x8x1x4x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> memref<1x?x32x1x1x1x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      %27 = "memref.subview"(%23, %17, %18, %20) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0, -9223372036854775808, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 1, -9223372036854775808, 32>, static_strides = array<i64: 1, 1, 1, 1, 1, 1, 1>}> : (memref<8x1x4x1x1x?x32xf16, strided<[?, ?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> memref<1x1x1x1x1x?x32xf16, strided<[?, ?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>
      "iree_linalg_ext.attention"(%24, %25, %26, %1, %27, %19) <{indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8, d0, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8, d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d6, d7, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>]}> ({
      ^bb0(%arg0: f32):
        "iree_linalg_ext.yield"(%arg0) : (f32) -> ()
      }) : (memref<1x1x1x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, memref<1x?x32x1x1x1x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, memref<1x?x32x1x1x1x128xf16, strided<[?, 131072, 4096, 512, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, f16, memref<1x1x1x1x1x?x32xf16, strided<[?, ?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>, memref<1x1x1x1x128xf16, strided<[512, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      "func.return"() : () -> ()
    }) : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "rocm_hsaco_fb", target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>} : () -> ()
/home/sosa/export/model.mlir:53831:14: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
    %551:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%548, %549, %550, %float0.000000e00, %false_278, %370, %none_279) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>)
             ^
/home/sosa/export/model.mlir:53831:14: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg1: !hal.device, %arg2: index, %arg3: index, %arg4: index):
    %38 = "arith.constant"() <{value = 4 : index}> : () -> index
    %39 = "arith.constant"() <{value = 8 : index}> : () -> index
    "hal.return"(%38, %39, %38) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index, subgroup_size = 64 : index, sym_name = "decode_bs4$async_dispatch_31_attention_4xDx32x8x4x128xf16_dispatch_tensor_store", workgroup_size = [128 : index, 1 : index, 1 : index]} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "decode_bs4$async_dispatch_31_attention_4xDx32x8x4x128xf16_dispatch_tensor_store"}> ({
      %0 = "arith.constant"() <{value = 32 : i64}> : () -> i64
      %1 = "arith.constant"() <{value = 8.837890e-02 : f16}> : () -> f16
      %2 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
      %3 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
      %4 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
      %5 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
      %6 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 4 : index} : () -> i32
      %7 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 5 : index} : () -> i32
      %8 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 6 : index} : () -> i32
      %9 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 7 : index} : () -> i32
      %10 = "arith.index_castui"(%2) : (i32) -> index
      %11 = "arith.index_castui"(%3) : (i32) -> index
      %12 = "arith.extui"(%4) : (i32) -> i64
      %13 = "arith.extui"(%5) : (i32) -> i64
      %14 = "arith.shli"(%13, %0) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %15 = "arith.ori"(%12, %14) : (i64, i64) -> i64
      %16 = "arith.index_castui"(%15) : (i64) -> index
      %17 = "arith.extui"(%6) : (i32) -> i64
      %18 = "arith.extui"(%7) : (i32) -> i64
      %19 = "arith.shli"(%18, %0) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %20 = "arith.ori"(%17, %19) : (i64, i64) -> i64
      %21 = "arith.index_castui"(%20) : (i64) -> index
      %22 = "arith.index_castui"(%8) : (i32) -> index
      %23 = "arith.index_castui"(%9) : (i32) -> index
      %24:6 = "util.assume.int"(%10, %11, %16, %21, %22, %23) <{assumptions = [[#util.int.assumption<umin = 81920, umax = 81920, udiv = 81920>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>, #util.int.assumption<umin = 115008, umax = 115008, udiv = 115008>], [#util.int.assumption<umin = 67668160, umax = 2214496256>], [#util.int.assumption<umin = 68716736, umax = 6508414976>], [#util.int.assumption<umin = 69765312, umax = 10802333696>], [#util.int.assumption<umin = 320, umax = 320, udiv = 320>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>, #util.int.assumption<umin = 33088, umax = 33088, udiv = 33088>], [#util.int.assumption<umin = 1, umax = 4095>]]}> : (index, index, index, index, index, index) -> (index, index, index, index, index, index)
      %25 = "hal.interface.binding.subspan"(%24#0) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<4x8x4x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%25) <{alignment = 1 : i32}> : (memref<4x8x4x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %26 = "hal.interface.binding.subspan"(%24#4) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<4x8x4x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%26) <{alignment = 1 : i32}> : (memref<4x8x4x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %27 = "hal.interface.workgroup.id"() {dimension = 2 : index} : () -> index
      %28 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
      %29 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
      %30 = "memref.subview"(%26, %27, %28, %29) <{operandSegmentSizes = array<i32: 1, 3, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (memref<4x8x4x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> memref<1x1x1x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      %31 = "hal.interface.binding.subspan"(%24#1, %24#5) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<4x?x32x8x4x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%31) <{alignment = 1 : i32}> : (memref<4x?x32x8x4x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %32 = "hal.interface.binding.subspan"(%24#2, %24#5) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<4x?x32x8x4x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%32) <{alignment = 1 : i32}> : (memref<4x?x32x8x4x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %33 = "hal.interface.binding.subspan"(%24#3, %24#5) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<4x8x4x1x?x32xf16, strided<[?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>
      "memref.assume_alignment"(%33) <{alignment = 1 : i32}> : (memref<4x8x4x1x?x32xf16, strided<[?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      %34 = "memref.subview"(%25, %27, %28, %29) <{operandSegmentSizes = array<i32: 1, 3, 0, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (memref<4x8x4x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index) -> memref<1x1x1x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
      %35 = "memref.subview"(%31, %27, %28, %29, %24#5) <{operandSegmentSizes = array<i32: 1, 3, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, -9223372036854775808, -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808, 32, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<4x?x32x8x4x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index, index) -> memref<1x?x32x1x1x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      %36 = "memref.subview"(%32, %27, %28, %29, %24#5) <{operandSegmentSizes = array<i32: 1, 3, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, -9223372036854775808, -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808, 32, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<4x?x32x8x4x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, index, index, index, index) -> memref<1x?x32x1x1x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>
      %37 = "memref.subview"(%33, %27, %28, %29, %24#5) <{operandSegmentSizes = array<i32: 1, 3, 1, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 1, 1, 1, -9223372036854775808, 32>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<4x8x4x1x?x32xf16, strided<[?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>, index, index, index, index) -> memref<1x1x1x1x?x32xf16, strided<[?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>
      "iree_linalg_ext.attention"(%34, %35, %36, %1, %37, %30) <{indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d6, d7, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d6, d7, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>]}> ({
      ^bb0(%arg0: f32):
        "iree_linalg_ext.yield"(%arg0) : (f32) -> ()
      }) : (memref<1x1x1x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, memref<1x?x32x1x1x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, memref<1x?x32x1x1x128xf16, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #gpu.address_space<global>>, f16, memref<1x1x1x1x?x32xf16, strided<[?, ?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>, memref<1x1x1x1x128xf16, strided<[4096, 512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>) -> ()
      "func.return"() : () -> ()
    }) : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "rocm_hsaco_fb", target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>} : () -> ()
(3.11.venv) sosa@sharkmi300x-3:~$ python -m sharktank.examples.export_paged_llm_v1   --gguf-file=$MODEL_PARAMS_PATH   --output-mlir=$MLIR_PATH   --output-config=$OUTPUT_C 

@DhruvDh
Copy link

DhruvDh commented Dec 23, 2024

I run into the same problem when only using the nightly release for the gfx1100 target: "func.func op uses -131072 bytes of shared memory; exceeded the limit of 65536 bytes"

`pip freeze`
➜  shark-ai git:(main) uv pip freeze
aiohappyeyeballs==2.4.4
aiohttp==3.11.11
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.7.0
attrs==24.3.0
certifi==2024.12.14
charset-normalizer==3.4.0
click==8.1.8
dataclasses-json==0.6.7
datasets==3.0.1
dill==0.3.8
fastapi==0.115.6
filelock==3.13.1
frozenlist==1.5.0
fsspec==2024.2.0
gguf==0.13.0
h11==0.14.0
huggingface-hub==0.22.2
idna==3.10
iree-base-compiler==3.0.0
iree-base-runtime==3.0.0
iree-turbine==3.0.0
jinja2==3.1.3
markupsafe==2.1.5
marshmallow==3.23.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
packaging==24.2
pandas==2.2.3
pillow==10.2.0
propcache==0.2.1
pyarrow==18.1.0
pydantic==2.10.4
pydantic-core==2.27.2
python-dateutil==2.9.0.post0
pytorch-triton-rocm==3.1.0
pytz==2024.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.4.6.dev0
sentencepiece==0.2.0
shark-ai==3.0.0
sharktank==3.1.0rc20241223
shortfin==3.1.0rc20241223
six==1.17.0
sniffio==1.3.1
starlette==0.41.3
sympy==1.13.1
tokenizers==0.19.1
torch==2.5.1+rocm6.2
torchaudio==2.5.1+rocm6.2
torchvision==0.20.1+rocm6.2
tqdm==4.67.1
transformers==4.40.0
triton==3.1.0
typing-extensions==4.9.0
typing-inspect==0.9.0
tzdata==2024.2
urllib3==2.3.0
uvicorn==0.34.0
xxhash==3.5.0
yarl==1.18.3

@amd-chrissosa
Copy link
Contributor Author

Associated CL landed: #751

@amd-chrissosa
Copy link
Contributor Author

SDXL is broken in testing on 1/3/2025: #753

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants