diff --git a/third_party/xla/xla/codegen/ir/BUILD b/third_party/xla/xla/codegen/ir/BUILD new file mode 100644 index 00000000000000..b0065ce84cc096 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/BUILD @@ -0,0 +1,143 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//xla/tests:build_defs.bzl", "xla_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +td_library( + name = "xla_td_files", + srcs = glob(["*.td"]), + includes = ["."], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:CallInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_dialect_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-dialect-decls"], + "xla_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "xla_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_dialect.td", + deps = [":xla_td_files"], +) + +gentbl_cc_library( + name = "xla_ops_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-op-decls"], + "xla_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "xla_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_ops.td", + deps = [":xla_td_files"], +) + +gentbl_cc_library( + name = "xla_attrs_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "xla_enums.h.inc", + ), + ( + ["-gen-enum-defs"], + "xla_enums.cc.inc", + ), + ( + [ + "-gen-attrdef-decls", + ], + "xla_attrs.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + ], + "xla_attrs.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_attrs.td", + deps = [":xla_td_files"], +) + +cc_library( + name = "xla", + srcs = [ + "xla_attrs.cc", + "xla_dialect.cc", + "xla_ops.cc", + ], + hdrs = ["xla_ops.h"], + deps = [ + ":xla_attrs_inc_gen", + ":xla_dialect_inc_gen", + ":xla_ops_inc_gen", + "//xla:status_macros", + "//xla/hlo/analysis:indexing_analysis", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:CallOpInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + ], +) + +xla_test( + name = "xla_ops_test", + srcs = ["xla_ops_test.cc"], + backends = ["cpu"], + deps = [ + ":xla", + "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/analysis:indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/codegen/ir/tests/BUILD b/third_party/xla/xla/codegen/ir/tests/BUILD new file mode 100644 index 00000000000000..10228fcc460af8 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/tests/BUILD @@ -0,0 +1,16 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/codegen/tools:emitters_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/codegen/ir/tests/canonicalize.mlir similarity index 60% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir rename to third_party/xla/xla/codegen/ir/tests/canonicalize.mlir index 08086e34f60b05..ae0e54c70f9ba4 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir +++ b/third_party/xla/xla/codegen/ir/tests/canonicalize.mlir @@ -1,26 +1,26 @@ -// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s +// RUN: emitters_opt %s --split-input-file -canonicalize | FileCheck %s -#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2]"> +#map0 = #xla.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2]"> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] + %0:2 = xla.apply_indexing #map0 [%s0, %s1] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1, d0 mod 2), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 + 1, d0 mod 2), // CHECK-SAME: domain: d0 in [-10, 10]"> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]]) +// CHECK: xla.apply_indexing #[[$MAP]](%[[ARG_0]]) // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]"> +#map0 = #xla.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]"> func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { - %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] + %0:3 = xla.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] func.return %0#0, %0#1, %0#2 : index, index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d2 + 1, d2 mod 2, d0 + d1), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1, d2) -> (d2 + 1, d2 mod 2, d0 + d1), // CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3], d2 in [-11, 11] // CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims @@ -29,18 +29,18 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: index, // CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index, // CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]] +// CHECK: xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[ARG_0]], %[[ARG_2]], %[[ARG_3]]) // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> +#map0 = #xla.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] + %0:5 = xla.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-LABEL: func.func @fold_indexing_map_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) @@ -48,49 +48,49 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]] +// CHECK: %[[NEW_RESULT:.*]] = xla.apply_indexing #[[$MAP]] // CHECK: return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]] // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)," +#map0 = #xla.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)," "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] + %0:5 = xla.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#2 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 2), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 mod 2), // CHECK-SAME: domain: d0 in [0, 2] // CHECK-LABEL: func.func @remove_unused_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]]) +// CHECK: %[[NEW_RESULT:.*]] = xla.apply_indexing #[[$MAP]](%[[ARG_1]]) // CHECK: return %[[NEW_RESULT]] // ----- -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)," +#map0 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)," "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]"> func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index %s1 = arith.constant 3 : index - %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1] + %0 = xla.apply_indexing #map0 (%d0, %d1)[%s0, %s1] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 3), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 + 3), // CHECK-SAME: domain: d0 in [0, 10] // CHECK-LABEL: func.func @fold_operands // CHECK-SAME: %[[ARG_0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]]) +// CHECK: xla.apply_indexing #[[$MAP]](%[[ARG_0]]) // ----- func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (0, d1)," + %0:2 = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> (0, d1)," "domain: d0 in [0, 4], d1 in [0, 5]">(%arg0, %arg1) return %0#0, %0#1 : index, index } @@ -103,96 +103,96 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) // ----- func.func @fold_sequence(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< + %0 = xla.apply_indexing #xla.indexing_map< "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 100 + 42)," + %1 = xla.apply_indexing #xla.indexing_map<"(d0) -> (d0 mod 100 + 42)," "domain: d0 in [0, 10000]">(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 + 42), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1) -> (d0 + d1 + 42), // CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4] // CHECK-LABEL: func.func @fold_sequence // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-NEXT: xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[ARG0]], %[[ARG1]]) // ----- func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), " + %0 = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> (d0 + d1), " "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map< + %1 = xla.apply_indexing #xla.indexing_map< "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]">(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 + 42), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1) -> (d0 + d1 + 42), // CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4] // CHECK-LABEL: func.func @fold_sequence_sym // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-NEXT: xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[ARG0]], %[[ARG1]]) // ----- -#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512)," +#indexing_map1 = #xla.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512)," "domain: d0 in [0, 1], d1 in [0, 607]"> -#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (" +#indexing_map2 = #xla.indexing_map<"(d0, d1, d2) -> (" "((d1 floordiv 32 + 1) mod 3) * 64 + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2)," "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @fold_sequence_no_simplification_needed(%i: index) -> index { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} - %ind0 = xla_gpu.apply_indexing #indexing_map1(%i, %thread_id_x) - %ind1 = xla_gpu.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) + %ind0 = xla.apply_indexing #indexing_map1(%i, %thread_id_x) + %ind1 = xla.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) func.return %ind1 : index } -// CHECK: xla_gpu.apply_indexing -// CHECK-NOT: xla_gpu.apply_indexing +// CHECK: xla.apply_indexing +// CHECK-NOT: xla.apply_indexing // ----- -#indexing_map1 = #xla_gpu.indexing_map< +#indexing_map1 = #xla.indexing_map< "(d0) -> (3 * d0), domain: d0 in [0, 9407]"> -#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1)," +#indexing_map2 = #xla.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1)," "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> -#indexing_map3 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2)," +#indexing_map3 = #xla.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2)," "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} - %ind0 = xla_gpu.apply_indexing #indexing_map1(%thread_id_x) - %ind1 = xla_gpu.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) - %ind2 = xla_gpu.apply_indexing #indexing_map3(%ind0, %thread_id_x, %i) + %ind0 = xla.apply_indexing #indexing_map1(%thread_id_x) + %ind1 = xla.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) + %ind2 = xla.apply_indexing #indexing_map3(%ind0, %thread_id_x, %i) func.return %ind1, %ind2 : index, index } -// CHECK-COUNT-3: xla_gpu.apply_indexing +// CHECK-COUNT-3: xla.apply_indexing // ----- func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," + %0 = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> (d0 + d1)," "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," + %1 = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> (d0 + d1)," "domain: d0 in [0, 4], d1 in [0, 10000]">(%arg1, %0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1) -> (d0 * 2 + d1), // CHECK-SAME: domain: d0 in [0, 4], d1 in [0, 5] // CHECK-LABEL: func.func @fold_sequence_shared_operands // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-NEXT: xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[ARG1]], %[[ARG0]]) // ----- func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index) -> (tensor<2x3xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + %ret = xla.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { ^bb0(%current : f32): - xla_gpu.yield %current : f32 + xla.yield %current : f32 } return %ret : tensor<2x3xf32> } @@ -206,9 +206,9 @@ func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index) func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) -> (tensor<2x3xf32>) { %cst = arith.constant 0.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + %ret = xla.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { ^bb0(%current : f32): - xla_gpu.yield %cst : f32 + xla.yield %cst : f32 } return %ret : tensor<2x3xf32> } @@ -216,65 +216,65 @@ func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32> // CHECK-NEXT: %[[CST:.*]] = arith.constant // CHECK-NEXT: atomic_rmw -// CHECK: xla_gpu.yield %[[CST]] +// CHECK: xla.yield %[[CST]] // ----- -#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," +#map0 = #xla.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," "domain: d0 in [0, 3], s0 in [0, 2]"> func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) -> index { - %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] + %0 = xla.apply_indexing #map0(%dim0)[%sym0] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> ((d0 * d1) * 2), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1) -> ((d0 * d1) * 2), // CHECK-SAME: domain: d0 in [0, 3], d1 in [0, 2] // CHECK-LABEL: func.func @apply_indexing_move_syms_to_dims -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-NEXT: xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[ARG0:.*]], %[[ARG1:.*]]) // // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," +#map0 = #xla.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { - %idx = xla_gpu.apply_indexing #map0(%dim) - %sum = xla_gpu.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) { + %idx = xla.apply_indexing #map0(%dim) + %sum = xla.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) { %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } func.return %sum : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 * 4 + s0, s1), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (d0 * 4 + s0, s1), // CHECK-SAME: domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32] // CHECK-LABEL: func.func @loop_of_apply_indexing // CHECK-SAME: %[[ARG0:.*]]: tensor<1024x32xf32>, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: index) -// CHECK: xla_gpu.loop (%[[ARG2]]) +// CHECK: xla.loop (%[[ARG2]]) // CHECK-SAME: in #[[$MAP]] // ----- -#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," +#map0 = #xla.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," "domain: d0 in [0, 3], s0 in [0, 2]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1)," +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1)," "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: tensor<1024x32xf32>, %init: f32) -> (f32) { - %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] - %sum = xla_gpu.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) { + %0 = xla.apply_indexing #map0(%dim0)[%sym0] + %sum = xla.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) { %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } func.return %sum : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> ((d0 * d1) * 2 + s0 + s1), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> ((d0 * d1) * 2 + s0 + s1), // CHECK-SAME: domain: d0 in [0, 3], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] // CHECK-LABEL: func.func @loop_of_apply_indexing_with_syms // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index -// CHECK: xla_gpu.loop (%[[ARG0]], %[[ARG1]]) +// CHECK: xla.loop (%[[ARG0]], %[[ARG1]]) // CHECK-SAME: in #[[$MAP]] diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir b/third_party/xla/xla/codegen/ir/tests/inlining.mlir similarity index 74% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir rename to third_party/xla/xla/codegen/ir/tests/inlining.mlir index f15b37b040b848..fac0ff321778f4 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir +++ b/third_party/xla/xla/codegen/ir/tests/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-erase-dead-functions -inline | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-erase-dead-functions -inline | FileCheck %s module { func.func private @mul(%a: f32, %b: f32) -> f32 { @@ -8,21 +8,21 @@ module { func.func private @add(%a: f32, %b: f32) -> f32 { %add = arith.addf %a, %b : f32 - %ret = xla_gpu.pure_call @mul(%add, %add) : (f32, f32) -> (f32) + %ret = xla.pure_call @mul(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } func.func @caller(%a: f32, %b: f32) -> f32 { - %ret = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) return %ret : f32 } } // CHECK-LABEL: module { // CHECK: @caller -// CHECK-NOT: xla_gpu.pure_call @add +// CHECK-NOT: xla.pure_call @add // CHECK: arith.addf -// CHECK-NOT: xla_gpu.pure_call @mul +// CHECK-NOT: xla.pure_call @mul // CHECK: arith.mulf // ----- @@ -30,7 +30,7 @@ module { module { func.func @fused_computation(%arg0: tensor<2xf32> {xla.slice_index = 0 : index}, %arg1: tensor<2xf32> {xla.slice_index = 1 : index}, %arg2: tensor<2xf32> {xla.slice_index = 2 : index}) -> tensor<2xf32> attributes {xla.entry} { %0 = gpu.thread_id x {xla.range = [0 : index, 1 : index]} - %1 = xla_gpu.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 + %1 = xla.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 %inserted = tensor.insert %1 into %arg2[%0] : tensor<2xf32> return %inserted : tensor<2xf32> } @@ -48,7 +48,7 @@ module { // CHECK-LABEL: module { // CHECK: @fused_computation -// CHECK-NOT: xla_gpu.pure_call @add +// CHECK-NOT: xla.pure_call @add // CHECK: gpu.thread_id // CHECK-NEXT: tensor.extract // CHECK-NEXT: tensor.extract @@ -80,13 +80,13 @@ module { func.func private @add(%a: f32, %b: f32) -> f32 { %add = arith.addf %a, %b : f32 - %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + %ret = xla.pure_call @large(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } func.func @caller(%a: f32, %b: f32) -> f32 { - %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + %add = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla.pure_call @large(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } } @@ -94,8 +94,8 @@ module { // CHECK-LABEL: module { // CHECK: @caller // CHECK: arith.addf -// CHECK: xla_gpu.pure_call @large -// CHECK: xla_gpu.pure_call @large +// CHECK: xla.pure_call @large +// CHECK: xla.pure_call @large // ----- @@ -106,15 +106,15 @@ module { } func.func @caller(%a: f32, %b: f32) -> f32 { - %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = xla_gpu.pure_call @add(%add, %add) : (f32, f32) -> (f32) + %add = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla.pure_call @add(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } } // CHECK-LABEL: module { // CHECK: @caller -// CHECK-NOT: xla_gpu.pure_call +// CHECK-NOT: xla.pure_call // CHECK: arith.addf // CHECK: arith.addf @@ -129,48 +129,48 @@ module { return %start : f32 } func.func private @fib2(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib0(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) + %a = xla.pure_call @fib0(%start) : (f32) -> (f32) + %b = xla.pure_call @fib1(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func private @fib3(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) + %a = xla.pure_call @fib1(%start) : (f32) -> (f32) + %b = xla.pure_call @fib2(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func private @fib4(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) + %a = xla.pure_call @fib2(%start) : (f32) -> (f32) + %b = xla.pure_call @fib3(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } // When inlining the other functions into @fib5, this function exceeds the // threshold for inlining. func.func private @fib5(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) + %a = xla.pure_call @fib3(%start) : (f32) -> (f32) + %b = xla.pure_call @fib4(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } // As we do not inline @fib5 into @fib6, this function stays below the // threshold for inlining. func.func private @fib6(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) + %a = xla.pure_call @fib4(%start) : (f32) -> (f32) + %b = xla.pure_call @fib5(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func private @fib7(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib6(%start) : (f32) -> (f32) + %a = xla.pure_call @fib5(%start) : (f32) -> (f32) + %b = xla.pure_call @fib6(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func @caller(%a: f32) -> f32 { - %ret = xla_gpu.pure_call @fib7(%a) : (f32) -> (f32) + %ret = xla.pure_call @fib7(%a) : (f32) -> (f32) return %ret : f32 } } @@ -178,12 +178,12 @@ module { // CHECK-LABEL: module { // CHECK: @caller // CHECK: arith.constant 0.000000e+00 -// CHECK: xla_gpu.pure_call @fib5 +// CHECK: xla.pure_call @fib5 // CHECK: arith.addf // CHECK: arith.addf // CHECK: arith.addf // CHECK: arith.addf -// CHECK: xla_gpu.pure_call @fib5 +// CHECK: xla.pure_call @fib5 // CHECK: arith.addf // CHECK: arith.addf @@ -196,7 +196,7 @@ module { } func.func @caller(%a: f32, %b: f32) -> complex { - %ret = xla_gpu.pure_call @complex(%a, %b) : (f32, f32) -> (complex) + %ret = xla.pure_call @complex(%a, %b) : (f32, f32) -> (complex) return %ret : complex } } @@ -214,7 +214,7 @@ module { } func.func private @callee1(%a: f32) -> f32 { - %c1 = xla_gpu.pure_call @callee2(%a) : (f32) -> (f32) + %c1 = xla.pure_call @callee2(%a) : (f32) -> (f32) %b0 = arith.addf %a, %a : f32 %b1 = arith.addf %b0, %a : f32 %b2 = arith.addf %b1, %a : f32 @@ -223,18 +223,18 @@ module { %b5 = arith.addf %b4, %a : f32 %b6 = arith.addf %b5, %a : f32 %b7 = arith.addf %b6, %a : f32 - %c2 = xla_gpu.pure_call @callee2(%b7) : (f32) -> (f32) + %c2 = xla.pure_call @callee2(%b7) : (f32) -> (f32) %ret = arith.addf %c1, %c2 : f32 return %ret : f32 } func.func private @dead(%a: f32) -> f32 { - %ret = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %ret = xla.pure_call @callee1(%a) : (f32) -> (f32) return %ret : f32 } func.func @caller(%a: f32, %b: f32) -> f32 { - %ret = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %ret = xla.pure_call @callee1(%a) : (f32) -> (f32) return %ret : f32 } } @@ -242,7 +242,7 @@ module { // CHECK-LABEL: module { // CHECK-NOT: func.func // CHECK: func.func @caller -// CHECK-NOT: xla_gpu.pure_call +// CHECK-NOT: xla.pure_call // CHECK-NOT: func.func // ----- @@ -265,7 +265,7 @@ module { } func.func private @callee2(%a: f32) -> f32 { - %call = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %call = xla.pure_call @callee1(%a) : (f32) -> (f32) %b0 = arith.addf %a, %a : f32 %b1 = arith.addf %b0, %a : f32 %b2 = arith.addf %b1, %a : f32 @@ -281,8 +281,8 @@ module { } func.func @caller(%a: f32, %b: f32) -> f32 { - %call1 = xla_gpu.pure_call @callee2(%a) : (f32) -> (f32) - %call2 = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %call1 = xla.pure_call @callee2(%a) : (f32) -> (f32) + %call2 = xla.pure_call @callee1(%a) : (f32) -> (f32) %ret = arith.addf %call1, %call2 : f32 return %ret : f32 } diff --git a/third_party/xla/xla/codegen/ir/tests/invalid.mlir b/third_party/xla/xla/codegen/ir/tests/invalid.mlir new file mode 100644 index 00000000000000..2e929d0624e8fe --- /dev/null +++ b/third_party/xla/xla/codegen/ir/tests/invalid.mlir @@ -0,0 +1,86 @@ +// RUN: emitters_opt %s -split-input-file -verify-diagnostics + +#map0 = #xla.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> +func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { + // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} + %0:2 = xla.apply_indexing #map0 (%d0) + func.return %0#0, %0#1 : index, index +} + +// ----- + +#map0 = #xla.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10]"> +func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { + // expected-error @+1 {{apply indexing op cannot have any constraints}} + %0:2 = xla.apply_indexing #map0 (%d0, %d1)[%s0] + func.return %0#0, %0#1 : index, index +} + +// ----- + +#map = #xla.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> +func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, + %init: f32) -> (f32) { + // expected-error @+1 {{mismatch in number of loop-carried values and results}} + %sum:2 = "xla.loop"(%init) <{ + indexing_map_attr = #map, + operandSegmentSizes = array + }> ({ + ^bb0(%i: index, %j: index, %r0: index, %r1: index, %iter: f32): + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %iter, %t : f32 + xla.yield %add : f32 + }) : (f32) -> (f32, f32) + func.return %sum#0 : f32 +} + +// ----- + +#map = #xla.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024]"> +func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, + %init: f32) -> (f32) { + // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} + %sum = "xla.loop"(%init) <{ + indexing_map_attr = #map, + operandSegmentSizes = array + }> ({ + ^bb0(%i: index, %j: index, %r0: index, %r1: index, %iter: f32): + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %iter, %t : f32 + xla.yield %add : f32 + }) : (f32) -> (f32) + func.return %sum : f32 +} + +// ----- + +#map = #xla.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> + +func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { + // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} + %sum = "xla.loop"(%init) <{ + indexing_map_attr = #map, + operandSegmentSizes = array + }> ({ + ^bb0(%i: index, %j: index, %r0: index, %r1: index, %iter: f32): + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %iter, %t : f32 + xla.yield %add : f32 + }) : (f32) -> (i32) + func.return %sum : i32 +} + +// ----- + +#map = #xla.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> + +func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { + // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} + %sum = xla.loop ()[%i, %j] -> (%r0, %r1) + in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %sum_, %t : f32 + xla.yield %add : f32 + } {xla.range = [0 : index, 42 : index]} + func.return %sum : f32 +} diff --git a/third_party/xla/xla/codegen/ir/tests/ops.mlir b/third_party/xla/xla/codegen/ir/tests/ops.mlir new file mode 100644 index 00000000000000..f2264693c7f541 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/tests/ops.mlir @@ -0,0 +1,122 @@ +// RUN: emitters_opt %s --split-input-file | FileCheck %s +// Verify the printed output can be parsed. +// RUN: emitters_opt %s --split-input-file | emitters_opt --split-input-file | FileCheck %s +// Verify the generic form can be parsed. +// RUN: emitters_opt %s --split-input-file --mlir-print-op-generic | emitters_opt --split-input-file | FileCheck %s + +func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index) + -> (tensor<2x3xf32>) { + %ret = xla.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 42.0 : f32 + %add = arith.addf %current, %c42 : f32 + xla.yield %add : f32 + } + return %ret : tensor<2x3xf32> +} +// CHECK-LABEL: @atomic_rmw +// CHECK: xla.atomic_rmw + +// ----- + +func.func private @add(%a: f32, %b: f32) -> f32 { + %ret = arith.addf %a, %b : f32 + return %ret : f32 +} + +func.func @caller(%a: f32, %b: f32) -> f32 { + %c = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %d = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = arith.addf %c, %d : f32 + return %ret : f32 +} +// CHECK-LABEL: @caller +// CHECK: %[[C:.*]] = xla.pure_call @add +// CHECK: %[[D:.*]] = xla.pure_call @add +// CHECK: arith.addf %[[C]], %[[D]] + +// CHECK-CSE: @caller +// CHECK-CSE: %[[C:.*]] = xla.pure_call @add +// CHECK-CSE: arith.addf %[[C]], %[[C]] + +// ----- + +#map0 = #xla.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0)," + "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> +func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { + %0:2 = xla.apply_indexing #map0 (%d0, %d1)[%s0] + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP0:.*]] = #xla.indexing_map<" +// CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: s0 in [0, 32] +// CHECK-SAME: > + +// CHECK-LABEL: @apply_indexing +// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) +// CHECK: xla.apply_indexing #[[$MAP0]] +// CHECK-SAME: (%[[d0]], %[[d1]])[%[[s0]]] + +// ----- + +#map0 = #xla.indexing_map<"(d0, d1) -> (d0, d1)," + "domain: d0 in [0, 2], d1 in [1, 3]"> +func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { + %0:2 = xla.apply_indexing #map0 (%d0, %d1) + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP0:.*]] = #xla.indexing_map<" +// CHECK-SAME: (d0, d1) -> (d0, d1) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [0, 2] +// CHECK-SAME: d1 in [1, 3] +// CHECK-SAME: > + +// CHECK-LABEL: @apply_indexing_no_symbols +// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) +// CHECK: xla.apply_indexing #[[$MAP0]] +// CHECK-SAME: (%[[d0]], %[[d1]]) + +// ----- + +#map0 = #xla.indexing_map<"()[s0] -> (s0, s0)," + "domain: s0 in [2, 4]"> +func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { + %0:2 = xla.apply_indexing #map0 [%s0] + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP0:.*]] = #xla.indexing_map<" +// CHECK-SAME: ()[s0] -> (s0, s0) +// CHECK-SAME: domain: +// CHECK-SAME: s0 in [2, 4] +// CHECK-SAME: > + +// CHECK-LABEL: @apply_indexing_no_dims +// CHECK: (%[[s0:.*]]: index) +// CHECK: xla.apply_indexing #[[$MAP0]][%[[s0]]] + +// ----- + +#map = #xla.indexing_map<"(d0)[s0, s1] -> (s0, s1), " + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> +func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, + %dim: index) -> (f32) { + %sum = xla.loop (%dim)[%i, %j] -> (%r0, %r1) + in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %sum_, %t : f32 + xla.yield %add : f32 + } {xla.range = [0 : index, 42 : index]} + func.return %sum : f32 +} +// CHECK: #[[$MAP:.*]] = #xla.indexing_map +// CHECK: xla.loop (%{{.*}})[%[[I:.*]], %[[J:.*]]] -> +// CHECK-SAME: (%[[R0:.*]], %[[R1:.*]]) in #[[$MAP]] +// CHECK-SAME: iter_args(%[[SUM_ITER:.*]] = %{{.*}}) -> (f32) { +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[%[[I]], %[[J]]] +// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[EXTRACTED]] : f32 +// CHECK: xla.yield %[[ADD]] : f32 +// CHECK: } {xla.range = [0 : index, 42 : index]} diff --git a/third_party/xla/xla/codegen/ir/xla_attrs.cc b/third_party/xla/xla/codegen/ir/xla_attrs.cc new file mode 100644 index 00000000000000..5f0e416064ec6c --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_attrs.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "llvm/Support/LogicalResult.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "xla/codegen/ir/xla_ops.h" +#include "xla/hlo/analysis/indexing_map.h" +#include "xla/hlo/analysis/indexing_map_serialization.h" + +namespace xla { + +using llvm::ParseResult; +using llvm::SmallVector; +using mlir::AffineExpr; +using mlir::ArrayRef; +using mlir::AsmParser; +using mlir::AsmPrinter; +using mlir::failure; +using mlir::success; + +mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { + if (parser.parseLess()) { + return {}; + } + auto indexing_map = parseChainOfStringsAsIndexingMap(parser); + if (!indexing_map.has_value() || parser.parseGreater()) { + return {}; + } + return IndexingMapAttr::get(parser.getContext(), *indexing_map); +} + +void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { + printer << "<\"" << ToString(getIndexingMap()) << "\">"; +} + +IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, + const IndexingMap& indexing_map) { + llvm::SmallVector> constraints; + for (auto& constraint : indexing_map.GetConstraints()) { + constraints.push_back({constraint.first, constraint.second}); + } + return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(), + indexing_map.GetRangeVars(), constraints); +} + +mlir::LogicalResult IndexingMapAttr::verify( + mlir::function_ref emitError, + mlir::AffineMap map, ArrayRef dim_vars, + ArrayRef range_vars, + ArrayRef> constraints) { + auto indexing_map = + IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, constraints); + std::stringstream ss; + if (!indexing_map.Verify(ss)) { + return emitError() << ss.str(); + } + return success(); +} + +IndexingMap IndexingMapAttr::getIndexingMap() const { + return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{}, + getConstraints()); +} + +int64_t IndexingMapAttr::getNumResults() const { + return getMap().getNumResults(); +} + +} // namespace xla diff --git a/third_party/xla/xla/codegen/ir/xla_attrs.td b/third_party/xla/xla/codegen/ir/xla_attrs.td new file mode 100644 index 00000000000000..d52f29fea13614 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_attrs.td @@ -0,0 +1,65 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_CODEGEN_IR_XLA_ATTRS +#define XLA_CODEGEN_IR_XLA_ATTRS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "xla/codegen/ir/xla_dialect.td" + +class XLA_Attr traits = []> : + AttrDef { +} + +def XLA_AffineMapParameter : + AttrOrTypeParameter<"::mlir::AffineMap", ""> { +} + +def XLA_IndexingMapVariableParameter + : ArrayRefParameter<"::xla::IndexingMap::Variable", + "IndexingMapVariableArray"> { +} + +def XLA_ConstraintsParameter : + ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::Interval>", + "ContraintsArray"> { +} + +def XLA_IndexingMapAttr : XLA_Attr<"IndexingMap"> { + let summary = "An Attribute representing an indexing map."; + let mnemonic = "indexing_map"; + let description = [{This attribute stores an indexing map. See + https://openxla.org/xla/indexing for more details. + }]; + let parameters = (ins XLA_AffineMapParameter:$map, + XLA_IndexingMapVariableParameter:$dim_vars, + XLA_IndexingMapVariableParameter:$range_vars, + XLA_ConstraintsParameter:$constraints); + let hasCustomAssemblyFormat = 1; + let builders = [ + AttrBuilder<(ins "const ::xla::IndexingMap&":$indexing_map)>, + ]; + let genVerifyDecl = 1; + let extraClassDeclaration = [{ + // Returns the indexing map constructed from IndexingMapAttr. + xla::IndexingMap getIndexingMap() const; + + // Returns the number of indexing map results. + int64_t getNumResults() const; + }]; +} + +#endif // XLA_CODEGEN_IR_XLA_ATTRS diff --git a/third_party/xla/xla/codegen/ir/xla_dialect.cc b/third_party/xla/xla/codegen/ir/xla_dialect.cc new file mode 100644 index 00000000000000..5d84acc6da6997 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_dialect.cc @@ -0,0 +1,132 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "mlir/Transforms/InliningUtils.h" +#include "xla/codegen/ir/xla_ops.h" + +// The order of these includes is important. +#define GET_ATTRDEF_CLASSES +#include "xla/codegen/ir/xla_attrs.cc.inc" + +namespace xla { +namespace { + +struct XlaInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + // Returns true if the given operation 'callable', that implements the + // 'CallableOpInterface', can be inlined into the position given call + // operation 'call', that is registered to the current dialect and implements + // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the + // given 'callable' is set to be cloned during the inlining process, or false + // if the region is set to be moved in-place (i.e. no duplicates would be + // created). + bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, + bool wouldBeCloned) const final { + if (!wouldBeCloned) { + // If no duplicate would be created, 'call' is likely the only caller of + // 'callable'. + return true; + } + // Otherwise, inline only if the called function is small. We could + // theoretically also inline if there is no other caller in the function + // that contains the callee that has a call path to the callable, but that + // is more expensive to check. + auto func_op = mlir::dyn_cast(callable); + if (!func_op) { + return false; + } + auto region = func_op.getCallableRegion(); + if (!region) { + return false; + } + + // If callee and caller call the same third function, inline. We have no + // guarantee that the indices are the same, but there is a good chance they + // are (or if the callee gets inlined as well, there will be CSE + // opportunities). + // This is duct tape to work around the limitations of our partitioner. + // Ideally, the partitioner would be aware of the actual indexing and create + // the partitions based on it (i.e., the case where the indices are the same + // would never happen). + llvm::SmallDenseSet callee_calls; + for (auto call : region->getOps()) { + callee_calls.insert(call.getCallee()); + } + for (auto call : call->getParentRegion()->getOps()) { + if (callee_calls.contains(call.getCallee())) { + return true; + } + } + + constexpr int kMaxOperationsToInline = 8; + int num_ops = 0; + region->front().walk([&](mlir::Operation* op) { ++num_ops; }); + + // Don't inline functions that are called more than once and contain more + // than one call themselves. + return num_ops <= kMaxOperationsToInline; + } + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the given region, false otherwise. + // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned + // during the inlining process, or false if the operation is set to be moved + // in-place(i.e. no duplicates would be created). 'valueMapping' contains any + // remapped values from within the 'src' region. This can be used to examine + // what values may potentially replace the operands to 'op'. + bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, + bool wouldBeCloned, + mlir::IRMapping& valueMapping) const final { + // We allow any op from the xla dialect to be inlined. + return true; + } + // We don't have any special restrictions on what can be inlined into + // destination regions (e.g. while/conditional bodies). Always allow it. + bool isLegalToInline(mlir::Region* dest, mlir::Region* src, + bool wouldBeCloned, + mlir::IRMapping& valueMapping) const final { + return true; + } +}; + +struct XlaOpAsmDialectInterface : public mlir::OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + AliasResult getAlias(mlir::Attribute attr, + mlir::raw_ostream& os) const final { + if (llvm::isa(attr)) { + os << "indexing_map"; + return AliasResult::FinalAlias; + } + return AliasResult::NoAlias; + } +}; + +} // namespace + +void XlaDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "xla/codegen/ir/xla_ops.cc.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "xla/codegen/ir/xla_attrs.cc.inc" + >(); + addInterfaces(); +} + +} // namespace xla diff --git a/third_party/xla/xla/codegen/ir/xla_dialect.td b/third_party/xla/xla/codegen/ir/xla_dialect.td new file mode 100644 index 00000000000000..33e6a68528d7e8 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_dialect.td @@ -0,0 +1,30 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_CODEGEN_IR_XLA_DIALECT +#define XLA_CODEGEN_IR_XLA_DIALECT + +include "mlir/IR/DialectBase.td" + +def XlaDialect : Dialect { + let name = "xla"; + let description = [{ + This dialect contains ops required for lowering HLO to LLVM. + }]; + let cppNamespace = "::xla"; + let useDefaultAttributePrinterParser = 1; +} + +#endif // XLA_CODEGEN_IR_XLA_DIALECT diff --git a/third_party/xla/xla/codegen/ir/xla_ops.cc b/third_party/xla/xla/codegen/ir/xla_ops.cc new file mode 100644 index 00000000000000..9cc80547c34fd4 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_ops.cc @@ -0,0 +1,950 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/codegen/ir/xla_ops.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" // IWYU pragma: keep +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "xla/codegen/ir/xla_dialect.cc.inc" +#include "xla/hlo/analysis/indexing_map.h" +#include "xla/hlo/analysis/indexing_map_serialization.h" + +namespace xla { +namespace { + +using llvm::ArrayRef; +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::Block; +using mlir::DenseI64ArrayAttr; +using mlir::failure; +using mlir::getAffineConstantExpr; +using mlir::getAffineDimExpr; +using mlir::getAffineSymbolExpr; +using mlir::Location; +using mlir::LogicalResult; +using mlir::MLIRContext; +using mlir::OpAsmParser; +using mlir::OpAsmPrinter; +using mlir::OpBuilder; +using mlir::OperationState; +using mlir::ParseResult; +using mlir::PatternRewriter; +using mlir::RankedTensorType; +using mlir::Region; +using mlir::SmallVector; +using mlir::success; +using mlir::Type; +using mlir::TypeRange; +using mlir::Value; +using mlir::ValueRange; + +namespace arith = mlir::arith; + +} // namespace + +LogicalResult PureCallOp::verifySymbolUses( + mlir::SymbolTableCollection& symbolTable) { + auto callee = getCalleeAttr(); + auto function = + symbolTable.lookupNearestSymbolFrom(*this, callee); + if (!function) { + return emitError("'f' attribute refers to an undefined function: ") + << callee; + } + + int func_arg_count = function.getFunctionType().getNumInputs(); + int arg_count = getOperands().size(); + + if (arg_count != func_arg_count) { + return emitError() << "argument count mismatch: 'operands' has " + << arg_count << " arguments, but '" << callee + << "' expects " << func_arg_count; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ApplyIndexingOp +//===----------------------------------------------------------------------===// + +void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, + ValueRange dims, ValueRange symbols, + const IndexingMap& indexing_map) { + SmallVector operands; + operands.reserve(dims.size() + symbols.size()); + operands.append(dims.begin(), dims.end()); + operands.append(symbols.begin(), symbols.end()); + build(builder, result, operands, indexing_map); +} + +void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, + ValueRange operands, + const IndexingMap& indexing_map) { + SmallVector result_types(indexing_map.GetAffineMap().getNumResults(), + builder.getIndexType()); + IndexingMapAttr indexing_map_attr = + IndexingMapAttr::get(builder.getContext(), indexing_map); + build(builder, result, result_types, operands, indexing_map_attr); +} + +void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, + ValueRange operands, AffineMap affine_map, + ArrayRef dim_vars, + ArrayRef range_vars) { + IndexingMap indexing_map(affine_map, dim_vars, range_vars, {}); + build(builder, result, operands, indexing_map); +} + +ParseResult ApplyIndexingOp::parse(OpAsmParser& parser, + OperationState& result) { + mlir::Builder& builder = parser.getBuilder(); + auto index_type = builder.getIndexType(); + + IndexingMapAttr indexing_map_attr; + if (parser.parseAttribute(indexing_map_attr, "indexing_map_attr", + result.attributes)) { + return failure(); + } + + SmallVector operands; + SmallVector lower_bounds, upper_bounds; + if (succeeded(parser.parseOptionalLParen())) { + if (parseOperands(parser, &operands) || parser.parseRParen()) { + return failure(); + } + } + if (succeeded(parser.parseOptionalLSquare())) { + if (parseOperands(parser, &operands) || parser.parseRSquare()) { + return failure(); + } + } + if (parser.resolveOperands(operands, index_type, result.operands) || + parser.parseOptionalAttrDict(result.attributes)) { + return failure(); + } + auto map = indexing_map_attr.getIndexingMap().GetAffineMap(); + result.addTypes(SmallVector(map.getNumResults(), index_type)); + return success(); +} + +void ApplyIndexingOp::print(OpAsmPrinter& p) { + AffineMap affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); + p << " " << getIndexingMapAttr(); + + auto operands = getOperands(); + unsigned num_dimensions = affine_map.getNumDims(); + if (num_dimensions > 0) { + p << '('; + auto dimension_operands = operands.slice(0, num_dimensions); + llvm::interleaveComma(dimension_operands, p); + p << ')'; + } + + unsigned num_symbols = affine_map.getNumSymbols(); + if (num_symbols > 0) { + p << '['; + auto symbol_operands = operands.slice(num_dimensions, num_symbols); + llvm::interleaveComma(symbol_operands, p); + p << ']'; + } + + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"indexing_map_attr"}); +} + +LogicalResult ApplyIndexingOp::verify() { + auto affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); + unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); + if (getOperands().size() != num_variables) { + return emitOpError( + "operand count must match the number of dimensions and symbols in the " + "affine map"); + } + if (!getIndexingMap().GetConstraints().empty()) { + return emitOpError("apply indexing op cannot have any constraints"); + } + return success(); +} + +IndexingMap ApplyIndexingOp::getIndexingMap() { + return getIndexingMapAttr().getIndexingMap(); +} + +namespace { +struct IndexingMapWithAdditions { + IndexingMap indexing_map; + SmallVector added_dim_args; +}; + +// Returns the new indexing map after replacing ops. +// Only supports replacing operands that are dimensions. +absl::StatusOr GetNewIndexingMapAfterFoldingSequence( + IndexingMap indexing_map, + SmallVector, 2> apply_indexing_ops, + mlir::DenseMap operand_exprs, MLIRContext* ctx) { + int num_dims = indexing_map.GetDimensionCount(); + + SmallVector added_dim_args; + auto new_dim_vars = indexing_map.GetDimVars(); + + mlir::DenseMap replacements; + for (auto& [operand_id, producer] : apply_indexing_ops) { + auto producer_map = producer.getIndexingMap(); + mlir::OpResult producer_result = producer->getOpResult(0); + int producer_result_id = producer_result.getResultNumber(); + int num_producer_dims = producer.getAffineMap().getNumDims(); + SmallVector producer_dim_replacements; + for (auto& producer_operand : producer->getOpOperands()) { + int producer_operand_number = producer_operand.getOperandNumber(); + if (producer_operand_number >= num_producer_dims) { + // MoveSymbolsToDims will convert symbols to dimensions later. + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("In apply_indexing_op ", operand_id, + " producer has operand number ", + producer_operand_number, " outside of dimensions [0,", + num_producer_dims, ")")); + } + auto& replacement_expr = operand_exprs[producer_operand.get()]; + if (!replacement_expr) { + int dim_num = producer_operand_number; + replacement_expr = + getAffineDimExpr(num_dims + added_dim_args.size(), ctx); + added_dim_args.push_back(producer_operand.get()); + new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); + } + producer_dim_replacements.push_back(replacement_expr); + } + replacements[operand_exprs[producer_result]] = + producer.getAffineMap() + .getResult(producer_result_id) + .replaceDims(producer_dim_replacements); + } + + auto new_affine_map = indexing_map.GetAffineMap().replace( + replacements, num_dims + added_dim_args.size(), + indexing_map.GetSymbolCount()); + IndexingMap new_indexing_map(new_affine_map, new_dim_vars, + indexing_map.GetRangeVars(), + indexing_map.GetRTVars()); + + return IndexingMapWithAdditions{new_indexing_map, added_dim_args}; +} +} // namespace + +namespace { + +// Simplifies the indexing map. +struct SimplifyIndexingMap : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + IndexingMap indexing_map = indexing_op.getIndexingMap(); + if (!indexing_map.Simplify()) { + return rewriter.notifyMatchFailure(indexing_op, + "IndexingMap is already simplified"); + } + rewriter.replaceOpWithNewOp( + indexing_op, indexing_op.getOperands(), indexing_map); + return success(); + } +}; + +// Removes unused variables. +struct RemoveUnusedVariables : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + IndexingMap indexing_map = indexing_op.getIndexingMap(); + auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars(); + if (unused_symbols_bit_vector.count() == 0) { + return rewriter.notifyMatchFailure(indexing_op, + "IndexingMap stayed unchanged"); + } + SmallVector operands; + operands.reserve(unused_symbols_bit_vector.count()); + for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) { + if (!unused_symbols_bit_vector[i]) { + operands.push_back(indexing_op.getOperand(i)); + } + } + rewriter.replaceOpWithNewOp(indexing_op, operands, + indexing_map); + return success(); + } +}; + +struct MoveSymbolsToDims : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + IndexingMap indexing_map = indexing_op.getIndexingMap(); + if (indexing_map.GetSymbolCount() == 0) { + return rewriter.notifyMatchFailure(indexing_op, "No symbols found"); + } + rewriter.replaceOpWithNewOp( + indexing_op, indexing_op->getOperands(), + indexing_map.ConvertSymbolsToDimensions()); + return success(); + } +}; + +struct FoldApplyIndexingSequence + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + auto indexing_map = indexing_op.getIndexingMap(); + SmallVector, 2> apply_indexing_ops; + bool all_apply_indexing_operands_have_one_use = true; + for (auto& operand : indexing_op->getOpOperands()) { + if (auto producer = operand.get().getDefiningOp()) { + apply_indexing_ops.push_back({operand.getOperandNumber(), producer}); + all_apply_indexing_operands_have_one_use &= producer->hasOneUse(); + } + } + if (apply_indexing_ops.empty()) { + return rewriter.notifyMatchFailure(indexing_op, + "No apply_indexing sequences found"); + } + // If the indexing map has unused variables, we can accidentally fuse an + // operand that is not used in the map and it can lead to an infinite loop + // in canonicalizer. + auto indexing_map_with_no_unused_vars = indexing_map; + if (indexing_map_with_no_unused_vars.RemoveUnusedVars().count() > 0) { + indexing_map_with_no_unused_vars.RemoveUnusedVars(); + return rewriter.notifyMatchFailure(indexing_op, + "IndexingMap has unused variables"); + } + MLIRContext* ctx = indexing_op.getContext(); + int num_dims = indexing_op.getAffineMap().getNumDims(); + int num_syms = indexing_op.getAffineMap().getNumSymbols(); + mlir::DenseMap operand_exprs; + for (auto& operand : indexing_op->getOpOperands()) { + int operand_number = operand.getOperandNumber(); + operand_exprs[operand.get()] = + operand_number < num_dims + ? getAffineDimExpr(operand_number, ctx) + : getAffineSymbolExpr(operand_number - num_dims, ctx); + } + + auto replacement = GetNewIndexingMapAfterFoldingSequence( + indexing_map, apply_indexing_ops, operand_exprs, ctx); + + if (!replacement.ok()) { + return rewriter.notifyMatchFailure(indexing_op, + replacement.status().message()); + } + + if (!all_apply_indexing_operands_have_one_use && + !replacement->indexing_map.Simplify()) { + return rewriter.notifyMatchFailure( + indexing_op, "Folded indexing map was not simplified"); + } + + int new_num_operands = + indexing_op->getNumOperands() + replacement->added_dim_args.size(); + SmallVector new_operands; + new_operands.reserve(new_num_operands); + + auto begin = indexing_op.getOperands().begin(); + new_operands.append(begin, begin + num_dims); + new_operands.append(replacement->added_dim_args); + new_operands.append(begin + num_dims, begin + num_dims + num_syms); + + rewriter.replaceOpWithNewOp(indexing_op, new_operands, + replacement->indexing_map); + + return success(); + } +}; + +// Folds constants into the indexing map. +struct FoldApplyIndexingOperands + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + IndexingMap indexing_map = indexing_op.getIndexingMap(); + AffineMap affine_map = indexing_map.GetAffineMap(); + + MLIRContext* ctx = affine_map.getContext(); + unsigned num_operands = indexing_op->getNumOperands(); + unsigned num_dims = affine_map.getNumDims(); + unsigned num_symbols = affine_map.getNumSymbols(); + + SmallVector> constant_values(num_operands, + std::nullopt); + int num_constants = 0; + for (auto& operand : indexing_op->getOpOperands()) { + if (auto constant = + operand.get().getDefiningOp()) { + constant_values[operand.getOperandNumber()] = constant.value(); + ++num_constants; + } + } + if (num_constants == 0) { + return rewriter.notifyMatchFailure(indexing_op, + "No constant operands found"); + } + SmallVector dim_replacements, symbol_replacements; + dim_replacements.reserve(num_dims); + symbol_replacements.reserve(num_symbols); + + unsigned new_num_operands = indexing_op->getNumOperands() - num_constants; + SmallVector new_operands; + new_operands.reserve(new_num_operands); + SmallVector new_dim_vars; + new_dim_vars.reserve(num_dims); + + unsigned new_num_dims = 0; + for (auto [operand, constant_value] : + llvm::zip(indexing_op->getOpOperands(), constant_values)) { + unsigned operand_id = operand.getOperandNumber(); + if (operand_id >= num_dims) { + return rewriter.notifyMatchFailure( + indexing_op, + absl::StrCat("operand ", operand_id, + " is outside of dimension range [0, ", num_dims, ")")); + } + if (constant_value.has_value()) { + dim_replacements.push_back(getAffineConstantExpr(*constant_value, ctx)); + } else { + new_operands.push_back(operand.get()); + dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); + new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); + } + } + rewriter.replaceOpWithNewOp( + indexing_op, new_operands, + affine_map.replaceDimsAndSymbols(dim_replacements, {}, new_num_dims, 0), + new_dim_vars, SmallVector()); + return success(); + } +}; + +// Folds constant and dim/symbol expression results. +struct FoldApplyIndexingResults + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + Location loc = indexing_op.getLoc(); + IndexingMap indexing_map = indexing_op.getIndexingMap(); + if (indexing_map.IsKnownEmpty()) { + return rewriter.notifyMatchFailure(indexing_op, + "Domain of the indexing map is empty"); + } + AffineMap* affine_map = &indexing_map.GetMutableAffineMap(); + unsigned num_results = affine_map->getNumResults(); + SmallVector new_exprs; + new_exprs.reserve(num_results); + SmallVector new_values; + new_values.reserve(num_results); + for (mlir::OpResult opresult : indexing_op->getOpResults()) { + if (opresult.use_empty()) { + new_values.push_back(rewriter.create(loc, 0)); + continue; + } + + unsigned id = opresult.getResultNumber(); + AffineExpr result_expr = affine_map->getResult(id); + if (auto const_expr = + mlir::dyn_cast(result_expr)) { + new_values.push_back(rewriter.create( + loc, const_expr.getValue())); + continue; + } + if (auto dim_expr = mlir::dyn_cast(result_expr)) { + new_values.push_back(indexing_op.getOperand(dim_expr.getPosition())); + continue; + } + if (auto symbol_expr = + mlir::dyn_cast(result_expr)) { + new_values.push_back(indexing_op.getOperand( + indexing_map.GetDimVarsCount() + symbol_expr.getPosition())); + continue; + } + new_exprs.push_back(result_expr); + new_values.push_back(Value{}); + } + if (new_exprs.size() == num_results) { + return rewriter.notifyMatchFailure( + indexing_op, "No constant or dim/symbol expression found"); + } + *affine_map = + AffineMap::get(affine_map->getNumDims(), affine_map->getNumSymbols(), + new_exprs, affine_map->getContext()); + auto new_indexing_op = rewriter.create( + loc, indexing_op.getOperands(), indexing_map); + for (int new_result_id = 0, new_indexing_op_result_id = 0; + new_result_id < new_values.size(); ++new_result_id) { + auto& new_value = new_values[new_result_id]; + if (new_value) continue; + new_value = new_indexing_op.getResult(new_indexing_op_result_id++); + } + rewriter.replaceOp(indexing_op, new_values); + return success(); + } +}; + +} // namespace + +void ApplyIndexingOp::getCanonicalizationPatterns( + mlir::RewritePatternSet& results, MLIRContext* context) { + results.add(context); +} + +mlir::LogicalResult ApplyIndexingOp::fold( + FoldAdaptor adaptor, llvm::SmallVectorImpl& results) { + auto map = getAffineMap(); + for (auto expr : map.getResults()) { + if (auto dim = mlir::dyn_cast(expr)) { + results.push_back(getOperand(dim.getPosition())); + } else if (auto sym = mlir::dyn_cast(expr)) { + results.push_back(getOperand(map.getNumDims() + sym.getPosition())); + } else { + results.clear(); + return failure(); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +void AtomicRMWOp::getAsmResultNames( + llvm::function_ref setNameFn) { + setNameFn(getResult(), "atomic_rmw"); +} + +void AtomicRMWOp::build(OpBuilder& builder, OperationState& result, + Value tensor, ValueRange ivs) { + OpBuilder::InsertionGuard g(builder); + result.addOperands(tensor); + result.addOperands(ivs); + result.addTypes(tensor.getType()); + + auto tensor_type = llvm::cast(tensor.getType()); + Region* body = result.addRegion(); + builder.createBlock(body); + body->addArgument(tensor_type.getElementType(), tensor.getLoc()); +} + +mlir::OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { + auto* body = getBody(); + if (&body->front() == body->getTerminator() && + body->front().getOperand(0) == body->getArgument(0)) { + return getOperand(0); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// PureCallOp +//===----------------------------------------------------------------------===// + +void PureCallOp::getAsmResultNames( + llvm::function_ref setNameFn) { + for (auto result : getResults()) { + setNameFn(result, "pure_call"); + } +} + +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +void LoopOp::getAsmResultNames( + llvm::function_ref setNameFn) { + for (auto result : getResults()) { + setNameFn(result, "xla_loop"); + } +} + +void LoopOp::getAsmBlockArgumentNames(mlir::Region& region, + mlir::OpAsmSetValueNameFn setFn) { + // i, j, k, l, m, n, n_0, n_1, ... + char iv_name = 'i'; + for (auto iv : getInductionVars()) { + setFn(iv, std::string{iv_name}); + if (iv_name <= 'n') { + ++iv_name; + } + } + // ra, rb, rc, rd, ..., rz, rz_0, ... + std::string map_result_name = "ra"; + char map_result_char = 'a'; + for (auto map_result : getIndexingMapResults()) { + setFn(map_result, map_result_name); + if (map_result_char <= 'z') { + ++map_result_char; + map_result_name[1] = map_result_char; + } + } + for (auto iv : getRegionIterArgs()) { + setFn(iv, "iter"); + } +} + +void LoopOp::build(OpBuilder& builder, OperationState& result, + IndexingMapAttr indexing_map_attr, ValueRange dims, + ValueRange inits, BodyBuilderFn bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + + int64_t num_ivs = indexing_map_attr.getRangeVars().size(); + int64_t num_indexing_map_results = + indexing_map_attr.getIndexingMap().GetNumResults(); + int64_t num_inits = inits.size(); + result.addOperands(dims); + result.addOperands(inits); + result.addTypes(TypeRange(inits)); + Block* body_block = builder.createBlock(result.addRegion()); + // Add induction variables and indexing map results block args. + for (int i = 0, e = num_ivs + num_indexing_map_results; i < e; ++i) { + body_block->addArgument(builder.getIndexType(), result.location); + } + // Add iteration arguments block args. + for (auto init_type : TypeRange(inits)) { + body_block->addArguments(init_type, result.location); + } + mlir::OperationName opname(LoopOp::getOperationName(), builder.getContext()); + result.addAttribute(LoopOp::getIndexingMapAttrAttrName(opname), + indexing_map_attr); + result.addAttribute( + LoopOp::getOperandSegmentSizesAttrName(opname), + builder.getDenseI32ArrayAttr({static_cast(dims.size()), + static_cast(inits.size())})); + if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(body_block); + bodyBuilder( + builder, result.location, + body_block->getArguments().take_front(num_ivs), + body_block->getArguments().drop_front(num_ivs).drop_back(num_inits), + body_block->getArguments().take_back(num_inits)); + } +} + +void LoopOp::build(OpBuilder& builder, OperationState& result, + const IndexingMap& indexing_map, ValueRange dims, + ValueRange inits, BodyBuilderFn bodyBuilder) { + build(builder, result, + IndexingMapAttr::get(builder.getContext(), indexing_map), dims, inits, + bodyBuilder); +} + +ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector region_args, ivs, map_results, + iter_args; + SmallVector dim_operands; + + // Parse the dimension values. + auto* ctx = parser.getContext(); + OpBuilder b(ctx); + Type index_type = b.getIndexType(); + if (parser.parseOperandList(dim_operands, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(dim_operands, index_type, result.operands)) + return failure(); + // Parse the induction variables. + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Square)) + return failure(); + for (auto iv : ivs) { + region_args.push_back(iv); + region_args.back().type = index_type; + } + + // Parse the indexing map results variables. + if (parser.parseArrow() || + parser.parseArgumentList(map_results, OpAsmParser::Delimiter::Paren)) + return failure(); + for (auto map_result : map_results) { + region_args.push_back(map_result); + region_args.back().type = index_type; + } + + // Parse the indexing map attribute. + IndexingMapAttr indexing_map_attr; + if (parser.parseKeyword("in") || + parser.parseAttribute(indexing_map_attr, "indexing_map_attr", + result.attributes)) { + return failure(); + } + + // Parse the arguments. + SmallVector init_operands; + if (parser.parseKeyword("iter_args") || + parser.parseAssignmentList(iter_args, init_operands) || + parser.parseArrowTypeList(result.types) || + parser.resolveOperands(init_operands, result.types, parser.getNameLoc(), + result.operands)) + return failure(); + + for (auto [index, iter_arg] : llvm::enumerate(iter_args)) { + region_args.push_back(iter_arg); + region_args.back().type = result.types[index]; + } + + if (region_args.size() != + result.types.size() + ivs.size() + map_results.size()) { + return parser.emitError( + parser.getNameLoc(), + "mismatch in number of induction variables + loop-carried values + " + "number of indexing map results variables and the number of results"); + } + + // Parse the body region. + Region* body = result.addRegion(); + if (parser.parseRegion(*body, region_args)) return failure(); + LoopOp::ensureTerminator(*body, b, result.location); + + // Add the necessary attributes. + result.addAttribute( + LoopOp::getOperandSegmentSizeAttr(), + b.getDenseI32ArrayAttr({static_cast(dim_operands.size()), + static_cast(iter_args.size())})); + + // Parse the optional attribute list + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + return success(); +} + +void LoopOp::print(OpAsmPrinter& p) { + p << " (" << getDims() << ")[" << getInductionVars() << "] -> (" + << getIndexingMapResults() << ") in " << getIndexingMapAttr() + << " iter_args("; + llvm::interleaveComma( + llvm::zip(getRegionIterArgs(), getInits()), p, + [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); + p << ") -> (" << getInits().getTypes() << ") "; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{ + getIndexingMapAttrAttrName(), + getOperandSegmentSizesAttrName(), + }); +} + +LogicalResult LoopOp::verify() { + if (getInits().size() != getNumResults()) { + return emitOpError("mismatch in number of loop-carried values and results"); + } + IndexingMap indexing_map = getIndexingMap(); + if (indexing_map.GetRangeVarsCount() != getNumInductionVars()) { + return emitOpError() << "mismatch in number of induction variables " + << getNumInductionVars() + << " and RangeVars in the indexing map " + << ToString(indexing_map); + } + if (indexing_map.GetDimVarsCount() != getDims().size()) { + return emitOpError() << "mismatch in number of dims operands " + << getDims().size() + << " and DimVars in the indexing map " + << ToString(indexing_map); + } + for (auto [bb_arg, result_type, init] : + llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) { + if (bb_arg.getType() != result_type || init.getType() != result_type) { + return emitOpError() << "block iter arg type = " << bb_arg.getType() + << ", result type = " << result_type + << " and init operand type = " << init.getType() + << " should match"; + } + } + return success(); +} + +IndexingMap LoopOp::getIndexingMap() { + return getIndexingMapAttr().getIndexingMap(); +} + +namespace { +// This pattern is tasked to simplify loop(apply_indexing, ...) patterns to fold +// the apply_indexing ops into the loop op where it can. The pattern assumes +// that the apply_indexing ops have already been simplified via +// MoveSymbolsToDims pattern, which basically means that the producer +// apply_indexing ops should not have any symbols. +struct SimplifyLoopOfApplyIndexing : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoopOp loop_op, + PatternRewriter& rewriter) const override { + auto loop_indexing_map = loop_op.getIndexingMap(); + MLIRContext* ctx = loop_op.getContext(); + int num_dims = loop_indexing_map.GetDimVarsCount(); + + SmallVector, 2> apply_indexing_ops; + bool all_apply_indexing_operands_have_one_use = true; + + // Only consider dims. + for (auto& operand : loop_op->getOpOperands().take_front(num_dims)) { + if (auto producer = operand.get().getDefiningOp()) { + // Producer should be canonicalized via MoveSymbolsToDims pattern. + if (producer.getIndexingMap().GetSymbolCount() > 0) { + continue; + } + apply_indexing_ops.push_back({operand.getOperandNumber(), producer}); + all_apply_indexing_operands_have_one_use &= producer->hasOneUse(); + } + } + if (apply_indexing_ops.empty()) { + return rewriter.notifyMatchFailure( + loop_op, + "No loop(apply_indexing) patterns found. Note that producer " + "apply_indexing should have already been simplified via " + "MoveSymbolsToDims pattern."); + } + + mlir::DenseMap operand_exprs; + for (auto& operand : loop_op->getOpOperands().take_front(num_dims)) { + int operand_number = operand.getOperandNumber(); + operand_exprs[operand.get()] = getAffineDimExpr(operand_number, ctx); + } + + auto replacement = GetNewIndexingMapAfterFoldingSequence( + loop_indexing_map, apply_indexing_ops, operand_exprs, ctx); + + if (!replacement.ok()) { + return rewriter.notifyMatchFailure(loop_op, + replacement.status().message()); + } + + if (!all_apply_indexing_operands_have_one_use && + !replacement->indexing_map.Simplify()) { + return rewriter.notifyMatchFailure( + loop_op, "Folded indexing map of the loop op was not simplified"); + } + + int new_num_dims = num_dims + replacement->added_dim_args.size(); + SmallVector aggregate_dims; + aggregate_dims.reserve(new_num_dims); + + auto begin = loop_op.getOperands().begin(); + aggregate_dims.append(begin, begin + num_dims); + aggregate_dims.append(replacement->added_dim_args); + + // Remove unused dims. + SmallVector used_dims; + used_dims.reserve(aggregate_dims.size()); + auto used_dim_bit_vector = ~replacement->indexing_map.RemoveUnusedVars(); + for (auto used_dim_idx : used_dim_bit_vector.set_bits()) { + if (used_dim_idx < new_num_dims) { + used_dims.push_back(aggregate_dims[used_dim_idx]); + } + } + + auto new_loop_op = + rewriter.create(loop_op.getLoc(), replacement->indexing_map, + used_dims, loop_op.getInits()); + Block* original_block = &loop_op.getRegion().front(); + Block* new_block = &new_loop_op.getRegion().front(); + rewriter.mergeBlocks(original_block, new_block, new_block->getArguments()); + rewriter.replaceOp(loop_op, new_loop_op.getResults()); + + return success(); + } +}; + +} // namespace + +VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { + VariableConstraints result; + result.constraints_for_dims.resize(map.GetDimensionCount()); + result.constraints_for_symbols.resize(map.GetSymbolCount()); + for (const auto& constraint : map.GetConstraints()) { + constraint.first.walk([&](mlir::AffineExpr leaf) { + if (auto dim = mlir::dyn_cast(leaf)) { + result.constraints_for_dims[dim.getPosition()].insert(constraint); + } else if (auto sym = mlir::dyn_cast(leaf)) { + result.constraints_for_symbols[sym.getPosition()].insert(constraint); + } + }); + } + return result; +} + +// Parses a comma-separated list of operands, ex: %d1, %d2. +ParseResult parseOperands( + OpAsmParser& parser, + SmallVector* operands) { + OpAsmParser::UnresolvedOperand operand; + return parser.parseCommaSeparatedList( + [&]() { return parser.parseOperand(operands->emplace_back()); }); +} + +std::optional parseChainOfStringsAsIndexingMap( + mlir::AsmParser& parser) { + mlir::StringAttr indexing_map_attr; + std::string indexing_map_str; + while (parser.parseOptionalAttribute(indexing_map_attr).has_value()) { + indexing_map_str.append(indexing_map_attr.getValue()); + } + return ParseIndexingMap(indexing_map_str, parser.getContext()); +} + +void LoopOp::getCanonicalizationPatterns(mlir::RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} + +} // namespace xla + +#define GET_OP_CLASSES +#include "xla/codegen/ir/xla_ops.cc.inc" diff --git a/third_party/xla/xla/codegen/ir/xla_ops.h b/third_party/xla/xla/codegen/ir/xla_ops.h new file mode 100644 index 00000000000000..0888d0485567b3 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_ops.h @@ -0,0 +1,66 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_CODEGEN_IR_XLA_OPS_H_ +#define XLA_CODEGEN_IR_XLA_OPS_H_ + +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep +#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep +#include "mlir/IR/Attributes.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep +#include "mlir/IR/Dialect.h" // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep +#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep +#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep +#include "xla/codegen/ir/xla_dialect.h.inc" +#include "xla/hlo/analysis/indexing_map.h" // IWYU pragma: keep +#define GET_ATTRDEF_CLASSES +#include "xla/codegen/ir/xla_attrs.h.inc" +#define GET_OP_CLASSES +#include "xla/codegen/ir/xla_ops.h.inc" + +namespace xla { + +struct VariableConstraints { + llvm::SmallVector> + constraints_for_dims; + llvm::SmallVector> + constraints_for_symbols; +}; +VariableConstraints GetConstraintsForVariables(const IndexingMap& map); + +// Parses a comma-separated list of operands, ex: %d1, %d2. +mlir::ParseResult parseOperands( + mlir::OpAsmParser& parser, + mlir::SmallVector* operands); + +// Parses a chain of string attributes into an indexing map. +// Example: +// "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)," +// " domain: s0 in [-10, 10], s1 in [0, 2]" +// will be parsed as 3 StringAttrs, concatenated into a single string, and then +// parsed into an IndexingMap. +std::optional parseChainOfStringsAsIndexingMap( + mlir::AsmParser& parser); + +} // namespace xla + +#endif // XLA_CODEGEN_IR_XLA_OPS_H_ diff --git a/third_party/xla/xla/codegen/ir/xla_ops.td b/third_party/xla/xla/codegen/ir/xla_ops.td new file mode 100644 index 00000000000000..3e4702579181c5 --- /dev/null +++ b/third_party/xla/xla/codegen/ir/xla_ops.td @@ -0,0 +1,283 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_CODEGEN_IR_XLA_OPS +#define XLA_CODEGEN_IR_XLA_OPS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "xla/codegen/ir/xla_dialect.td" +include "xla/codegen/ir/xla_attrs.td" + +class XLA_Op traits = []> : + Op { +} + +def XLA_AtomicRMWOp : XLA_Op<"atomic_rmw", + [Pure, + TypesMatchWith<"result type matches type of dest", + "input", "result", "$_self">, + DeclareOpInterfaceMethods + ]> { + let summary = "Atomically updates an element of a tensor."; + + let description = [{ + Reads an element from a tensor, computes the updated value for it, and + writes back the result. + }]; + + let arguments = (ins AnyRankedTensor:$input, Variadic:$indices); + let results = (outs AnyRankedTensor:$result); + // The region takes the current value in the tensor as an argument and yields + // the updated value. + let regions = (region SizedRegion<1>:$computation); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>]; + + let extraClassDeclaration = [{ + mlir::Block* getBody() { return &getComputation().front(); } + mlir::OpBuilder getBodyBuilder() { + return mlir::OpBuilder(getBody(), std::prev(getBody()->end())); + } + // The value stored in tensor[ivs]. + mlir::Value getCurrentValue() { + return getRegion().getArgument(0); + } + }]; + let hasFolder = 1; + + let assemblyFormat = [{ + $input `[` $indices `]` `:` type($input) $computation attr-dict + }]; +} + +def XLA_YieldOp : XLA_Op<"yield", [ + ParentOneOf<["::xla::AtomicRMWOp", "::xla::LoopOp"]>, + ReturnLike, Terminator]> { + let summary = "Terminator for atomic_rmw ops."; + let arguments = (ins Variadic:$result); + + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +def XLA_PureCallOp : XLA_Op<"pure_call", + [Pure, CallOpInterface, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let summary = "Function call without side effects."; + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + let builders = [ + OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "mlir::FlatSymbolRefAttr":$callee, CArg<"mlir::ValueRange", "{}">:$operands, CArg<"llvm::ArrayRef", "{}">:$result_types), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(result_types); + }]>, + ]; + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; + + let extraClassDeclaration = [{ + operand_range getArgOperands() { + return getOperands(); + } + + mlir::MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + }]; +} + +def XLA_PredicatedInsertOp : XLA_Op<"predicated_insert", + [Pure, + TypesMatchWith<"result type matches type of operands", + "dest", "result", "$_self">, + TypesMatchWith<"value type matches element type of dest", + "dest", "value", + "::llvm::cast($_self).getElementType()">]> { + let summary = "Inserts a value into a tensor if a condition holds"; + let arguments = (ins I1:$condition, AnyType:$value, + AnyStaticShapeTensor:$dest, Variadic:$indices); + let results = (outs AnyStaticShapeTensor:$result); + + let assemblyFormat = [{ + $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest) + }]; +} + + +def XLA_PredicatedExtractOp : XLA_Op<"predicated_extract", + [Pure, + TypesMatchWith<"fallback type matches element type of src", + "src", "fallback", + "::llvm::cast($_self).getElementType()">, + TypesMatchWith<"result type matches element type of src", + "src", "result", + "::llvm::cast($_self).getElementType()">]> { + let summary = "Inserts a value into a tensor if a condition holds"; + let arguments = (ins I1:$condition, AnyType:$fallback, + AnyStaticShapeTensor:$src, Variadic:$indices); + let results = (outs AnyType:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src) + }]; +} + +def ApplyIndexingOp : XLA_Op<"apply_indexing", [Pure]> { + let summary = "Applies indexing map to a list of SSA values"; + let description = [{ + The `apply_indexing` operation applies an indexing_map to a list + of SSA values, yielding a single SSA value. The number of dimension and + symbol arguments must be equal to the respective number of dimensional and + symbolic inputs in the indexing_map. The index mapping can be + multi-dimensional, and so the `apply_indexing` operation always returns one + value. The operands and results must all have ‘index’ type. + + Example: + + ```mlir + #map = #xla.indexing_map<(d0, d1)[s0] -> (d0 floordiv 8 + d1 floordiv 128, s0)> + %results:2 = xla_ops.apply_indexing #map (%0 in [0, 10], %1 in [0, 11])[%2 in [11, 32]] + ``` + }]; + let arguments = (ins Variadic:$operands, + XLA_IndexingMapAttr:$indexing_map_attr); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols, + "const IndexingMap&":$indexing_map)>, + OpBuilder<(ins "mlir::ValueRange":$operands, + "const IndexingMap&":$indexing_map)>, + OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map, + "llvm::ArrayRef":$dim_vars, + "llvm::ArrayRef":$range_vars)>, + ]; + let extraClassDeclaration = [{ + // Returns the indexing map constructed from IndexingMapAttr. + xla::IndexingMap getIndexingMap(); + // Extracts the affine map from the attribute. + mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + let hasCanonicalizer = 1; + let hasFolder = 1; +} + +def LoopOp : XLA_Op<"loop", [ + AttrSizedOperandSegments, Pure, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"xla::YieldOp"> + ]> { + let summary = "Loop nest that iterates over all feasible values of RangeVars."; + let description = [{ + + ```mlir + #map = #xla.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), + domain: + d0 in [0, 3], + s0 in [0, 1024], + s1 in [0, 32] + > + // Initial sum set to 0. + %sum_0 = arith.constant 0.0 : f32 + %dim = arith.constant 1 : index + // iter_args binds initial values to the loop's region arguments. + %sum = xla.loop (%dim)[%i, %j] -> (%r0, %r1) + in #map iter_args(%sum_iter = %sum_0) -> (f32) { + %t = tensor.extract %buffer[%i, %j] : tensor<1024x32xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + // Yield current iteration sum to next iteration %sum_iter or to %sum + // if final iteration. + xla.yield %sum_next : f32 + } + ``` + }]; + let arguments = (ins XLA_IndexingMapAttr:$indexing_map_attr, + Variadic:$dims, + Variadic:$inits); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$region); + + let builders = [ + OpBuilder<(ins "IndexingMapAttr":$indexing_map_attr, + "mlir::ValueRange":$dims, "mlir::ValueRange":$inits, + CArg<"llvm::function_ref", "nullptr">)>, + OpBuilder<(ins "const IndexingMap&":$indexing_map, + "mlir::ValueRange":$dims, "mlir::ValueRange":$inits, + CArg<"llvm::function_ref", "nullptr">)> + ]; + + let extraClassDeclaration = [{ + using BodyBuilderFn = llvm::function_ref< + void(mlir::OpBuilder&, mlir::Location, mlir::ValueRange, + mlir::ValueRange, mlir::ValueRange)>; + + // Returns the indexing map constructed from IndexingMapAttr. + xla::IndexingMap getIndexingMap(); + int64_t getNumInductionVars() { + return getBody()->getNumArguments() - getIndexingMapAttr().getNumResults() + - getNumResults(); + } + mlir::BlockArgument getInductionVar(int64_t index) { + return getBody()->getArgument(index); + } + mlir::Block::BlockArgListType getInductionVars() { + return getBody()->getArguments().take_front(getNumInductionVars()); + } + mlir::Block::BlockArgListType getIndexingMapResults() { + return getBody()->getArguments().drop_front(getNumInductionVars()) + .drop_back(getNumResults()); + } + mlir::Block::BlockArgListType getRegionIterArgs() { + return getBody()->getArguments().take_back(getNumResults()); + } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + let hasCanonicalizer = 1; +} + +#endif // XLA_CODEGEN_IR_XLA_OPS + diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/codegen/ir/xla_ops_test.cc similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc rename to third_party/xla/xla/codegen/ir/xla_ops_test.cc index 001477150f771a..91cd2b8adeb3c2 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc +++ b/third_party/xla/xla/codegen/ir/xla_ops_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/codegen/ir/xla_ops.h" #include #include @@ -29,11 +29,10 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" -namespace xla::gpu { +namespace xla { namespace { - -class XLAGPUOpsTest : public HloTestBase { +class XLAOpsTest : public HloTestBase { public: mlir::MLIRContext mlir_context_; }; @@ -80,7 +79,7 @@ std::string VariableConstraintsToString(const IndexingMap& map) { return result; } -TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { +TEST_F(XLAOpsTest, GetConstraintsForVariables) { auto map = *ParseIndexingMap(R"( (x, y)[s0, w] -> (x + s0, y + w), domain: x in [0, 5], @@ -102,7 +101,7 @@ w: s0 + w in [0, 3], w mod 4 in [0, 2], y + w in [0, 4] )"); } -TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) { +TEST_F(XLAOpsTest, GetConstraintsForVariablesEmpty) { auto map = *ParseIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 5], @@ -120,4 +119,4 @@ s1: no constraints } } // namespace -} // namespace xla::gpu +} // namespace xla diff --git a/third_party/xla/xla/codegen/tools/BUILD b/third_party/xla/xla/codegen/tools/BUILD new file mode 100644 index 00000000000000..96e73bff4f1668 --- /dev/null +++ b/third_party/xla/xla/codegen/tools/BUILD @@ -0,0 +1,46 @@ +load("//xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "emitters_opt", + srcs = ["emitters_opt.cc"], + # We want to use this tool for lit tests. Due to hermetic cuda, we need to + # set linkopts in such a way that dynamic libraries are found, which are + # symlinked from the lit_lib directory. + linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], + visibility = [ + "//xla/codegen/ir/tests:__subpackages__", + "//xla/service/gpu/fusions:__subpackages__", + ], + deps = [ + "//xla/codegen/ir:xla", + "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/transforms:passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", + ], +) diff --git a/third_party/xla/xla/codegen/tools/emitters_opt.cc b/third_party/xla/xla/codegen/tools/emitters_opt.cc new file mode 100644 index 00000000000000..940a655245faba --- /dev/null +++ b/third_party/xla/xla/codegen/tools/emitters_opt.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include "xla/codegen/ir/xla_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" + +int main(int argc, char** argv) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::func::registerAllExtensions(registry); + mlir::LLVM::registerInlinerInterface(registry); + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerInliner(); + xla::gpu::registerGpuFusionTransformsPasses(); + mlir::registerPassPipeline( + "xla-gpu-test-optimize", + "Test pipeline of passes up to inlining. No vectorization, also does not " + "lower xla_gpu. Intended to simplify IR in tests.", + [=](mlir::OpPassManager& pm, llvm::StringRef options, + llvm::function_ref + errorHandler) { + if (!options.empty()) return mlir::failure(); + + xla::gpu::AddXlaGpuOpsOptimizationPasses(pm); + return mlir::success(); + }, + [](llvm::function_ref) {}); + mlir::registerPassPipeline( + "xla-gpu-test-transform-loops", + "Test pipeline for vectorization. Should run after " + "xla-gpu-test-to-inline.", + [=](mlir::OpPassManager& pm, llvm::StringRef options, + llvm::function_ref + errorHandler) { + if (!options.empty()) return mlir::failure(); + xla::gpu::AddLoopTransformationPasses( + pm, xla::gpu::TestGpuDeviceInfo::RTXA6000DeviceInfo()); + return mlir::success(); + }, + [](llvm::function_ref) {}); + + return mlir::failed( + MlirOptMain(argc, argv, "XLA Emitters Pass Driver\n", registry)); +} diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 80c504bbff8571..f9510d67170cb7 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -270,6 +270,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/service:scatter_simplifier", diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD index ec19475abc6041..a6fbc1bb04ff1d 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -1,5 +1,4 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -19,6 +18,7 @@ td_library( srcs = glob(["*.td"]), includes = ["."], deps = [ + "//xla/codegen/ir:xla_td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:CallInterfacesTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -79,12 +79,14 @@ gentbl_cc_library( ( [ "-gen-attrdef-decls", + "-attrdefs-dialect=xla_gpu", ], "xla_gpu_attrs.h.inc", ), ( [ "-gen-attrdef-defs", + "-attrdefs-dialect=xla_gpu", ], "xla_gpu_attrs.cc.inc", ), @@ -126,15 +128,14 @@ cc_library( "xla_gpu_ops.cc", "xla_gpu_types.cc", ], - hdrs = [ - "xla_gpu_ops.h", - ], + hdrs = ["xla_gpu_ops.h"], deps = [ ":xla_gpu_attrs_inc_gen", ":xla_gpu_dialect_inc_gen", ":xla_gpu_ops_inc_gen", ":xla_gpu_types_inc_gen", "//xla:status_macros", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -151,21 +152,3 @@ cc_library( "@llvm-project//mlir:Support", ], ) - -xla_test( - name = "xla_gpu_ops_test", - srcs = ["xla_gpu_ops_test.cc"], - backends = ["gpu"], - deps = [ - ":xla_gpu", - "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/analysis:indexing_test_utils", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", - "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:test", - ], -) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD index 381d5a3220b1df..10228fcc460af8 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD @@ -10,7 +10,7 @@ lit_test_suite( srcs = glob(["*.mlir"]), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "//xla/codegen/tools:emitters_opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir index 35064858b23150..c9dce3a9554fb6 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -1,93 +1,6 @@ -// RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics +// RUN: emitters_opt %s -split-input-file -verify-diagnostics -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> -func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} - %0:2 = xla_gpu.apply_indexing #map0 (%d0) - func.return %0#0, %0#1 : index, index -} - -// ----- - -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10]"> -func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { - // expected-error @+1 {{apply indexing op cannot have any constraints}} - %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] - func.return %0#0, %0#1 : index, index -} - -// ----- - -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> -func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, - %init: f32) -> (f32) { - // expected-error @+1 {{mismatch in number of loop-carried values and results}} - %sum:2 = "xla_gpu.loop"(%init) <{ - indexing_map_attr = #map, - operandSegmentSizes = array - }> ({ - ^bb0(%i: index, %j: index, %r0: index, %r1: index, %iter: f32): - %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> - %add = arith.addf %iter, %t : f32 - xla_gpu.yield %add : f32 - }) : (f32) -> (f32, f32) - func.return %sum#0 : f32 -} - -// ----- - -#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024]"> -func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, - %init: f32) -> (f32) { - // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} - %sum = "xla_gpu.loop"(%init) <{ - indexing_map_attr = #map, - operandSegmentSizes = array - }> ({ - ^bb0(%i: index, %j: index, %r0: index, %r1: index, %iter: f32): - %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> - %add = arith.addf %iter, %t : f32 - xla_gpu.yield %add : f32 - }) : (f32) -> (f32) - func.return %sum : f32 -} - -// ----- - -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> - -func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { - // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} - %sum = "xla_gpu.loop"(%init) <{ - indexing_map_attr = #map, - operandSegmentSizes = array - }> ({ - ^bb0(%i: index, %j: index, %r0: index, %r1: index, %iter: f32): - %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> - %add = arith.addf %iter, %t : f32 - xla_gpu.yield %add : f32 - }) : (f32) -> (i32) - func.return %sum : i32 -} - -// ----- - -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> - -func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { - // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} - %sum = xla_gpu.loop ()[%i, %j] -> (%r0, %r1) - in #map iter_args(%sum_ = %init) -> (f32) { - %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> - %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 - } {xla.range = [0 : index, 42 : index]} - func.return %sum : f32 -} - -// ----- - -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> { @@ -99,8 +12,8 @@ func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -112,8 +25,8 @@ func.func @no_thread_id_in(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> -#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -125,8 +38,8 @@ func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -135,8 +48,8 @@ func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: inde // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -149,8 +62,8 @@ func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{number of symbols in both indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -159,8 +72,8 @@ func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, % // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{domain of symbols of indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -169,8 +82,8 @@ func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -182,8 +95,8 @@ func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -195,8 +108,8 @@ func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -209,8 +122,8 @@ func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]"> func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -222,8 +135,8 @@ func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -236,8 +149,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -250,8 +163,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -264,8 +177,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -277,8 +190,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir index 7f85889c29798f..aff9ede7e50346 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -1,8 +1,8 @@ -// RUN: mlir_fusions_opt %s --split-input-file | FileCheck %s +// RUN: emitters_opt %s --split-input-file | FileCheck %s // Verify the printed output can be parsed. -// RUN: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s +// RUN: emitters_opt %s --split-input-file | emitters_opt --split-input-file | FileCheck %s // Verify the generic form can be parsed. -// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s +// RUN: emitters_opt %s --split-input-file --mlir-print-op-generic | emitters_opt --split-input-file | FileCheck %s func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) { %shared1 = xla_gpu.allocate_shared : tensor<2xf32> @@ -19,132 +19,13 @@ func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) { // ----- -func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index) - -> (tensor<2x3xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { - ^bb0(%current : f32): - %c42 = arith.constant 42.0 : f32 - %add = arith.addf %current, %c42 : f32 - xla_gpu.yield %add : f32 - } - return %ret : tensor<2x3xf32> -} -// CHECK-LABEL: @atomic_rmw -// CHECK: xla_gpu.atomic_rmw - -// ----- - -func.func private @add(%a: f32, %b: f32) -> f32 { - %ret = arith.addf %a, %b : f32 - return %ret : f32 -} - -func.func @caller(%a: f32, %b: f32) -> f32 { - %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = arith.addf %c, %d : f32 - return %ret : f32 -} -// CHECK-LABEL: @caller -// CHECK: %[[C:.*]] = xla_gpu.pure_call @add -// CHECK: %[[D:.*]] = xla_gpu.pure_call @add -// CHECK: arith.addf %[[C]], %[[D]] - -// CHECK-CSE: @caller -// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add -// CHECK-CSE: arith.addf %[[C]], %[[C]] - -// ----- - -#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0)," - "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> -func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" -// CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0) -// CHECK-SAME: domain: -// CHECK-SAME: d0 in [1, 2] -// CHECK-SAME: d1 in [5, 8] -// CHECK-SAME: s0 in [0, 32] -// CHECK-SAME: > - -// CHECK-LABEL: @apply_indexing -// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]], %[[d1]])[%[[s0]]] - -// ----- - -#map0 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," - "domain: d0 in [0, 2], d1 in [1, 3]"> -func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" -// CHECK-SAME: (d0, d1) -> (d0, d1) -// CHECK-SAME: domain: -// CHECK-SAME: d0 in [0, 2] -// CHECK-SAME: d1 in [1, 3] -// CHECK-SAME: > - -// CHECK-LABEL: @apply_indexing_no_symbols -// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]], %[[d1]]) - -// ----- - -#map0 = #xla_gpu.indexing_map<"()[s0] -> (s0, s0)," - "domain: s0 in [2, 4]"> -func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0] - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" -// CHECK-SAME: ()[s0] -> (s0, s0) -// CHECK-SAME: domain: -// CHECK-SAME: s0 in [2, 4] -// CHECK-SAME: > - -// CHECK-LABEL: @apply_indexing_no_dims -// CHECK: (%[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]]] - -// ----- - -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), " - "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> -func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, - %dim: index) -> (f32) { - %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) - in #map iter_args(%sum_ = %init) -> (f32) { - %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> - %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 - } {xla.range = [0 : index, 42 : index]} - func.return %sum : f32 -} -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map -// CHECK: xla_gpu.loop (%{{.*}})[%[[I:.*]], %[[J:.*]]] -> -// CHECK-SAME: (%[[R0:.*]], %[[R1:.*]]) in #[[$MAP]] -// CHECK-SAME: iter_args(%[[SUM_ITER:.*]] = %{{.*}}) -> (f32) { -// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[%[[I]], %[[J]]] -// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[EXTRACTED]] : f32 -// CHECK: xla_gpu.yield %[[ADD]] : f32 -// CHECK: } {xla.range = [0 : index, 42 : index]} - -// ----- - func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)," +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)," "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1)," +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1)," "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," +#map2 = #xla.indexing_map<"(d0, d1) -> (d0, d1)," "domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, @@ -156,11 +37,11 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.return %1 : tensor<32x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] -// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1) +// CHECK: #[[$MAP1:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] -// CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1) +// CHECK: #[[$MAP2:.*]] = #xla.indexing_map<"(d0, d1) -> (d0, d1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2]"> // CHECK-LABEL: @materialize_and_insert // CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/types.mlir similarity index 83% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/types.mlir index 6a199f5f024241..a5e52c3d3eb9d7 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/types.mlir @@ -1,6 +1,6 @@ -// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s +// RUN: emitters_opt %s -split-input-file | emitters_opt -split-input-file | FileCheck %s -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: (d0, d1, d2)[s0] -> (d0), // CHECK-SAME: domain: // CHECK-SAME: d0 in [1, 2], @@ -10,7 +10,7 @@ // CHECK-SAME: d0 + s0 in [1, 10], // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: > -#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," +#map = #xla.indexing_map<"(d0, d1, d2)[s0] -> (d0)," "domain:" "d0 in [1, 2]," "d1 in [5, 8]," @@ -26,7 +26,7 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> // ----- -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) // CHECK-SAME: domain: // CHECK-SAME: d0 in [1, 2] @@ -38,7 +38,7 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: d1 + s1 + s2 in [1, 32] // CHECK-SAME: > -#map = #xla_gpu.indexing_map< +#map = #xla.indexing_map< "(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)," "domain:" "d0 in [1, 2]," @@ -56,13 +56,13 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // ----- -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: (d0)[s0] -> (d0) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 100] // CHECK-SAME: s0 in [-3, -1] // CHECK-SAME: > -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0)," "domain:" "d0 in [0, 100]," "s0 in [-3, -1]" @@ -73,7 +73,7 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // ----- -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: (d0, d1, d2)[s0] -> (d0) // CHECK-SAME: domain: // CHECK-SAME: d0 in [1, 2] @@ -81,7 +81,7 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: d2 in [10, 12] // CHECK-SAME: s0 in [0, 32] // CHECK-SAME: > -#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," +#map = #xla.indexing_map<"(d0, d1, d2)[s0] -> (d0)," "domain:" "d0 in [1, 2]," "d1 in [5, 8]," @@ -94,13 +94,13 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // ----- -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: ()[s0] -> (s0) // CHECK-SAME: domain: // CHECK-SAME: s0 in [3, 5] // CHECK-SAME: s0 mod 2 in [0, 1] // CHECK-SAME: > -#map = #xla_gpu.indexing_map<"()[s0] -> (s0)," +#map = #xla.indexing_map<"()[s0] -> (s0)," "domain:" "s0 in [3, 5]," "s0 mod 2 in [0, 1]" @@ -111,13 +111,13 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // ----- -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: (d0) -> (d0) // CHECK-SAME: domain: // CHECK-SAME: d0 in [3, 5] // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: > -#map = #xla_gpu.indexing_map<"(d0) -> (d0)," +#map = #xla.indexing_map<"(d0) -> (d0)," "domain:" "d0 in [3, 5]," "d0 mod 2 in [0, 1]," @@ -128,10 +128,10 @@ func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // ----- -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$INDEX_MAP:.*]] = #xla.indexing_map< // CHECK-SAME: () -> () // CHECK-SAME: > -#map = #xla_gpu.indexing_map<"() -> ()"> +#map = #xla.indexing_map<"() -> ()"> func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @empty // CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 6f10ec6d507153..16de41e05cf5c6 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/analysis/indexing_map_serialization.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" @@ -36,78 +37,7 @@ limitations under the License. namespace xla { namespace gpu { -using llvm::ParseResult; -using llvm::SmallVector; -using mlir::AffineExpr; -using mlir::ArrayRef; -using mlir::AsmParser; using mlir::AsmPrinter; -using mlir::failure; -using mlir::success; - -// Parses a chain of string attributes into an indexing map. -// Example: -// "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)," -// " domain: s0 in [-10, 10], s1 in [0, 2]" -// will be parsed as 3 StringAttrs, concatenated into a single string, and then -// parsed into an IndexingMap. -std::optional parseChainOfStringsAsIndexingMap( - mlir::AsmParser& parser) { - mlir::StringAttr indexing_map_attr; - std::string indexing_map_str; - while (parser.parseOptionalAttribute(indexing_map_attr).has_value()) { - indexing_map_str.append(indexing_map_attr.getValue()); - } - return ParseIndexingMap(indexing_map_str, parser.getContext()); -} - -mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { - if (parser.parseLess()) { - return {}; - } - auto indexing_map = parseChainOfStringsAsIndexingMap(parser); - if (!indexing_map.has_value() || parser.parseGreater()) { - return {}; - } - return IndexingMapAttr::get(parser.getContext(), *indexing_map); -} - -void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<\"" << ToString(getIndexingMap()) << "\">"; -} - -IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, - const IndexingMap& indexing_map) { - llvm::SmallVector> constraints; - for (auto& constraint : indexing_map.GetConstraints()) { - constraints.push_back({constraint.first, constraint.second}); - } - return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(), - indexing_map.GetRangeVars(), constraints); -} - -mlir::LogicalResult IndexingMapAttr::verify( - mlir::function_ref emitError, - mlir::AffineMap map, ArrayRef dim_vars, - ArrayRef range_vars, - ArrayRef> constraints) { - auto indexing_map = - IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, constraints); - std::stringstream ss; - if (!indexing_map.Verify(ss)) { - return emitError() << ss.str(); - } - return success(); -} - -IndexingMap IndexingMapAttr::getIndexingMap() const { - return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{}, - getConstraints()); -} - -int64_t IndexingMapAttr::getNumResults() const { - return getMap().getNumResults(); -} mlir::Attribute LayoutAttr::parse(mlir::AsmParser& parser, mlir::Type) { mlir::StringAttr memory_space_str; @@ -126,6 +56,7 @@ mlir::Attribute LayoutAttr::parse(mlir::AsmParser& parser, mlir::Type) { return {}; } auto* context = parser.getContext(); + context->getOrLoadDialect(); return LayoutAttr::get(context, MemorySpaceAttr::get(context, *memspace), IndexingMapAttr::get(context, *indexing_map)); } diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index cc01ce257b3fc1..858d4ec82278ec 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -19,49 +19,12 @@ limitations under the License. include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/codegen/ir/xla_attrs.td" class XLAGPU_Attr traits = []> : AttrDef { } -def XLAGPU_AffineMapParameter : - AttrOrTypeParameter<"::mlir::AffineMap", ""> { -} - -def XLAGPU_IndexingMapVariableParameter - : ArrayRefParameter<"::xla::IndexingMap::Variable", - "IndexingMapVariableArray"> { -} - -def XLAGPU_ConstraintsParameter : - ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::Interval>", - "ContraintsArray"> { -} - -def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { - let summary = "An Attribute representing an indexing map."; - let mnemonic = "indexing_map"; - let description = [{This attribute stores an indexing map. See - https://openxla.org/xla/indexing for more details. - }]; - let parameters = (ins XLAGPU_AffineMapParameter:$map, - XLAGPU_IndexingMapVariableParameter:$dim_vars, - XLAGPU_IndexingMapVariableParameter:$range_vars, - XLAGPU_ConstraintsParameter:$constraints); - let hasCustomAssemblyFormat = 1; - let builders = [ - AttrBuilder<(ins "const ::xla::IndexingMap&":$indexing_map)>, - ]; - let genVerifyDecl = 1; - let extraClassDeclaration = [{ - // Returns the indexing map constructed from IndexingMapAttr. - xla::IndexingMap getIndexingMap() const; - - // Returns the number of indexing map results. - int64_t getNumResults() const; - }]; -} - //===----------------------------------------------------------------------===// // Tensor layout attribute //===----------------------------------------------------------------------===// @@ -98,5 +61,4 @@ def XLAGPU_LayoutAttr : XLAGPU_Attr<"Layout"> { let hasCustomAssemblyFormat = 1; } - #endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc index c46561a98d0d45..b7ee5f43d9d68a 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc @@ -30,91 +30,11 @@ namespace xla { namespace gpu { namespace { -struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - // Returns true if the given operation 'callable', that implements the - // 'CallableOpInterface', can be inlined into the position given call - // operation 'call', that is registered to the current dialect and implements - // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the - // given 'callable' is set to be cloned during the inlining process, or false - // if the region is set to be moved in-place (i.e. no duplicates would be - // created). - bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, - bool wouldBeCloned) const final { - if (!wouldBeCloned) { - // If no duplicate would be created, 'call' is likely the only caller of - // 'callable'. - return true; - } - // Otherwise, inline only if the called function is small. We could - // theoretically also inline if there is no other caller in the function - // that contains the callee that has a call path to the callable, but that - // is more expensive to check. - auto func_op = mlir::dyn_cast(callable); - if (!func_op) { - return false; - } - auto region = func_op.getCallableRegion(); - if (!region) { - return false; - } - - // If callee and caller call the same third function, inline. We have no - // guarantee that the indices are the same, but there is a good chance they - // are (or if the callee gets inlined as well, there will be CSE - // opportunities). - // This is duct tape to work around the limitations of our partitioner. - // Ideally, the partitioner would be aware of the actual indexing and create - // the partitions based on it (i.e., the case where the indices are the same - // would never happen). - llvm::SmallDenseSet callee_calls; - for (auto call : region->getOps()) { - callee_calls.insert(call.getCallee()); - } - for (auto call : call->getParentRegion()->getOps()) { - if (callee_calls.contains(call.getCallee())) { - return true; - } - } - - constexpr int kMaxOperationsToInline = 8; - int num_ops = 0; - region->front().walk([&](mlir::Operation* op) { ++num_ops; }); - - // Don't inline functions that are called more than once and contain more - // than one call themselves. - return num_ops <= kMaxOperationsToInline; - } - // Returns true if the given operation 'op', that is registered to this - // dialect, can be inlined into the given region, false otherwise. - // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned - // during the inlining process, or false if the operation is set to be moved - // in-place(i.e. no duplicates would be created). 'valueMapping' contains any - // remapped values from within the 'src' region. This can be used to examine - // what values may potentially replace the operands to 'op'. - bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, - bool wouldBeCloned, - mlir::IRMapping& valueMapping) const final { - // We allow any op from the xla_gpu dialect to be inlined. - return true; - } - // We don't have any special restrictions on what can be inlined into - // destination regions (e.g. while/conditional bodies). Always allow it. - bool isLegalToInline(mlir::Region* dest, mlir::Region* src, - bool wouldBeCloned, - mlir::IRMapping& valueMapping) const final { - return true; - } -}; struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(mlir::Attribute attr, mlir::raw_ostream& os) const final { - if (llvm::isa(attr)) { - os << "indexing_map"; - return AliasResult::FinalAlias; - } if (llvm::isa(attr)) { os << "layout"; return AliasResult::FinalAlias; @@ -134,7 +54,7 @@ void XlaGpuDialect::initialize() { #define GET_ATTRDEF_LIST #include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" >(); - addInterfaces(); + addInterfaces(); addTypes< #define GET_TYPEDEF_LIST #include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 9c2ec518975e22..bdb4a8cc516fb8 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -87,28 +87,6 @@ namespace arith = mlir::arith; } // namespace -LogicalResult PureCallOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTable) { - auto callee = getCalleeAttr(); - auto function = - symbolTable.lookupNearestSymbolFrom(*this, callee); - if (!function) { - return emitError("'f' attribute refers to an undefined function: ") - << callee; - } - - int func_arg_count = function.getFunctionType().getNumInputs(); - int arg_count = getOperands().size(); - - if (arg_count != func_arg_count) { - return emitError() << "argument count mismatch: 'operands' has " - << arg_count << " arguments, but '" << callee - << "' expects " << func_arg_count; - } - - return success(); -} - //===----------------------------------------------------------------------===// // AllocateSharedOp //===----------------------------------------------------------------------===// @@ -118,846 +96,10 @@ void AllocateSharedOp::getAsmResultNames( setNameFn(getResult(), "shmem"); } -//===----------------------------------------------------------------------===// -// ApplyIndexingOp -//===----------------------------------------------------------------------===// - -void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, - ValueRange dims, ValueRange symbols, - const IndexingMap& indexing_map) { - SmallVector operands; - operands.reserve(dims.size() + symbols.size()); - operands.append(dims.begin(), dims.end()); - operands.append(symbols.begin(), symbols.end()); - build(builder, result, operands, indexing_map); -} - -void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, - ValueRange operands, - const IndexingMap& indexing_map) { - SmallVector result_types(indexing_map.GetAffineMap().getNumResults(), - builder.getIndexType()); - IndexingMapAttr indexing_map_attr = - IndexingMapAttr::get(builder.getContext(), indexing_map); - build(builder, result, result_types, operands, indexing_map_attr); -} - -void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, - ValueRange operands, AffineMap affine_map, - ArrayRef dim_vars, - ArrayRef range_vars) { - IndexingMap indexing_map(affine_map, dim_vars, range_vars, {}); - build(builder, result, operands, indexing_map); -} - -// Parses a comma-separated list of operands, ex: %d1, %d2. -ParseResult parseOperands( - OpAsmParser& parser, - SmallVector* operands) { - OpAsmParser::UnresolvedOperand operand; - return parser.parseCommaSeparatedList( - [&]() { return parser.parseOperand(operands->emplace_back()); }); -} - -ParseResult ApplyIndexingOp::parse(OpAsmParser& parser, - OperationState& result) { - mlir::Builder& builder = parser.getBuilder(); - auto index_type = builder.getIndexType(); - - IndexingMapAttr indexing_map_attr; - if (parser.parseAttribute(indexing_map_attr, "indexing_map_attr", - result.attributes)) { - return failure(); - } - - SmallVector operands; - SmallVector lower_bounds, upper_bounds; - if (succeeded(parser.parseOptionalLParen())) { - if (parseOperands(parser, &operands) || parser.parseRParen()) { - return failure(); - } - } - if (succeeded(parser.parseOptionalLSquare())) { - if (parseOperands(parser, &operands) || parser.parseRSquare()) { - return failure(); - } - } - if (parser.resolveOperands(operands, index_type, result.operands) || - parser.parseOptionalAttrDict(result.attributes)) { - return failure(); - } - auto map = indexing_map_attr.getIndexingMap().GetAffineMap(); - result.addTypes(SmallVector(map.getNumResults(), index_type)); - return success(); -} - -void ApplyIndexingOp::print(OpAsmPrinter& p) { - AffineMap affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); - p << " " << getIndexingMapAttr(); - - auto operands = getOperands(); - unsigned num_dimensions = affine_map.getNumDims(); - if (num_dimensions > 0) { - p << '('; - auto dimension_operands = operands.slice(0, num_dimensions); - llvm::interleaveComma(dimension_operands, p); - p << ')'; - } - - unsigned num_symbols = affine_map.getNumSymbols(); - if (num_symbols > 0) { - p << '['; - auto symbol_operands = operands.slice(num_dimensions, num_symbols); - llvm::interleaveComma(symbol_operands, p); - p << ']'; - } - - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"indexing_map_attr"}); -} - -LogicalResult ApplyIndexingOp::verify() { - auto affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); - unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); - if (getOperands().size() != num_variables) { - return emitOpError( - "operand count must match the number of dimensions and symbols in the " - "affine map"); - } - if (!getIndexingMap().GetConstraints().empty()) { - return emitOpError("apply indexing op cannot have any constraints"); - } - return success(); -} - -IndexingMap ApplyIndexingOp::getIndexingMap() { - return getIndexingMapAttr().getIndexingMap(); -} - -namespace { -struct IndexingMapWithAdditions { - IndexingMap indexing_map; - SmallVector added_dim_args; -}; - -// Returns the new indexing map after replacing ops. -// Only supports replacing operands that are dimensions. -absl::StatusOr GetNewIndexingMapAfterFoldingSequence( - IndexingMap indexing_map, - SmallVector, 2> apply_indexing_ops, - mlir::DenseMap operand_exprs, MLIRContext* ctx) { - int num_dims = indexing_map.GetDimensionCount(); - - SmallVector added_dim_args; - auto new_dim_vars = indexing_map.GetDimVars(); - - mlir::DenseMap replacements; - for (auto& [operand_id, producer] : apply_indexing_ops) { - auto producer_map = producer.getIndexingMap(); - mlir::OpResult producer_result = producer->getOpResult(0); - int producer_result_id = producer_result.getResultNumber(); - int num_producer_dims = producer.getAffineMap().getNumDims(); - SmallVector producer_dim_replacements; - for (auto& producer_operand : producer->getOpOperands()) { - int producer_operand_number = producer_operand.getOperandNumber(); - if (producer_operand_number >= num_producer_dims) { - // MoveSymbolsToDims will convert symbols to dimensions later. - return absl::Status( - absl::StatusCode::kInvalidArgument, - absl::StrCat("In apply_indexing_op ", operand_id, - " producer has operand number ", - producer_operand_number, " outside of dimensions [0,", - num_producer_dims, ")")); - } - auto& replacement_expr = operand_exprs[producer_operand.get()]; - if (!replacement_expr) { - int dim_num = producer_operand_number; - replacement_expr = - getAffineDimExpr(num_dims + added_dim_args.size(), ctx); - added_dim_args.push_back(producer_operand.get()); - new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); - } - producer_dim_replacements.push_back(replacement_expr); - } - replacements[operand_exprs[producer_result]] = - producer.getAffineMap() - .getResult(producer_result_id) - .replaceDims(producer_dim_replacements); - } - - auto new_affine_map = indexing_map.GetAffineMap().replace( - replacements, num_dims + added_dim_args.size(), - indexing_map.GetSymbolCount()); - IndexingMap new_indexing_map(new_affine_map, new_dim_vars, - indexing_map.GetRangeVars(), - indexing_map.GetRTVars()); - - return IndexingMapWithAdditions{new_indexing_map, added_dim_args}; -} -} // namespace - -namespace { - -// Simplifies the indexing map. -struct SimplifyIndexingMap : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, - PatternRewriter& rewriter) const override { - IndexingMap indexing_map = indexing_op.getIndexingMap(); - if (!indexing_map.Simplify()) { - return rewriter.notifyMatchFailure(indexing_op, - "IndexingMap is already simplified"); - } - rewriter.replaceOpWithNewOp( - indexing_op, indexing_op.getOperands(), indexing_map); - return success(); - } -}; - -// Removes unused variables. -struct RemoveUnusedVariables : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, - PatternRewriter& rewriter) const override { - IndexingMap indexing_map = indexing_op.getIndexingMap(); - auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars(); - if (unused_symbols_bit_vector.count() == 0) { - return rewriter.notifyMatchFailure(indexing_op, - "IndexingMap stayed unchanged"); - } - SmallVector operands; - operands.reserve(unused_symbols_bit_vector.count()); - for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) { - if (!unused_symbols_bit_vector[i]) { - operands.push_back(indexing_op.getOperand(i)); - } - } - rewriter.replaceOpWithNewOp(indexing_op, operands, - indexing_map); - return success(); - } -}; - -struct MoveSymbolsToDims : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, - PatternRewriter& rewriter) const override { - IndexingMap indexing_map = indexing_op.getIndexingMap(); - if (indexing_map.GetSymbolCount() == 0) { - return rewriter.notifyMatchFailure(indexing_op, "No symbols found"); - } - rewriter.replaceOpWithNewOp( - indexing_op, indexing_op->getOperands(), - indexing_map.ConvertSymbolsToDimensions()); - return success(); - } -}; - -struct FoldApplyIndexingSequence - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, - PatternRewriter& rewriter) const override { - auto indexing_map = indexing_op.getIndexingMap(); - SmallVector, 2> apply_indexing_ops; - bool all_apply_indexing_operands_have_one_use = true; - for (auto& operand : indexing_op->getOpOperands()) { - if (auto producer = operand.get().getDefiningOp()) { - apply_indexing_ops.push_back({operand.getOperandNumber(), producer}); - all_apply_indexing_operands_have_one_use &= producer->hasOneUse(); - } - } - if (apply_indexing_ops.empty()) { - return rewriter.notifyMatchFailure(indexing_op, - "No apply_indexing sequences found"); - } - // If the indexing map has unused variables, we can accidentally fuse an - // operand that is not used in the map and it can lead to an infinite loop - // in canonicalizer. - auto indexing_map_with_no_unused_vars = indexing_map; - if (indexing_map_with_no_unused_vars.RemoveUnusedVars().count() > 0) { - indexing_map_with_no_unused_vars.RemoveUnusedVars(); - return rewriter.notifyMatchFailure(indexing_op, - "IndexingMap has unused variables"); - } - MLIRContext* ctx = indexing_op.getContext(); - int num_dims = indexing_op.getAffineMap().getNumDims(); - int num_syms = indexing_op.getAffineMap().getNumSymbols(); - mlir::DenseMap operand_exprs; - for (auto& operand : indexing_op->getOpOperands()) { - int operand_number = operand.getOperandNumber(); - operand_exprs[operand.get()] = - operand_number < num_dims - ? getAffineDimExpr(operand_number, ctx) - : getAffineSymbolExpr(operand_number - num_dims, ctx); - } - - auto replacement = GetNewIndexingMapAfterFoldingSequence( - indexing_map, apply_indexing_ops, operand_exprs, ctx); - - if (!replacement.ok()) { - return rewriter.notifyMatchFailure(indexing_op, - replacement.status().message()); - } - - if (!all_apply_indexing_operands_have_one_use && - !replacement->indexing_map.Simplify()) { - return rewriter.notifyMatchFailure( - indexing_op, "Folded indexing map was not simplified"); - } - - int new_num_operands = - indexing_op->getNumOperands() + replacement->added_dim_args.size(); - SmallVector new_operands; - new_operands.reserve(new_num_operands); - - auto begin = indexing_op.getOperands().begin(); - new_operands.append(begin, begin + num_dims); - new_operands.append(replacement->added_dim_args); - new_operands.append(begin + num_dims, begin + num_dims + num_syms); - - rewriter.replaceOpWithNewOp(indexing_op, new_operands, - replacement->indexing_map); - - return success(); - } -}; - -// Folds constants into the indexing map. -struct FoldApplyIndexingOperands - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, - PatternRewriter& rewriter) const override { - IndexingMap indexing_map = indexing_op.getIndexingMap(); - AffineMap affine_map = indexing_map.GetAffineMap(); - - MLIRContext* ctx = affine_map.getContext(); - unsigned num_operands = indexing_op->getNumOperands(); - unsigned num_dims = affine_map.getNumDims(); - unsigned num_symbols = affine_map.getNumSymbols(); - - SmallVector> constant_values(num_operands, - std::nullopt); - int num_constants = 0; - for (auto& operand : indexing_op->getOpOperands()) { - if (auto constant = - operand.get().getDefiningOp()) { - constant_values[operand.getOperandNumber()] = constant.value(); - ++num_constants; - } - } - if (num_constants == 0) { - return rewriter.notifyMatchFailure(indexing_op, - "No constant operands found"); - } - SmallVector dim_replacements, symbol_replacements; - dim_replacements.reserve(num_dims); - symbol_replacements.reserve(num_symbols); - - unsigned new_num_operands = indexing_op->getNumOperands() - num_constants; - SmallVector new_operands; - new_operands.reserve(new_num_operands); - SmallVector new_dim_vars; - new_dim_vars.reserve(num_dims); - - unsigned new_num_dims = 0; - for (auto [operand, constant_value] : - llvm::zip(indexing_op->getOpOperands(), constant_values)) { - unsigned operand_id = operand.getOperandNumber(); - if (operand_id >= num_dims) { - return rewriter.notifyMatchFailure( - indexing_op, - absl::StrCat("operand ", operand_id, - " is outside of dimension range [0, ", num_dims, ")")); - } - if (constant_value.has_value()) { - dim_replacements.push_back(getAffineConstantExpr(*constant_value, ctx)); - } else { - new_operands.push_back(operand.get()); - dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); - new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); - } - } - rewriter.replaceOpWithNewOp( - indexing_op, new_operands, - affine_map.replaceDimsAndSymbols(dim_replacements, {}, new_num_dims, 0), - new_dim_vars, SmallVector()); - return success(); - } -}; - -// Folds constant and dim/symbol expression results. -struct FoldApplyIndexingResults - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, - PatternRewriter& rewriter) const override { - Location loc = indexing_op.getLoc(); - IndexingMap indexing_map = indexing_op.getIndexingMap(); - if (indexing_map.IsKnownEmpty()) { - return rewriter.notifyMatchFailure(indexing_op, - "Domain of the indexing map is empty"); - } - AffineMap* affine_map = &indexing_map.GetMutableAffineMap(); - unsigned num_results = affine_map->getNumResults(); - SmallVector new_exprs; - new_exprs.reserve(num_results); - SmallVector new_values; - new_values.reserve(num_results); - for (mlir::OpResult opresult : indexing_op->getOpResults()) { - if (opresult.use_empty()) { - new_values.push_back(rewriter.create(loc, 0)); - continue; - } - - unsigned id = opresult.getResultNumber(); - AffineExpr result_expr = affine_map->getResult(id); - if (auto const_expr = - mlir::dyn_cast(result_expr)) { - new_values.push_back(rewriter.create( - loc, const_expr.getValue())); - continue; - } - if (auto dim_expr = mlir::dyn_cast(result_expr)) { - new_values.push_back(indexing_op.getOperand(dim_expr.getPosition())); - continue; - } - if (auto symbol_expr = - mlir::dyn_cast(result_expr)) { - new_values.push_back(indexing_op.getOperand( - indexing_map.GetDimVarsCount() + symbol_expr.getPosition())); - continue; - } - new_exprs.push_back(result_expr); - new_values.push_back(Value{}); - } - if (new_exprs.size() == num_results) { - return rewriter.notifyMatchFailure( - indexing_op, "No constant or dim/symbol expression found"); - } - *affine_map = - AffineMap::get(affine_map->getNumDims(), affine_map->getNumSymbols(), - new_exprs, affine_map->getContext()); - auto new_indexing_op = rewriter.create( - loc, indexing_op.getOperands(), indexing_map); - for (int new_result_id = 0, new_indexing_op_result_id = 0; - new_result_id < new_values.size(); ++new_result_id) { - auto& new_value = new_values[new_result_id]; - if (new_value) continue; - new_value = new_indexing_op.getResult(new_indexing_op_result_id++); - } - rewriter.replaceOp(indexing_op, new_values); - return success(); - } -}; - -} // namespace - -void ApplyIndexingOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& results, MLIRContext* context) { - results.add(context); -} - -mlir::LogicalResult ApplyIndexingOp::fold( - FoldAdaptor adaptor, llvm::SmallVectorImpl& results) { - auto map = getAffineMap(); - for (auto expr : map.getResults()) { - if (auto dim = mlir::dyn_cast(expr)) { - results.push_back(getOperand(dim.getPosition())); - } else if (auto sym = mlir::dyn_cast(expr)) { - results.push_back(getOperand(map.getNumDims() + sym.getPosition())); - } else { - results.clear(); - return failure(); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// AtomicRMWOp -//===----------------------------------------------------------------------===// - -void AtomicRMWOp::getAsmResultNames( - llvm::function_ref setNameFn) { - setNameFn(getResult(), "atomic_rmw"); -} - -void AtomicRMWOp::build(OpBuilder& builder, OperationState& result, - Value tensor, ValueRange ivs) { - OpBuilder::InsertionGuard g(builder); - result.addOperands(tensor); - result.addOperands(ivs); - result.addTypes(tensor.getType()); - - auto tensor_type = llvm::cast(tensor.getType()); - Region* body = result.addRegion(); - builder.createBlock(body); - body->addArgument(tensor_type.getElementType(), tensor.getLoc()); -} - -mlir::OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { - auto* body = getBody(); - if (&body->front() == body->getTerminator() && - body->front().getOperand(0) == body->getArgument(0)) { - return getOperand(0); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// PureCallOp -//===----------------------------------------------------------------------===// - -void PureCallOp::getAsmResultNames( - llvm::function_ref setNameFn) { - for (auto result : getResults()) { - setNameFn(result, "pure_call"); - } -} - -//===----------------------------------------------------------------------===// -// SyncThreadsOp -//===----------------------------------------------------------------------===// - -void SyncThreadsOp::getAsmResultNames( - llvm::function_ref setNameFn) { - for (auto result : getResults()) { - setNameFn(result, "synced_tensor"); - } -} - -//===----------------------------------------------------------------------===// -// LoopOp -//===----------------------------------------------------------------------===// - -void LoopOp::getAsmResultNames( - llvm::function_ref setNameFn) { - for (auto result : getResults()) { - setNameFn(result, "xla_loop"); - } -} - -void LoopOp::getAsmBlockArgumentNames(mlir::Region& region, - mlir::OpAsmSetValueNameFn setFn) { - // i, j, k, l, m, n, n_0, n_1, ... - char iv_name = 'i'; - for (auto iv : getInductionVars()) { - setFn(iv, std::string{iv_name}); - if (iv_name <= 'n') { - ++iv_name; - } - } - // ra, rb, rc, rd, ..., rz, rz_0, ... - std::string map_result_name = "ra"; - char map_result_char = 'a'; - for (auto map_result : getIndexingMapResults()) { - setFn(map_result, map_result_name); - if (map_result_char <= 'z') { - ++map_result_char; - map_result_name[1] = map_result_char; - } - } - for (auto iv : getRegionIterArgs()) { - setFn(iv, "iter"); - } -} - -void LoopOp::build(OpBuilder& builder, OperationState& result, - IndexingMapAttr indexing_map_attr, ValueRange dims, - ValueRange inits, BodyBuilderFn bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - - int64_t num_ivs = indexing_map_attr.getRangeVars().size(); - int64_t num_indexing_map_results = - indexing_map_attr.getIndexingMap().GetNumResults(); - int64_t num_inits = inits.size(); - result.addOperands(dims); - result.addOperands(inits); - result.addTypes(TypeRange(inits)); - Block* body_block = builder.createBlock(result.addRegion()); - // Add induction variables and indexing map results block args. - for (int i = 0, e = num_ivs + num_indexing_map_results; i < e; ++i) { - body_block->addArgument(builder.getIndexType(), result.location); - } - // Add iteration arguments block args. - for (auto init_type : TypeRange(inits)) { - body_block->addArguments(init_type, result.location); - } - mlir::OperationName opname(LoopOp::getOperationName(), builder.getContext()); - result.addAttribute(LoopOp::getIndexingMapAttrAttrName(opname), - indexing_map_attr); - result.addAttribute( - LoopOp::getOperandSegmentSizesAttrName(opname), - builder.getDenseI32ArrayAttr({static_cast(dims.size()), - static_cast(inits.size())})); - if (bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(body_block); - bodyBuilder( - builder, result.location, - body_block->getArguments().take_front(num_ivs), - body_block->getArguments().drop_front(num_ivs).drop_back(num_inits), - body_block->getArguments().take_back(num_inits)); - } -} - -void LoopOp::build(OpBuilder& builder, OperationState& result, - const IndexingMap& indexing_map, ValueRange dims, - ValueRange inits, BodyBuilderFn bodyBuilder) { - build(builder, result, - IndexingMapAttr::get(builder.getContext(), indexing_map), dims, inits, - bodyBuilder); -} - -ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector region_args, ivs, map_results, - iter_args; - SmallVector dim_operands; - - // Parse the dimension values. - auto* ctx = parser.getContext(); - OpBuilder b(ctx); - Type index_type = b.getIndexType(); - if (parser.parseOperandList(dim_operands, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(dim_operands, index_type, result.operands)) - return failure(); - // Parse the induction variables. - if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Square)) - return failure(); - for (auto iv : ivs) { - region_args.push_back(iv); - region_args.back().type = index_type; - } - - // Parse the indexing map results variables. - if (parser.parseArrow() || - parser.parseArgumentList(map_results, OpAsmParser::Delimiter::Paren)) - return failure(); - for (auto map_result : map_results) { - region_args.push_back(map_result); - region_args.back().type = index_type; - } - - // Parse the indexing map attribute. - IndexingMapAttr indexing_map_attr; - if (parser.parseKeyword("in") || - parser.parseAttribute(indexing_map_attr, "indexing_map_attr", - result.attributes)) { - return failure(); - } - - // Parse the arguments. - SmallVector init_operands; - if (parser.parseKeyword("iter_args") || - parser.parseAssignmentList(iter_args, init_operands) || - parser.parseArrowTypeList(result.types) || - parser.resolveOperands(init_operands, result.types, parser.getNameLoc(), - result.operands)) - return failure(); - - for (auto [index, iter_arg] : llvm::enumerate(iter_args)) { - region_args.push_back(iter_arg); - region_args.back().type = result.types[index]; - } - - if (region_args.size() != - result.types.size() + ivs.size() + map_results.size()) { - return parser.emitError( - parser.getNameLoc(), - "mismatch in number of induction variables + loop-carried values + " - "number of indexing map results variables and the number of results"); - } - - // Parse the body region. - Region* body = result.addRegion(); - if (parser.parseRegion(*body, region_args)) return failure(); - LoopOp::ensureTerminator(*body, b, result.location); - - // Add the necessary attributes. - result.addAttribute( - LoopOp::getOperandSegmentSizeAttr(), - b.getDenseI32ArrayAttr({static_cast(dim_operands.size()), - static_cast(iter_args.size())})); - - // Parse the optional attribute list - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - return success(); -} - -void LoopOp::print(OpAsmPrinter& p) { - p << " (" << getDims() << ")[" << getInductionVars() << "] -> (" - << getIndexingMapResults() << ") in " << getIndexingMapAttr() - << " iter_args("; - llvm::interleaveComma( - llvm::zip(getRegionIterArgs(), getInits()), p, - [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); - p << ") -> (" << getInits().getTypes() << ") "; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{ - getIndexingMapAttrAttrName(), - getOperandSegmentSizesAttrName(), - }); -} - -LogicalResult LoopOp::verify() { - if (getInits().size() != getNumResults()) { - return emitOpError("mismatch in number of loop-carried values and results"); - } - IndexingMap indexing_map = getIndexingMap(); - if (indexing_map.GetRangeVarsCount() != getNumInductionVars()) { - return emitOpError() << "mismatch in number of induction variables " - << getNumInductionVars() - << " and RangeVars in the indexing map " - << ToString(indexing_map); - } - if (indexing_map.GetDimVarsCount() != getDims().size()) { - return emitOpError() << "mismatch in number of dims operands " - << getDims().size() - << " and DimVars in the indexing map " - << ToString(indexing_map); - } - for (auto [bb_arg, result_type, init] : - llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) { - if (bb_arg.getType() != result_type || init.getType() != result_type) { - return emitOpError() << "block iter arg type = " << bb_arg.getType() - << ", result type = " << result_type - << " and init operand type = " << init.getType() - << " should match"; - } - } - return success(); -} - -IndexingMap LoopOp::getIndexingMap() { - return getIndexingMapAttr().getIndexingMap(); -} - -namespace { -// This pattern is tasked to simplify loop(apply_indexing, ...) patterns to fold -// the apply_indexing ops into the loop op where it can. The pattern assumes -// that the apply_indexing ops have already been simplified via -// MoveSymbolsToDims pattern, which basically means that the producer -// apply_indexing ops should not have any symbols. -struct SimplifyLoopOfApplyIndexing : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LoopOp loop_op, - PatternRewriter& rewriter) const override { - auto loop_indexing_map = loop_op.getIndexingMap(); - MLIRContext* ctx = loop_op.getContext(); - int num_dims = loop_indexing_map.GetDimVarsCount(); - - SmallVector, 2> apply_indexing_ops; - bool all_apply_indexing_operands_have_one_use = true; - - // Only consider dims. - for (auto& operand : loop_op->getOpOperands().take_front(num_dims)) { - if (auto producer = operand.get().getDefiningOp()) { - // Producer should be canonicalized via MoveSymbolsToDims pattern. - if (producer.getIndexingMap().GetSymbolCount() > 0) { - continue; - } - apply_indexing_ops.push_back({operand.getOperandNumber(), producer}); - all_apply_indexing_operands_have_one_use &= producer->hasOneUse(); - } - } - if (apply_indexing_ops.empty()) { - return rewriter.notifyMatchFailure( - loop_op, - "No loop(apply_indexing) patterns found. Note that producer " - "apply_indexing should have already been simplified via " - "MoveSymbolsToDims pattern."); - } - - mlir::DenseMap operand_exprs; - for (auto& operand : loop_op->getOpOperands().take_front(num_dims)) { - int operand_number = operand.getOperandNumber(); - operand_exprs[operand.get()] = getAffineDimExpr(operand_number, ctx); - } - - auto replacement = GetNewIndexingMapAfterFoldingSequence( - loop_indexing_map, apply_indexing_ops, operand_exprs, ctx); - - if (!replacement.ok()) { - return rewriter.notifyMatchFailure(loop_op, - replacement.status().message()); - } - - if (!all_apply_indexing_operands_have_one_use && - !replacement->indexing_map.Simplify()) { - return rewriter.notifyMatchFailure( - loop_op, "Folded indexing map of the loop op was not simplified"); - } - - int new_num_dims = num_dims + replacement->added_dim_args.size(); - SmallVector aggregate_dims; - aggregate_dims.reserve(new_num_dims); - - auto begin = loop_op.getOperands().begin(); - aggregate_dims.append(begin, begin + num_dims); - aggregate_dims.append(replacement->added_dim_args); - - // Remove unused dims. - SmallVector used_dims; - used_dims.reserve(aggregate_dims.size()); - auto used_dim_bit_vector = ~replacement->indexing_map.RemoveUnusedVars(); - for (auto used_dim_idx : used_dim_bit_vector.set_bits()) { - if (used_dim_idx < new_num_dims) { - used_dims.push_back(aggregate_dims[used_dim_idx]); - } - } - - auto new_loop_op = - rewriter.create(loop_op.getLoc(), replacement->indexing_map, - used_dims, loop_op.getInits()); - Block* original_block = &loop_op.getRegion().front(); - Block* new_block = &new_loop_op.getRegion().front(); - rewriter.mergeBlocks(original_block, new_block, new_block->getArguments()); - rewriter.replaceOp(loop_op, new_loop_op.getResults()); - - return success(); - } -}; -} // namespace - -void LoopOp::getCanonicalizationPatterns(mlir::RewritePatternSet& results, - MLIRContext* context) { - results.add(context); -} - //===----------------------------------------------------------------------===// // MaterializeOp //===----------------------------------------------------------------------===// -VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { - VariableConstraints result; - result.constraints_for_dims.resize(map.GetDimensionCount()); - result.constraints_for_symbols.resize(map.GetSymbolCount()); - for (const auto& constraint : map.GetConstraints()) { - constraint.first.walk([&](mlir::AffineExpr leaf) { - if (auto dim = mlir::dyn_cast(leaf)) { - result.constraints_for_dims[dim.getPosition()].insert(constraint); - } else if (auto sym = mlir::dyn_cast(leaf)) { - result.constraints_for_symbols[sym.getPosition()].insert(constraint); - } - }); - } - return result; -} - LogicalResult MaterializeOp::verify() { IndexingMap map_in = getMap().getIndexingMap(); IndexingMap map_out = @@ -1219,6 +361,17 @@ void ShuffleReduceOp::print(OpAsmPrinter& p) { p << " : " << TypeRange(getResultTypes()); } +//===----------------------------------------------------------------------===// +// SyncThreadsOp +//===----------------------------------------------------------------------===// + +void SyncThreadsOp::getAsmResultNames( + llvm::function_ref setNameFn) { + for (auto result : getResults()) { + setNameFn(result, "synced_tensor"); + } +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h index 2d95ed9c710ec7..bec4116943f732 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" // IWYU pragma: keep #include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" #include "xla/service/gpu/fusions/ir/xla_gpu_enums.h.inc" @@ -40,16 +41,4 @@ limitations under the License. #define GET_OP_CLASSES #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h.inc" -namespace xla::gpu { - -struct VariableConstraints { - llvm::SmallVector> - constraints_for_dims; - llvm::SmallVector> - constraints_for_symbols; -}; -VariableConstraints GetConstraintsForVariables(const IndexingMap& map); - -} // namespace xla::gpu - #endif // XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td index f92b2fc759b298..9c184716ffb913 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td @@ -68,97 +68,6 @@ def XLAGPU_SyncThreadsOp : XLAGPU_Op<"sync_threads", [ let assemblyFormat = "operands attr-dict `:` type($operands)"; } -def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw", - [Pure, - TypesMatchWith<"result type matches type of dest", - "input", "result", "$_self">, - DeclareOpInterfaceMethods - ]> { - let summary = "Atomically updates an element of a tensor."; - - let description = [{ - Reads an element from a tensor, computes the updated value for it, and - writes back the result. - }]; - - let arguments = (ins AnyRankedTensor:$input, Variadic:$indices); - let results = (outs AnyRankedTensor:$result); - // The region takes the current value in the tensor as an argument and yields - // the updated value. - let regions = (region SizedRegion<1>:$computation); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>]; - - let extraClassDeclaration = [{ - mlir::Block* getBody() { return &getComputation().front(); } - mlir::OpBuilder getBodyBuilder() { - return mlir::OpBuilder(getBody(), std::prev(getBody()->end())); - } - // The value stored in tensor[ivs]. - mlir::Value getCurrentValue() { - return getRegion().getArgument(0); - } - }]; - let hasFolder = 1; - - let assemblyFormat = [{ - $input `[` $indices `]` `:` type($input) $computation attr-dict - }]; -} - -def XLAGPU_YieldOp : XLAGPU_Op<"yield", [ - ParentOneOf<["::xla::gpu::AtomicRMWOp", "::xla::gpu::LoopOp"]>, - ReturnLike, Terminator]> { - let summary = "Terminator for atomic_rmw ops."; - let arguments = (ins Variadic:$result); - - let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let assemblyFormat = "$result attr-dict `:` type($result)"; -} - -def XLAGPU_PureCallOp : XLAGPU_Op<"pure_call", - [Pure, CallOpInterface, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods - ]> { - let summary = "Function call without side effects."; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); - let results = (outs Variadic); - let builders = [ - OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{ - $_state.addOperands(operands); - $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee)); - $_state.addTypes(callee.getFunctionType().getResults()); - }]>, - OpBuilder<(ins "mlir::FlatSymbolRefAttr":$callee, CArg<"mlir::ValueRange", "{}">:$operands, CArg<"llvm::ArrayRef", "{}">:$result_types), [{ - $_state.addOperands(operands); - $_state.addAttribute("callee", callee); - $_state.addTypes(result_types); - }]>, - ]; - let assemblyFormat = [{ - $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) - }]; - - let extraClassDeclaration = [{ - operand_range getArgOperands() { - return getOperands(); - } - - mlir::MutableOperandRange getArgOperandsMutable() { - return getOperandsMutable(); - } - - mlir::CallInterfaceCallable getCallableForCallee() { - return (*this)->getAttrOfType("callee"); - } - - void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); - } - }]; -} - def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce", [Pure, CallOpInterface, TypesMatchWith<"result type matches type of operands", @@ -212,169 +121,13 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce", let hasCustomAssemblyFormat = 1; } -def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert", - [Pure, - TypesMatchWith<"result type matches type of operands", - "dest", "result", "$_self">, - TypesMatchWith<"value type matches element type of dest", - "dest", "value", - "::llvm::cast($_self).getElementType()">]> { - let summary = "Inserts a value into a tensor if a condition holds"; - let arguments = (ins I1:$condition, AnyType:$value, - AnyStaticShapeTensor:$dest, Variadic:$indices); - let results = (outs AnyStaticShapeTensor:$result); - - let assemblyFormat = [{ - $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest) - }]; -} - -def XLAGPU_PredicatedExtractOp : XLAGPU_Op<"predicated_extract", - [Pure, - TypesMatchWith<"fallback type matches element type of src", - "src", "fallback", - "::llvm::cast($_self).getElementType()">, - TypesMatchWith<"result type matches element type of src", - "src", "result", - "::llvm::cast($_self).getElementType()">]> { - let summary = "Inserts a value into a tensor if a condition holds"; - let arguments = (ins I1:$condition, AnyType:$fallback, - AnyStaticShapeTensor:$src, Variadic:$indices); - let results = (outs AnyType:$result); - - let assemblyFormat = [{ - $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src) - }]; -} - -def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { - let summary = "Applies indexing map to a list of SSA values"; - let description = [{ - The `apply_indexing` operation applies an indexing_map to a list - of SSA values, yielding a single SSA value. The number of dimension and - symbol arguments must be equal to the respective number of dimensional and - symbolic inputs in the indexing_map. The index mapping can be - multi-dimensional, and so the `apply_indexing` operation always returns one - value. The operands and results must all have ‘index’ type. - - Example: - - ```mlir - #map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 floordiv 8 + d1 floordiv 128, s0)> - %results:2 = xla_gpu_ops.apply_indexing #map (%0 in [0, 10], %1 in [0, 11])[%2 in [11, 32]] - ``` - }]; - let arguments = (ins Variadic:$operands, - XLAGPU_IndexingMapAttr:$indexing_map_attr); - let results = (outs Variadic); - - let builders = [ - OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols, - "const IndexingMap&":$indexing_map)>, - OpBuilder<(ins "mlir::ValueRange":$operands, - "const IndexingMap&":$indexing_map)>, - OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map, - "llvm::ArrayRef":$dim_vars, - "llvm::ArrayRef":$range_vars)>, - ]; - let extraClassDeclaration = [{ - // Returns the indexing map constructed from IndexingMapAttr. - xla::IndexingMap getIndexingMap(); - // Extracts the affine map from the attribute. - mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); } - }]; - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; - let hasCanonicalizer = 1; - let hasFolder = 1; -} - -def LoopOp : XLAGPU_Op<"loop", [ - AttrSizedOperandSegments, Pure, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"xla::gpu::YieldOp"> - ]> { - let summary = "Loop nest that iterates over all feasible values of RangeVars."; - let description = [{ - - ```mlir - #map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: - d0 in [0, 3], - s0 in [0, 1024], - s1 in [0, 32] - > - // Initial sum set to 0. - %sum_0 = arith.constant 0.0 : f32 - %dim = arith.constant 1 : index - // iter_args binds initial values to the loop's region arguments. - %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) - in #map iter_args(%sum_iter = %sum_0) -> (f32) { - %t = tensor.extract %buffer[%i, %j] : tensor<1024x32xf32> - %sum_next = arith.addf %sum_iter, %t : f32 - // Yield current iteration sum to next iteration %sum_iter or to %sum - // if final iteration. - xla_gpu.yield %sum_next : f32 - } - ``` - }]; - let arguments = (ins XLAGPU_IndexingMapAttr:$indexing_map_attr, - Variadic:$dims, - Variadic:$inits); - let results = (outs Variadic); - let regions = (region SizedRegion<1>:$region); - - let builders = [ - OpBuilder<(ins "IndexingMapAttr":$indexing_map_attr, - "mlir::ValueRange":$dims, "mlir::ValueRange":$inits, - CArg<"llvm::function_ref", "nullptr">)>, - OpBuilder<(ins "const IndexingMap&":$indexing_map, - "mlir::ValueRange":$dims, "mlir::ValueRange":$inits, - CArg<"llvm::function_ref", "nullptr">)> - ]; - - let extraClassDeclaration = [{ - using BodyBuilderFn = llvm::function_ref< - void(mlir::OpBuilder&, mlir::Location, mlir::ValueRange, - mlir::ValueRange, mlir::ValueRange)>; - - // Returns the indexing map constructed from IndexingMapAttr. - xla::IndexingMap getIndexingMap(); - int64_t getNumInductionVars() { - return getBody()->getNumArguments() - getIndexingMapAttr().getNumResults() - - getNumResults(); - } - mlir::BlockArgument getInductionVar(int64_t index) { - return getBody()->getArgument(index); - } - mlir::Block::BlockArgListType getInductionVars() { - return getBody()->getArguments().take_front(getNumInductionVars()); - } - mlir::Block::BlockArgListType getIndexingMapResults() { - return getBody()->getArguments().drop_front(getNumInductionVars()) - .drop_back(getNumResults()); - } - mlir::Block::BlockArgListType getRegionIterArgs() { - return getBody()->getArguments().take_back(getNumResults()); - } - }]; - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; - let hasCanonicalizer = 1; -} - def XLAGPU_MaterializeOp : XLAGPU_Op<"materialize", [AttrSizedOperandSegments, CallOpInterface]> { let summary = "Reads a tensor into registers"; let arguments = (ins Variadic:$input, Variadic:$indices, FlatSymbolRefAttr:$callee, - XLAGPU_IndexingMapAttr:$map); + XLA_IndexingMapAttr:$map); let results = (outs XLAGPU_IndexedVectorType:$result); let hasVerifier = 1; let assemblyFormat = [{ @@ -403,7 +156,7 @@ def XLAGPU_InsertOp : XLAGPU_Op<"insert", [TypesMatchWith< let arguments = (ins XLAGPU_IndexedVectorType:$source, Variadic:$indices, AnyRankedTensor:$dest, - XLAGPU_IndexingMapAttr:$map); + XLA_IndexingMapAttr:$map); let results = (outs AnyRankedTensor:$result); let hasVerifier = 1; let assemblyFormat = [{ diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td index 5d73344654d1de..bcb9a9a66c89df 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td @@ -20,7 +20,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" -include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td" +include "xla/codegen/ir/xla_attrs.td" class XLAGPU_Type traits = []> : TypeDef { @@ -33,7 +33,7 @@ def XLAGPU_IndexedVectorType : XLAGPU_Type<"IndexedVector", "indexed_vector", let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "mlir::Type":$elementType, - XLAGPU_IndexingMapAttr:$indexing_map_attr + XLA_IndexingMapAttr:$indexing_map_attr ); let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index 001f1db29f7e64..c4bd5923db38cb 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -67,6 +67,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", @@ -74,7 +75,6 @@ cc_library( "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service:algorithm_util", - "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -110,6 +110,7 @@ xla_cc_test( ":computation_partitioner", ":elemental_hlo_to_mlir", "//xla:status_macros", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", @@ -152,6 +153,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/mlir/tools/mlir_replay/public:compiler_trace_instrumentation", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index dedf3f806747a3..8ec559156ca0bb 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -61,6 +61,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" @@ -74,7 +75,6 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/shape_util.h" @@ -1579,7 +1579,7 @@ ValueRange EmitXlaLoopOp(ImplicitLocOpBuilder& b, ValueRange dim_values, } else { results = create_body(ivs, map_results, iter_args); } - nested_builder.create(loc, results); + nested_builder.create(loc, results); }; return b.create(indexing_map, dim_values, iter_args_inits, bb) .getResults(); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 228f02d1c3f223..aba8e66e13e9f6 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/parser/hlo_parser.h" @@ -65,7 +66,8 @@ class ElementalHloToMlirTest : public HloTestBase { mlir::affine::AffineDialect, mlir::arith::ArithDialect, mlir::math::MathDialect, mlir::scf::SCFDialect, mlir::mhlo::MhloDialect, mlir::LLVM::LLVMDialect, - mlir::DLTIDialect, xla::gpu::XlaGpuDialect>(); + mlir::DLTIDialect, xla::XlaDialect, + xla::gpu::XlaGpuDialect>(); } // Converts the root subgraph of the entry function of the given hlo module to @@ -237,9 +239,9 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2]">(%[[Y]]) - // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 3), + // CHECK: %[[J0:.*]] = xla.apply_indexing #xla.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2]">(%[[Y]]) + // CHECK: %[[J1:.*]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1 - 3), // CHECK-SAME: d0 in [0, 7], d1 in [0, 6]">(%[[Z]], %[[I]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] @@ -286,8 +288,8 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] // If symbol rescaling wasn't working we would have a // `d1 floordiv ` in the map: - // CHECK: %[[K:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK: %[[K:.*]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 * 2 + d1), // CHECK-SAME: d0 in [0, 18], d1 in [0, 3]">(%[[X]], %[[I]]) // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] @@ -507,7 +509,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 - // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing + // CHECK: %[[CONSTRAINT_VAL:.*]] = xla.apply_indexing // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] @@ -519,9 +521,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[Y_BOUNDS:.*]] = arith.andi %[[Y_L]], %[[Y_H]] // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] - // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing + // CHECK: %[[IN0:.*]] = xla.apply_indexing // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) - // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing + // CHECK: %[[IN1:.*]] = xla.apply_indexing // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] @@ -549,7 +551,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 - // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing + // CHECK: %[[CONSTRAINT_VAL:.*]] = xla.apply_indexing // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] @@ -561,9 +563,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[Y_BOUNDS:.*]] = arith.andi %[[Y_L]], %[[Y_H]] // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] - // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing + // CHECK: %[[IN0:.*]] = xla.apply_indexing // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) - // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing + // CHECK: %[[IN1:.*]] = xla.apply_indexing // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] @@ -880,11 +882,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> @@ -926,11 +928,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 * 2 + d1), // CHECK-SAME: d0 in [0, 2], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 * 2 + d1), // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> @@ -974,20 +976,20 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 1), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1 - 1), // CHECK-SAME: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 2), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1 - 2), // CHECK-SAME: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> @@ -1028,16 +1030,16 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla.apply_indexing #xla.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), // CHECK-SAME: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), // CHECK-SAME: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> @@ -1079,11 +1081,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d1 * 2 + d0), // CHECK-SAME: d0 in [0, 3], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d1 * 2 + d0), // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> @@ -1125,14 +1127,14 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) - // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 floordiv 8) * 2 + d1), + // CHECK: %[[XX2:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> ((d0 floordiv 8) * 2 + d1), // CHECK-SAME: d0 in [0, 15], d1 in [0, 1]">(%[[O]], %[[I]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> @@ -1176,11 +1178,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[A2:.+]] = %[[A1]]) -> (f32) { // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { - // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK: %[[XX0:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) - // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK: %[[XX1:.+]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> @@ -1715,8 +1717,8 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK-SAME: %[[P1:.*]]: tensor<100xf32>, // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] - // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), + // CHECK: %[[IDX:.*]] = xla.apply_indexing + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 * 10 + d1), // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] @@ -1738,11 +1740,11 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK-SAME: %[[P0:.*]]: tensor<10x10xf32>, // CHECK-SAME: %[[P1:.*]]: tensor<100xf32>, // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} - // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 + // CHECK: %[[P0_V:.*]] = xla.pure_call @main_p0 // CHECK: %[[IDX:.*]] = - // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: #xla.indexing_map<"(d0, d1) -> (d0 * 10 + d1), // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) - // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 + // CHECK: %[[P1_V:.*]] = xla.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] )")); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index e1ce7cc4c56df2..712e060ab71dfc 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -77,6 +77,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -329,13 +330,13 @@ MlirFusionEmitterBase::CreateMLIRModule( const std::string& entry_function_name, const BufferAssignment* buffer_assignment, mlir::interpreter::MlirCompilationTrace* trace) const { - context.loadDialect(); + context.loadDialect< + mlir::DLTIDialect, mlir::NVVM::NVVMDialect, mlir::ROCDL::ROCDLDialect, + mlir::affine::AffineDialect, mlir::arith::ArithDialect, + mlir::cf::ControlFlowDialect, mlir::func::FuncDialect, + mlir::gpu::GPUDialect, mlir::math::MathDialect, mlir::mhlo::MhloDialect, + mlir::scf::SCFDialect, mlir::tensor::TensorDialect, + mlir::vector::VectorDialect, xla::XlaDialect, xla::gpu::XlaGpuDialect>(); mlir::DialectRegistry registry; mlir::LLVM::registerInlinerInterface(registry); mlir::func::registerInlinerExtension(registry); diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index e8cf951a532edf..00d7e9735eb7cb 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -170,7 +171,7 @@ mlir::Value EmitScatterComputation( auto reduced_val = mlir_converter::InlineBlock( body_builder, reducer.getBody().front(), {atomic_rmw.getCurrentValue(), update_elem})[0]; - body_builder.create(reducer->getLoc(), reduced_val); + body_builder.create(reducer->getLoc(), reduced_val); return atomic_rmw->getResult(0); } diff --git a/third_party/xla/xla/service/gpu/fusions/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/tests/BUILD index e6f812bd158a1e..def0e86cdc4e10 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/tests/BUILD @@ -18,8 +18,8 @@ lit_test_suite( exec_properties = tf_exec_properties({"tags": tf_cuda_tests_tags()}), hermetic_cuda_data_dir = "%S/../../../../../../../cuda_nvcc", tools = [ + "//xla/codegen/tools:emitters_opt", "//xla/service/gpu/fusions/tools:fusion_to_mlir", - "//xla/service/gpu/fusions/tools:mlir_fusions_opt", "//xla/service/gpu/fusions/tools:test_correctness", "@llvm-project//llvm:FileCheck", ], diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo index 875bb871d287e3..d289ab87cf1fd7 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -8,10 +8,10 @@ fusion { param2 = f32[300] parameter(2) ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} } -// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<"(th_x, bl_x) -> (bl_x * 128 + th_x) -// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x) -// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 200) -// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 600) +// CHECK-DAG: #[[MAP:.*]] = #xla.indexing_map<"(th_x, bl_x) -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 200) +// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 600) // CHECK: func.func @main // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, @@ -22,17 +22,17 @@ fusion { // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id x -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in #[[LOOPMAP_1]] -// CHECK: xla_gpu.apply_indexing #[[MAP]] -// CHECK: %[[VAL_1:.*]] = xla_gpu.pure_call @fusion_param0 +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in #[[LOOPMAP_1]] +// CHECK: xla.apply_indexing #[[MAP]] +// CHECK: %[[VAL_1:.*]] = xla.pure_call @fusion_param0 // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[VAL_1:.*]] into {{.*}}[%[[RA]]] -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in #[[LOOPMAP_2]] -// CHECK: xla_gpu.apply_indexing #[[MAP]] -// CHECK: %[[VAL_2:.*]] = xla_gpu.pure_call @fusion_param1 +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in #[[LOOPMAP_2]] +// CHECK: xla.apply_indexing #[[MAP]] +// CHECK: %[[VAL_2:.*]] = xla.pure_call @fusion_param1 // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[VAL_2:.*]] into {{.*}}[%[[RA]]] -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in #[[LOOPMAP_3]] -// CHECK: xla_gpu.apply_indexing #[[MAP]] -// CHECK: %[[VAL_3:.*]] = xla_gpu.pure_call @fusion_param2 +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in #[[LOOPMAP_3]] +// CHECK: xla.apply_indexing #[[MAP]] +// CHECK: %[[VAL_3:.*]] = xla.pure_call @fusion_param2 // CHECK: %[[INSERTED_3:.*]] = tensor.insert %[[VAL_3:.*]] into {{.*}}[%[[RA]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/prologue_epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/prologue_epilogue.hlo index bb9e0c02c13c50..c7d94586b45517 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/prologue_epilogue.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/prologue_epilogue.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -12,12 +12,12 @@ fusion { } // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in -// CHECK: %[[VAL_1_1:.*]] = xla_gpu.pure_call @fusion_log({{.*}}, %[[THREAD_ID]]) -// CHECK: %[[VAL_1_2:.*]] = xla_gpu.pure_call @fusion__epilogue__neg({{.*}}, %[[RA]], %[[VAL_1_1]]) +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in +// CHECK: %[[VAL_1_1:.*]] = xla.pure_call @fusion_log({{.*}}, %[[THREAD_ID]]) +// CHECK: %[[VAL_1_2:.*]] = xla.pure_call @fusion__epilogue__neg({{.*}}, %[[RA]], %[[VAL_1_1]]) // CHECK: tensor.insert %[[VAL_1_2:.*]] into {{.*}}[%[[RA]]] -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in -// CHECK: %[[VAL_2_1:.*]] = xla_gpu.pure_call @fusion_exp({{.*}}, %[[THREAD_ID]]) -// CHECK: %[[VAL_2_2:.*]] = xla_gpu.pure_call @fusion__epilogue__neg({{.*}}, %[[RA]], %[[VAL_2_1]]) +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in +// CHECK: %[[VAL_2_1:.*]] = xla.pure_call @fusion_exp({{.*}}, %[[THREAD_ID]]) +// CHECK: %[[VAL_2_2:.*]] = xla.pure_call @fusion__epilogue__neg({{.*}}, %[[RA]], %[[VAL_2_1]]) // CHECK: tensor.insert %[[VAL_2_2:.*]] into {{.*}}[%[[RA]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/vectorization.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/vectorization.hlo index e46fe1f8b65ccd..f9518777447092 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/vectorization.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/vectorization.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -7,4 +7,4 @@ fusion { param1 = f32[640000] parameter(1) ROOT concat = f32[1280002] concatenate(param0, param1), dimensions={0} } -// CHECK-COUNT-2: xla_gpu.loop ({{.*}})[%{{.*}}, %{{.*}}] -> \ No newline at end of file +// CHECK-COUNT-2: xla.loop ({{.*}})[%{{.*}}, %{{.*}}] -> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo index 92914a55dc3c42..628b035833344b 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s --bijection_inputs=dus:1 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/operand_subgraph_two_roots.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/operand_subgraph_two_roots.hlo index c4d3a427d41e9b..adc0afe278904b 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/operand_subgraph_two_roots.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/operand_subgraph_two_roots.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -25,10 +25,10 @@ fusion { // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id x -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: -> (%[[RA:.*]], %[[RB:.*]]) in -// CHECK: %[[I0:.*]] = xla_gpu.pure_call @fusion_param_2_plus_one -// CHECK: %[[I1:.*]] = xla_gpu.pure_call @fusion_param_3_plus_one +// CHECK: %[[I0:.*]] = xla.pure_call @fusion_param_2_plus_one +// CHECK: %[[I1:.*]] = xla.pure_call @fusion_param_3_plus_one // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] // CHECK: %[[MIN0:.*]] = arith.minsi %[[IDX0]], %[[C_384]] // CHECK: %[[MAX0:.*]] = arith.maxsi %[[MIN0]], %[[C_0]] @@ -37,6 +37,6 @@ fusion { // CHECK: %[[MIN1:.*]] = arith.minsi %[[IDX1]], %[[C_384]] // CHECK: %[[MAX1:.*]] = arith.maxsi %[[MIN1]], %[[C_0]] // CHECK: %[[ADD1:.*]] = arith.addi %[[RB]], %[[MAX1]] -// CHECK: %[[UPDATE:.*]] = xla_gpu.pure_call @fusion_param_1_10 +// CHECK: %[[UPDATE:.*]] = xla.pure_call @fusion_param_1_10 // CHECK: %[[INSERT:.*]] = tensor.insert %[[UPDATE:.*]] into %{{[a-z0-9]+}}[%[[ADD0]], %[[ADD1]]] -// CHECK: xla_gpu.yield %[[INSERT]] \ No newline at end of file +// CHECK: xla.yield %[[INSERT]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/out_of_bounds.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/out_of_bounds.hlo index 40e16cdf13adbc..05abdcb9cf159d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/out_of_bounds.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/out_of_bounds.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -18,10 +18,10 @@ fusion { // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: -> (%[[RA:.*]], %[[RB:.*]]) in -// CHECK: %[[I0:.*]] = xla_gpu.pure_call @fusion_i0 -// CHECK: %[[I1:.*]] = xla_gpu.pure_call @fusion_i1 +// CHECK: %[[I0:.*]] = xla.pure_call @fusion_i0 +// CHECK: %[[I1:.*]] = xla.pure_call @fusion_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] // CHECK: %[[MIN0:.*]] = arith.minsi %[[IDX0]], %[[C_5]] // CHECK: %[[MAX0:.*]] = arith.maxsi %[[MIN0]], %[[C_0]] @@ -30,6 +30,6 @@ fusion { // CHECK: %[[MIN1:.*]] = arith.minsi %[[IDX1]], %[[C_5]] // CHECK: %[[MAX1:.*]] = arith.maxsi %[[MIN1]], %[[C_0]] // CHECK: %[[ADD1:.*]] = arith.addi %[[RB]], %[[MAX1]] -// CHECK: %[[UPDATE:.*]] = xla_gpu.pure_call @fusion_updates +// CHECK: %[[UPDATE:.*]] = xla.pure_call @fusion_updates // CHECK: %[[INSERT:.*]] = tensor.insert %[[UPDATE:.*]] into %{{[a-z0-9]+}}[%[[ADD0]], %[[ADD1]]] -// CHECK: xla_gpu.yield %[[INSERT]] \ No newline at end of file +// CHECK: xla.yield %[[INSERT]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/simple.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/simple.hlo index 0d382a5f8fa169..b4ecdc5ba4f7f4 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/simple.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/simple.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=updated:1 @@ -20,10 +20,10 @@ fusion { // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: -> (%[[RA:.*]], %[[RB:.*]]) in -// CHECK: %[[I0:.*]] = xla_gpu.pure_call @fusion_i0 -// CHECK: %[[I1:.*]] = xla_gpu.pure_call @fusion_i1 +// CHECK: %[[I0:.*]] = xla.pure_call @fusion_i0 +// CHECK: %[[I1:.*]] = xla.pure_call @fusion_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] // CHECK: %[[MIN0:.*]] = arith.minsi %[[IDX0]], %[[C_15]] // CHECK: %[[MAX0:.*]] = arith.maxsi %[[MIN0]], %[[C_0]] @@ -32,6 +32,6 @@ fusion { // CHECK: %[[MIN1:.*]] = arith.minsi %[[IDX1]], %[[C_24]] // CHECK: %[[MAX1:.*]] = arith.maxsi %[[MIN1]], %[[C_0]] // CHECK: %[[ADD1:.*]] = arith.addi %[[RB]], %[[MAX1]] -// CHECK: %[[UPDATE:.*]] = xla_gpu.pure_call @fusion_updates +// CHECK: %[[UPDATE:.*]] = xla.pure_call @fusion_updates // CHECK: %[[INSERT:.*]] = tensor.insert %[[UPDATE:.*]] into %{{[a-z0-9]+}}[%[[ADD0]], %[[ADD1]]] -// CHECK: xla_gpu.yield %[[INSERT]] +// CHECK: xla.yield %[[INSERT]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo index b9dbddaa26bb4d..b70681333f1b85 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s --bijection_inputs=dus:1 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo index 775dd248d0bb7b..7c019ed0e08b23 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s --bijection_inputs=dus:1 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast.hlo index b006bd6390706a..e0ae317af0f9c7 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize | FileCheck %s +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s --bijection_outputs=bcast fusion { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo index e290ccdc980f45..569968afe31850 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt --xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_outputs=broadcast @@ -7,7 +7,7 @@ bcast { ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={} } // CHECK: func.func @main(%[[ARG0:.*]]: tensor<2x16x48xbf16> -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]]) in +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]]) in // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) // CHECK: %[[CST:.*]] = arith.constant 0.000 // CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ITER]][%[[RA]], %[[RB]], %[[RC]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo index b225b378acc732..2bd646346ec9c7 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt --xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s bcast { @@ -9,7 +9,7 @@ bcast { // CHECK: func.func @main(%[[ARG0:.*]]: tensor<24x2048x2048x3x4096xbf16> // CHECK: gpu.block_id x {xla.range = [0 : index, 1207959551 : index]} // CHECK: gpu.block_id y {xla.range = [0 : index, 1 : index]} -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]], %[[RD:.*]], %[[RE:.*]]) in +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]], %[[RD:.*]], %[[RE:.*]]) in // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) // CHECK: %[[CST:.*]] = arith.constant 1.000 // CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ITER]][%[[RA]], %[[RB]], %[[RC]], %[[RD]], %[[RE]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo index 5719809d3a327b..a45dd36209207d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt --xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_outputs=broadcast @@ -7,7 +7,7 @@ bcast { ROOT broadcast = s4[3]{0} broadcast(x), dimensions={} } // CHECK: func.func @main(%[[ARG0:.*]]: tensor<3xi4> -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) // CHECK: %[[CST:.*]] = arith.constant -2 // CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ITER]][%[[RA]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo index 782b2e17753f0d..3d65ee3504553b 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt --xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_outputs=broadcast @@ -12,7 +12,7 @@ ENTRY main { } // CHECK: func.func @main(%[[ARG0:.*]]: tensor<3x31xi4> -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]]) in +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]]) in // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) // CHECK: %[[CST:.*]] = arith.constant dense<[-2, -3, -4]> // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[CST]][%[[RA]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo index 09fc596f3000dc..ec8dbd779ea327 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo @@ -1,5 +1,5 @@ // RUN: fusion_to_mlir %s |\ -// RUN: mlir_fusions_opt --xla-gpu-test-optimize | FileCheck %s +// RUN: emitters_opt --xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { @@ -14,7 +14,7 @@ fusion { // CHECK: func.func @main // CHECK-NEXT: gpu.thread_id -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-NEXT: pure_call @fusion_mul // CHECK-NEXT: tensor.insert // CHECK: return diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/dus_s64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/dus_s64.hlo index 8b19705bf21a26..a144f92dcc179d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/dus_s64.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/dus_s64.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize | FileCheck %s +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo index ff015b9d40aa4f..c58be970a885be 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo @@ -1,5 +1,5 @@ // RUN: fusion_to_mlir %s |\ -// RUN: mlir_fusions_opt --xla-gpu-test-optimize | FileCheck %s +// RUN: emitters_opt --xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/iota_copy_bitcast.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/iota_copy_bitcast.hlo index 0a6704bbd7ff22..53afdbf15071f9 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/iota_copy_bitcast.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/iota_copy_bitcast.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize | FileCheck %s +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo index 3e17261eef1885..47b230d5eaba0d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo @@ -11,7 +11,7 @@ fusion { } // CHECK: func.func @main -// CHECK: xla_gpu.pure_call @fusion_tuple +// CHECK: xla.pure_call @fusion_tuple // CHECK: func.func private @fusion_tuple // CHECK-DAG: arith.minimumf // CHECK-DAG: arith.maximumf diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/no_duplication.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/no_duplication.hlo index e8d64d755c48e1..fa4a4d7745a115 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/no_duplication.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/no_duplication.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize | FileCheck %s +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo index 97119a4833523c..44907583c2046b 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo @@ -1,5 +1,5 @@ // RUN: fusion_to_mlir %s |\ -// RUN: mlir_fusions_opt --xla-gpu-test-optimize | FileCheck %s +// RUN: emitters_opt --xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo index 4f93eacbfab93d..08a9ce73748ffc 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo @@ -1,5 +1,5 @@ // RUN: fusion_to_mlir %s |\ -// RUN: mlir_fusions_opt --xla-gpu-test-optimize | FileCheck %s +// RUN: emitters_opt --xla-gpu-test-optimize | FileCheck %s // RUN: test_correctness %s fusion { @@ -12,11 +12,11 @@ fusion { ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) } -// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 4), -// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 4), +// CHECK: #[[MAJOR:.*]] = #xla.indexing_map<"(d0) -> (d0 floordiv 4), +// CHECK: #[[MINOR:.*]] = #xla.indexing_map<"(d0) -> (d0 mod 4), -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in -// CHECK-DAG: %[[MAJOR_IDX:.*]] = xla_gpu.apply_indexing #[[MAJOR]] -// CHECK-DAG: %[[MINOR_IDX:.*]] = xla_gpu.apply_indexing #[[MINOR]] +// CHECK: xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in +// CHECK-DAG: %[[MAJOR_IDX:.*]] = xla.apply_indexing #[[MAJOR]] +// CHECK-DAG: %[[MINOR_IDX:.*]] = xla.apply_indexing #[[MINOR]] // CHECK-DAG: tensor.insert {{.*}}[%[[MAJOR_IDX]], %[[MINOR_IDX]]] // CHECK-DAG: tensor.insert {{.*}}[%[[RA]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo index 88634b3f233647..26718e83e2e03c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo @@ -24,7 +24,7 @@ fusion { } // CHECK: @main -// CHECK: %[[R0:.*]], %[[R1:.*]], %[[R2:.*]] = xla_gpu.pure_call @fusion_tuple +// CHECK: %[[R0:.*]], %[[R1:.*]], %[[R2:.*]] = xla.pure_call @fusion_tuple // CHECK-DAG: tensor.insert %[[R0]] // CHECK-DAG: tensor.insert %[[R1]] // CHECK-DAG: tensor.insert %[[R2]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo index 87521bdda33d62..dabad318f4d7af 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo @@ -20,11 +20,11 @@ fusion { } // CHECK: func @main( -// CHECK: %[[L1:.*]], %[[L2:.*]] = xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in -// CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla_gpu.pure_call @fusion_d_1 +// CHECK: %[[L1:.*]], %[[L2:.*]] = xla.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in +// CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla.pure_call @fusion_d_1 // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[SCALARS_0]] into %{{.*}}[%[[RA]]] // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[SCALARS_1]] into %{{.*}}[%[[RA]]] -// CHECK: xla_gpu.yield %[[INSERTED_1]], %[[INSERTED_2]] +// CHECK: xla.yield %[[INSERTED_1]], %[[INSERTED_2]] // CHECK: return %[[L1]], %[[L2]] // CHECK: func private @fusion_d_1 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo index 66dc6d99441980..7a5280658cee8a 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s --bijection_inputs=neg:0 --bijection_outputs=neg diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo index 5196b7e81ba2a6..302f6a21ec7c06 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s --bijection_inputs=neg:0 --bijection_outputs=neg diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo index 831af2b1402710..908664406c6a4f 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce @@ -15,7 +15,7 @@ fusion { ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=add } -// CHECK: xla_gpu.pure_call @add_add +// CHECK: xla.pure_call @add_add // CHECK: allocate_shared // CHECK: tensor.insert // CHECK: sync_threads diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo index 30dcd937472695..06147827d84e4c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo index 981f2dea75b0a8..28f24ef158c1f4 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo index 470597e4f23bc9..f70a7f251aef6f 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo index 4aa1e85008892e..f1b5e0370f038c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_2.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_2.hlo index a751264414914a..81c6bebf03f680 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_2.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_2.hlo @@ -1,5 +1,5 @@ // RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce -// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always add { %p0 = f32[] parameter(0) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_32_v2.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_32_v2.hlo index 16b2b71ed36994..923f19b24e196f 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_32_v2.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_32_v2.hlo @@ -1,5 +1,5 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always add { %p0 = f32[] parameter(0) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_8_v2.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_8_v2.hlo index 9dab4a3313330a..175ced82445ca3 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_8_v2.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/f32_8_v2.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always // RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce add { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/s8_f32_32_v4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/s8_f32_32_v4.hlo index 9b584bc5c4ed2f..781449df049c8e 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/s8_f32_32_v4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column_small/s8_f32_32_v4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always add { %p0 = f32[] parameter(0) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo index d2e3928bdfd564..7be02bd4bf1087 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // The reference implementation reduces in f64, so we need a larger tolerance. diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/maximum_vector_size.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/maximum_vector_size.hlo index 91be1f47593f4f..508ecf5da65154 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/maximum_vector_size.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/maximum_vector_size.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/pred_mof_x2_v4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/pred_mof_x2_v4.hlo index a7ca6eb1bfc7e9..81109b91b9381e 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/pred_mof_x2_v4.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/pred_mof_x2_v4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: -xla-gpu-test-transform-loops | FileCheck %s // RUN: test_correctness %s --bijection_inputs=tmp_5:0 \ // RUN: --bijection_inputs=tmp_8:0 --bijection_outputs=tmp_5 \ diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo index 9757370d6dd524..a2e83ce2f77a4c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo index 6786bc93498486..d0e4120d62187d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce1:0 \ // RUN: --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 \ @@ -34,7 +34,7 @@ fusion { // CHECK-DAG: %[[CST0:.*]] = arith.constant 0xFFF0000000000000 // CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: -> (%[[RA:.*]], %[[RA:.*]], %[[RC:.*]]) in // reduce0 in the context of reduce2 and reduce1's prologue: // CHECK: scf.for %[[I:.*]] = %[[C0]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo index 102e32b861e648..f683df0b7eea5b 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo @@ -14,4 +14,4 @@ fusion { ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add } -// CHECK: xla_gpu.predicated_insert {{.*}} : tensor<17x19xf32, dense<[0, 1]> : tensor<2xi64>> +// CHECK: xla.predicated_insert {{.*}} : tensor<17x19xf32, dense<[0, 1]> : tensor<2xi64>> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo index 249372d1347abf..3a50ef26a7fd07 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo index d2c5885b9d6eca..b648b6fdcdf246 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ // RUN: --inline="default-pipeline='cse'" | FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce1:0 \ // RUN: --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 \ @@ -38,6 +38,6 @@ fusion { // CHECK-DAG: %[[LOG:.*]] = math.log %[[SUM]] // CHECK-DAG: %[[ABS:.*]] = math.absf %[[SUM]] // CHECK-DAG: %[[NEG:.*]] = arith.negf %[[PROD]] -// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]] -// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]] -// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]] +// CHECK-DAG: xla.predicated_insert %[[LOG]] +// CHECK-DAG: xla.predicated_insert %[[ABS]] +// CHECK-DAG: xla.predicated_insert %[[NEG]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo index 641372c73697ec..8576fd677a48cc 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce --bijection_outputs=exp @@ -17,7 +17,7 @@ fusion { } // CHECK: @fusion -// CHECK: xla_gpu.loop -// CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fusion_exp +// CHECK: xla.loop +// CHECK: %[[SIDE_OUTPUT:.*]] = xla.pure_call @fusion_exp // CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]] -// CHECK: xla_gpu.yield +// CHECK: xla.yield diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo index 6452af04027333..2f95bc12e467f3 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=reduce:0,1 \ // RUN: --bijection_outputs=reduce diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/add.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/add.hlo index f20127e163bd02..db06318cde44ae 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/add.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/add.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=scatter:2 @@ -32,11 +32,11 @@ scatter { // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ule // CHECK: scf.if %[[IN_BOUNDS]] -> (tensor<10x5xf32>) { -// CHECK: xla_gpu.loop -// CHECK: %[[UPD_ELEM:.*]] = xla_gpu.pure_call @scatter_update -// CHECK: %[[RMW:.*]] = xla_gpu.atomic_rmw %{{[a-z0-9]+}} +// CHECK: xla.loop +// CHECK: %[[UPD_ELEM:.*]] = xla.pure_call @scatter_update +// CHECK: %[[RMW:.*]] = xla.atomic_rmw %{{[a-z0-9]+}} // CHECK: ^bb0(%[[CUR_VALUE:.*]]: f32): // CHECK: %[[SUM:.*]] = arith.addf %[[CUR_VALUE]], %[[UPD_ELEM]] -// CHECK: xla_gpu.yield %[[SUM]] : f32 +// CHECK: xla.yield %[[SUM]] : f32 // CHECK: } -// CHECK: xla_gpu.yield %[[RMW]] : tensor<10x5xf32> \ No newline at end of file +// CHECK: xla.yield %[[RMW]] : tensor<10x5xf32> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/overwrite.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/overwrite.hlo index 28e28c532da04b..3738fab5afccff 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/overwrite.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/overwrite.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -31,9 +31,9 @@ scatter { // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ule // CHECK: scf.if %[[IN_BOUNDS]] -> (tensor<10x5xf32>) { -// CHECK: %[[UPD_ELEM:.*]] = xla_gpu.pure_call @scatter_update -// CHECK: %[[RMW:.*]] = xla_gpu.atomic_rmw %{{[a-z0-9]+}} +// CHECK: %[[UPD_ELEM:.*]] = xla.pure_call @scatter_update +// CHECK: %[[RMW:.*]] = xla.atomic_rmw %{{[a-z0-9]+}} // CHECK: ^bb0(%[[CUR_VALUE:.*]]: f32): -// CHECK: xla_gpu.yield %[[UPD_ELEM]] : f32 +// CHECK: xla.yield %[[UPD_ELEM]] : f32 // CHECK: } -// CHECK: xla_gpu.yield %[[RMW]] : tensor<10x5xf32> +// CHECK: xla.yield %[[RMW]] : tensor<10x5xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo index 8abb0d548d1c06..c44a4fe879900d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -24,7 +24,7 @@ scatter { unique_indices=true, to_apply=add } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(th_x) -> (th_x floordiv 2) +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(th_x) -> (th_x floordiv 2) // CHECK-LABEL: func.func @main( // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> @@ -36,9 +36,9 @@ scatter { // CHECK: %[[TH_X:.*]] = gpu.thread_id x -// CHECK: %[[SLICE_ID:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[TH_X]] +// CHECK: %[[SLICE_ID:.*]] = xla.apply_indexing #[[$MAP]](%[[TH_X]] -// CHECK: %[[IND0_I32:.*]] = xla_gpu.pure_call @scatter_indices +// CHECK: %[[IND0_I32:.*]] = xla.pure_call @scatter_indices // CHECK-SAME: (%[[OPERAND]], %[[UPDATES]], %[[SLICE_ID]], %[[C0]]) @@ -46,18 +46,18 @@ scatter { // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ule // CHECK: scf.if %[[IN_BOUNDS]] -> (tensor<10x5xf32>) { -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]]) in -// CHECK: %[[UPD_ELEM:.*]] = xla_gpu.pure_call @scatter_update( +// CHECK: %[[UPD_ELEM:.*]] = xla.pure_call @scatter_update( // CHECK-SAME: %[[OPERAND]], %[[UPDATES]], // CHECK-SAME: %[[RA]], %[[RB]], %[[RC]]) // CHECK: %[[IND0_RB:.*]] = arith.addi %[[RB]], %[[IND0]] -// CHECK: %[[CURRENT:.*]] = xla_gpu.pure_call @scatter_operand( +// CHECK: %[[CURRENT:.*]] = xla.pure_call @scatter_operand( // CHECK-SAME: %[[OPERAND]], %[[UPDATES]], %[[IND0_RB]], // CHECK-SAME: %[[RC]]) // CHECK: %[[COMBINED:.*]] = arith.addf %[[CURRENT]], %[[UPD_ELEM]] // CHECK: %[[UPDATED:.*]] = tensor.insert %[[COMBINED]] // CHECK-SAME: into %{{[a-z0-9]+}}[%{{.*}}, %[[RC]]] : tensor<10x5xf32> -// CHECK: xla_gpu.yield %[[UPDATED]] : tensor<10x5xf32> +// CHECK: xla.yield %[[UPDATED]] : tensor<10x5xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unsigned.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unsigned.hlo index 56857b64ce91b0..62d1a7c91d5a30 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unsigned.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unsigned.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -19,6 +19,6 @@ scatter { index_vector_dim=1, to_apply=add } -// CHECK: %[[PARAM:.*]] = xla_gpu.pure_call @scatter_indices +// CHECK: %[[PARAM:.*]] = xla.pure_call @scatter_indices // CHECK: %[[INDEX:.*]] = arith.index_castui %[[PARAM]] // CHECK: arith.cmpi ule, %[[INDEX]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo index 25695c8212f7d0..8199acbd52ece4 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -17,7 +17,7 @@ fusion { // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) -// CHECK: %[[ABS:.*]] = xla_gpu.pure_call @fusion__epilogue__ +// CHECK: %[[ABS:.*]] = xla.pure_call @fusion__epilogue__ // CHECK: tensor.insert %[[ABS]] into %[[OUT_]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo index 7c2d63c78a47ef..69cac60e4c4aa5 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -20,7 +20,7 @@ fusion { // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) -// CHECK: %[[ABS:.*]] = xla_gpu.pure_call @fusion__epilogue__ +// CHECK: %[[ABS:.*]] = xla.pure_call @fusion__epilogue__ // CHECK: tensor.insert %[[ABS]] into %[[OUT_]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo index 2fc3855efe4c11..7ebd717c8f6c63 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -16,8 +16,8 @@ fusion { // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[SYNC]] -// CHECK: %[[RES:.*]] = xla_gpu.pure_call @fusion__epilogue__transpose +// CHECK: %[[RES:.*]] = xla.pure_call @fusion__epilogue__transpose // CHECK: tensor.insert %[[RES]] into %[[OUT_]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo index 97e23f171b713a..44cb50208169d1 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s @@ -18,7 +18,7 @@ fusion { // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) -// CHECK: %[[ABS:.*]] = xla_gpu.pure_call @fusion__epilogue__ +// CHECK: %[[ABS:.*]] = xla.pure_call @fusion__epilogue__ // CHECK: tensor.insert %[[ABS]] into %[[OUT_]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/mixed_indexing.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/mixed_indexing.hlo index d16b99adc054d2..96e34eaebec53e 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/mixed_indexing.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/mixed_indexing.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/multiple_roots.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/multiple_roots.hlo index 0bdc4aa0e8ab84..20c7b42a5b829c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/multiple_roots.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/multiple_roots.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/partial_tile.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/partial_tile.hlo index b4962525235094..6000693e7ba9d2 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/partial_tile.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/partial_tile.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_021.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_021.hlo index b18e4ef3c34ab6..4f1858c81db9ff 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_021.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_021.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_210.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_210.hlo index 382c3969a1b1f7..b70586064c084d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_210.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/prefer_large_vectors_210.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs.hlo index 688bf08ed7f059..e8707d36a89f87 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo index b17954109d1702..bbb17c2a0acd11 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_021_vectorized.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_021_vectorized.hlo index 7f30c9dd6eb455..92d4562d1cfc80 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_021_vectorized.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_021_vectorized.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_10.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_10.hlo index 385f2003d045d6..0acfd703dffeeb 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_10.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_10.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s --bijection_inputs=transpose --bijection_outputs=transpose diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_1302.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_1302.hlo index 97815ab8673fa5..8cf09e6f493dff 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_1302.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_1302.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_210_vectorized.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_210_vectorized.hlo index d18fd959124367..4f20c856bd7096 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_210_vectorized.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/transpose_210_vectorized.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize |\ // RUN: FileCheck %s // RUN: test_correctness %s diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD index d4db611668dfa2..9a67d32682236c 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD @@ -57,6 +57,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td index 96c1e0f85da133..1b5ffbdb24636e 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td @@ -44,7 +44,7 @@ def ConvertPureCallOpsPass }]; let dependentDialects = [ "mlir::func::FuncDialect", - "xla::gpu::XlaGpuDialect" + "xla::XlaDialect" ]; let constructor = "CreateConvertPureCallOpsPass()"; } @@ -60,6 +60,7 @@ def FlattenTensorsPass : Pass<"xla-gpu-flatten-tensors", "mlir::ModuleOp"> { "mlir::func::FuncDialect", "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", ]; let constructor = "CreateFlattenTensorsPass()"; } @@ -79,6 +80,7 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> { "mlir::scf::SCFDialect", "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", ]; let options = [ Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", @@ -173,7 +175,8 @@ def LowerXlaGpuToScfPass : let dependentDialects = [ "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect", - "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", "mlir::vector::VectorDialect", + "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", "mlir::vector::VectorDialect", ]; let options = [ @@ -194,7 +197,9 @@ def LowerXlaGpuLoopsToScfPass : Pass< let dependentDialects = [ "mlir::scf::SCFDialect", - "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + "mlir::tensor::TensorDialect", + "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", ]; let constructor = "CreateLowerXlaGpuLoopsToScfPass()"; @@ -209,7 +214,8 @@ def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> let dependentDialects = [ "mlir::func::FuncDialect", - "xla::gpu::XlaGpuDialect" + "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", ]; let constructor = "CreateEraseDeadFunctionsPass()"; @@ -281,7 +287,7 @@ def FuseLoopsPass : Pass<"xla-gpu-fuse-loops", "mlir::func::FuncOp"> { xla_gpu.yield %inserted } }]; - let dependentDialects = ["xla::gpu::XlaGpuDialect"]; + let dependentDialects = ["xla::gpu::XlaGpuDialect", "xla::XlaDialect"]; let constructor = "CreateFuseLoopsPass()"; } @@ -292,7 +298,7 @@ def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> { as [0, NUM_ITERATIONS - 1) and [NUM_ITERATIONS - 1, NUM_ITERATIONS) if it removes a constraint. }]; - let dependentDialects = ["xla::gpu::XlaGpuDialect"]; + let dependentDialects = ["xla::gpu::XlaGpuDialect", "xla::XlaDialect"]; let constructor = "CreatePeelLoopsPass()"; } @@ -308,6 +314,7 @@ def OptimizeLoopsPass : let dependentDialects = [ "mlir::vector::VectorDialect", "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", ]; let constructor = "CreateOptimizeLoopsPass()"; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc index 7d890ffc7f27f7..bee8dc383a0848 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/transforms/passes.h" @@ -358,7 +359,7 @@ std::optional GetIVRange(mlir::Value iv) { return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; } } - if (auto loop_op = mlir::dyn_cast(parent)) { + if (auto loop_op = mlir::dyn_cast(parent)) { const auto& indexing_map = loop_op.getIndexingMap(); if (bbarg.getArgNumber() >= loop_op.getNumInductionVars() && bbarg.getArgNumber() < diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD index 381d5a3220b1df..10228fcc460af8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD @@ -10,7 +10,7 @@ lit_test_suite( srcs = glob(["*.mlir"]), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "//xla/codegen/tools:emitters_opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_float_nvidia.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_float_nvidia.mlir index 27b0c083d06ec5..c6eda9e8363817 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_float_nvidia.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_float_nvidia.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-convert-float-nvidia -canonicalize | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-gpu-convert-float-nvidia -canonicalize | FileCheck %s module { func.func @intr_f16_to_f8(%arg0: f16) -> (f8E4M3FN, f8E5M2) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir index a6296414b9c62f..356f0fab167ce3 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir @@ -1,5 +1,5 @@ -// RUN: mlir_fusions_opt %s -xla-gpu-convert-pure-call-ops | FileCheck %s -// RUN: mlir_fusions_opt %s -cse -xla-gpu-convert-pure-call-ops \ +// RUN: emitters_opt %s -xla-gpu-convert-pure-call-ops | FileCheck %s +// RUN: emitters_opt %s -cse -xla-gpu-convert-pure-call-ops \ // RUN: | FileCheck %s -check-prefixes=CHECK-CSE func.func private @callee() -> f32 { @@ -11,9 +11,9 @@ func.func @caller() -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index - %call0 = xla_gpu.pure_call @callee() : () -> (f32) + %call0 = xla.pure_call @callee() : () -> (f32) %v = scf.for %i = %c0 to %c10 step %c1 iter_args(%r = %call0) -> f32 { - %call1 = xla_gpu.pure_call @callee() : () -> (f32) + %call1 = xla.pure_call @callee() : () -> (f32) %new_v = arith.addf %call1, %r : f32 scf.yield %new_v : f32 } @@ -39,7 +39,7 @@ func.func private @arg_callee(%arg0: f32, %arg1: f32) -> f32 { func.func @arg_caller() -> f32 { %cst0 = arith.constant 0.0 : f32 %cst1 = arith.constant 1.0 : f32 - %call = xla_gpu.pure_call @arg_callee(%cst0, %cst1) : (f32, f32) -> (f32) + %call = xla.pure_call @arg_callee(%cst0, %cst1) : (f32, f32) -> (f32) return %call : f32 } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir index ff97c1b46f708d..442fe5e9291572 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-expand-float-ops -canonicalize | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-gpu-expand-float-ops -canonicalize | FileCheck %s module { func.func @tanh(%arg0: f32) -> f32 { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index d35dc71ddad023..a64112eee9b671 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-flatten-tensors \ +// RUN: emitters_opt %s -split-input-file -xla-gpu-flatten-tensors \ // RUN: --verify-diagnostics | FileCheck %s func.func @tensor_extract( @@ -8,12 +8,12 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]"> +// CHECK: #[[$MAP:.+]] = #xla.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]"> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, // CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index) -> f32 { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]], %[[J]]) +// CHECK: %[[INDEX:.*]] = xla.apply_indexing #[[$MAP]](%[[I]], %[[J]]) // CHECK: tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32> // ----- @@ -45,13 +45,13 @@ func.func @update(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { } func.func @pure_call(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { - %updated_tensor = xla_gpu.pure_call @update(%arg0) + %updated_tensor = xla.pure_call @update(%arg0) : (tensor<10x24xf32>) -> (tensor<10x24xf32>) func.return %updated_tensor : tensor<10x24xf32> } // CHECK-LABEL: func.func @pure_call( // CHECK-SAME: %[[TENSOR:.*]]: tensor<240xf32>) -> tensor<240xf32> { -// CHECK-NEXT: xla_gpu.pure_call @update(%[[TENSOR]]) +// CHECK-NEXT: xla.pure_call @update(%[[TENSOR]]) // CHECK-SAME: : (tensor<240xf32>) -> tensor<240xf32> // CHECK-NEXT: return @@ -59,21 +59,21 @@ func.func @pure_call(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) -> (tensor<2x4xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { + %ret = xla.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { ^bb0(%current : f32): %c42 = arith.constant 1.0 : f32 %add = arith.minimumf %current, %c42 : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]"> +// CHECK: #[[$MAP:.+]] = #xla.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]"> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]] +// CHECK: %[[INDEX:.*]] = xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[I]], %[[J]]) -// CHECK: xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32> +// CHECK: xla.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32> // ----- @@ -93,8 +93,8 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) } {some_attr} return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1024) -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 32 + 5) +// CHECK: #[[$MAP0:.+]] = #xla.indexing_map<"(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla.indexing_map<"(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop( // CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, // CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { @@ -106,17 +106,17 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]] // CHECK-SAME: step %[[C32]] // CHECK-SAME: iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]]) -// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]]) +// CHECK: %[[IND0:.*]] = xla.apply_indexing #[[$MAP0]](%[[I]]) // CHECK: %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]] -// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]) +// CHECK: %[[IND1:.*]] = xla.apply_indexing #[[$MAP1]](%[[I]]) // CHECK: %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]] // CHECK: scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32> // ----- -#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map = #xla.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map1 = #xla.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map2 = #xla.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]"> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { @@ -124,18 +124,18 @@ func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %c3999 = arith.constant 3999 : index %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bl_x = gpu.block_id x {xla.range = [0 : index, 393749 : index]} - %0 = xla_gpu.apply_indexing #map(%th_x, %bl_x) + %0 = xla.apply_indexing #map(%th_x, %bl_x) %extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32> %1 = arith.index_cast %extracted : i32 to index %2 = arith.cmpi ule, %1, %c3999 : index %3 = scf.if %2 -> (tensor<4000x4x9xf32>) { - %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x) - %5 = xla_gpu.apply_indexing #map2(%th_x, %bl_x) + %4 = xla.apply_indexing #map1(%th_x, %bl_x) + %5 = xla.apply_indexing #map2(%th_x, %bl_x) %elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32> - %atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> { + %atomic_rmw = xla.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> { ^bb0(%arg4: f32): %6 = arith.addf %arg4, %elem : f32 - xla_gpu.yield %6 : f32 + xla.yield %6 : f32 } scf.yield %atomic_rmw : tensor<4000x4x9xf32> } else { @@ -222,12 +222,12 @@ func.func @vector_extract(%arg0: vector<2x3xf32>, %arg1: index) -> f32 { %v = vector.extract %arg0[%arg1, 2] : f32 from vector<2x3xf32> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 3 + 2), +// CHECK: #[[$MAP:.+]] = #xla.indexing_map<"(d0) -> (d0 * 3 + 2), // CHECK-SAME: domain: d0 in [0, 1] // CHECK-LABEL: func.func @vector_extract( // CHECK-SAME: %[[SRC:.*]]: vector<6xf32>, %[[I:.*]]: index) -> f32 { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK: %[[INDEX:.*]] = xla.apply_indexing #[[$MAP]](%[[I]]) // CHECK: vector.extract %[[SRC]][%[[INDEX]]] : f32 from vector<6xf32> // ----- @@ -238,12 +238,12 @@ func.func @vector_insert(%arg0: vector<10x24xf32>, %i: index) %out = vector.insert %scalar, %arg0 [1, %i] : f32 into vector<10x24xf32> func.return %out : vector<10x24xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 24), +// CHECK: #[[$MAP:.+]] = #xla.indexing_map<"(d0) -> (d0 + 24), // CHECK-SAME: domain: d0 in [0, 23] // CHECK-LABEL: func.func @vector_insert( // CHECK-SAME: %[[VECTOR:.*]]: vector<240xf32>, %[[I:.*]]: index) -> // CHECK-SAME: vector<240xf32> { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK: %[[INDEX:.*]] = xla.apply_indexing #[[$MAP]](%[[I]]) // CHECK: vector.insert {{.*}}, %[[VECTOR]] [%[[INDEX]]] // CHECK-SAME: : f32 into vector<240xf32> @@ -257,13 +257,13 @@ func.func @update(%arg0: vector<10x24xf32>) -> vector<10x24xf32> { } func.func @pure_call_vector(%arg0: vector<10x24xf32>) -> vector<10x24xf32> { - %updated_vector = xla_gpu.pure_call @update(%arg0) + %updated_vector = xla.pure_call @update(%arg0) : (vector<10x24xf32>) -> (vector<10x24xf32>) func.return %updated_vector : vector<10x24xf32> } // CHECK-LABEL: func.func @pure_call_vector( // CHECK-SAME: %[[VECTOR:.*]]: vector<240xf32>) -> vector<240xf32> { -// CHECK-NEXT: xla_gpu.pure_call @update(%[[VECTOR]]) +// CHECK-NEXT: xla.pure_call @update(%[[VECTOR]]) // CHECK-SAME: : (vector<240xf32>) -> vector<240xf32> // CHECK-NEXT: return @@ -287,8 +287,8 @@ func.func @for_loop_vector(%t0: vector<32x1024xf32>, %t1: vector<64x8x4xf32>) return %for#0, %for#1, %c0_f32 : vector<32x1024xf32>, vector<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1024) -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 32 + 5) +// CHECK: #[[$MAP0:.+]] = #xla.indexing_map<"(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla.indexing_map<"(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop_vector( // CHECK-SAME: %[[V0:.*]]: vector<32768xf32>, // CHECK-SAME: %[[V1:.*]]: vector<2048xf32>) -> @@ -301,9 +301,9 @@ func.func @for_loop_vector(%t0: vector<32x1024xf32>, %t1: vector<64x8x4xf32>) // CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]] // CHECK-SAME: step %[[C32]] // CHECK-SAME: iter_args(%[[V0_:.*]] = %[[V0]], %[[V1_:.*]] = %[[V1]]) -// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]]) +// CHECK: %[[IND0:.*]] = xla.apply_indexing #[[$MAP0]](%[[I]]) // CHECK: %[[UPD0:.*]] = vector.insert %[[F32]], %[[V0_]] [%[[IND0]]] -// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]) +// CHECK: %[[IND1:.*]] = xla.apply_indexing #[[$MAP1]](%[[I]]) // CHECK: %[[UPD1:.*]] = vector.insert %[[F32]], %[[V1_]] [%[[IND1]]] // CHECK: scf.yield %[[UPD0]], %[[UPD1]] : // CHECK-SAME: vector<32768xf32>, vector<2048xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir index f2926856623ca7..95af80f7a857d3 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir @@ -1,7 +1,7 @@ -// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-fuse-loops \ +// RUN: emitters_opt -split-input-file %s -xla-gpu-fuse-loops \ // RUN: | FileCheck %s -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (d1 floordiv 30," " ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," " (d1 mod 6) * 32 + d0 mod 32)," @@ -9,7 +9,7 @@ " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," " (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> -#indexing_map1 = #xla_gpu.indexing_map<"(th_x, d1)[s0, s1] ->" +#indexing_map1 = #xla.indexing_map<"(th_x, d1)[s0, s1] ->" " (0," " th_x mod 32," " th_x floordiv 32 + s0 * 4)," @@ -23,31 +23,31 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> %0 = math.exp %extracted : f32 %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> - xla_gpu.yield %1 : vector<8x1xf32> + xla.yield %1 : vector<8x1xf32> } - %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop_0 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> - xla_gpu.yield %inserted : tensor<1x32x33xf32> + xla.yield %inserted : tensor<1x32x33xf32> } %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> return %synced_tensor : tensor<1x32x33xf32> } -// CHECK: #[[$FUSED_MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> +// CHECK: #[[$FUSED_MAP:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> // CHECK-SAME: (d1 floordiv 30, ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, // CHECK-SAME: (d1 mod 6) * 32 + d0 mod 32, 0, d0 mod 32, d0 floordiv 32 + s0 * 4), // CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 599], // CHECK-SAME: s0 in [0, 7], s1 in [0, 0], (d1 mod 6) * 32 + d0 mod 32 in [0, 169] -// CHECK: %[[FUSED_LOOP:.*]] = xla_gpu.loop {{.*}} in #[[$FUSED_MAP]] +// CHECK: %[[FUSED_LOOP:.*]] = xla.loop {{.*}} in #[[$FUSED_MAP]] // CHECK-NOT: vector.insert // CHECK-NOT: vector.extract // CHECK: %[[EXTRACTED:.*]] = tensor.extract @@ -58,7 +58,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (d1 floordiv 30," " ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," " (d1 mod 6) * 32 + d0 mod 32)," @@ -66,7 +66,7 @@ func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," " (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map1 = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," " d0 floordiv 32 + s0 * 4)," @@ -80,31 +80,31 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> %0 = math.exp %extracted : f32 %1 = vector.insert %0, %iter [%j, %i] : f32 into vector<8x1xf32> - xla_gpu.yield %1 : vector<8x1xf32> + xla.yield %1 : vector<8x1xf32> } - %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop_0 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> - xla_gpu.yield %inserted : tensor<1x32x33xf32> + xla.yield %inserted : tensor<1x32x33xf32> } return %xla_loop_0 : tensor<1x32x33xf32> } // CHECK-LABEL: @do_not_fuse_index_mismatch -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.insert -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.extract // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (d1 floordiv 30," " ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," " (d1 mod 6) * 32 + d0 mod 32)," @@ -112,7 +112,7 @@ func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1 " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," " (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map1 = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," " d0 floordiv 32 + s0 * 4)," @@ -126,18 +126,18 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> %0 = math.exp %extracted : f32 %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> - xla_gpu.yield %1 : vector<8x1xf32> + xla.yield %1 : vector<8x1xf32> } - %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop_0 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> - xla_gpu.yield %inserted : tensor<1x32x33xf32> + xla.yield %inserted : tensor<1x32x33xf32> } %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> %0 = vector.extract %xla_loop [2, 0] : f32 from vector<8x1xf32> @@ -145,14 +145,14 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x } // CHECK-LABEL: @do_not_fuse_multiple_uses -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.insert -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.extract // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (d1 floordiv 30," " ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," " (d1 mod 6) * 32 + d0 mod 32)," @@ -160,7 +160,7 @@ func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," " (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map1 = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," " d0 floordiv 32 + s0 * 4)," @@ -174,32 +174,32 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> %0 = math.exp %extracted : f32 %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> - xla_gpu.yield %1 : vector<8x1xf32> + xla.yield %1 : vector<8x1xf32> } - %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop_0 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> - xla_gpu.yield %inserted : tensor<1x32x33xf32> + xla.yield %inserted : tensor<1x32x33xf32> } %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> return %synced_tensor : tensor<1x32x33xf32> } // CHECK-LABEL: @do_not_fuse_map_domain_mismatch -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.insert -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.extract // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (d1 floordiv 30," " ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," " (d1 mod 6) * 32 + d0 mod 32)," @@ -207,7 +207,7 @@ func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> ten " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0]," " (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map1 = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " (0," " d0 mod 32," " d0 floordiv 32 + s0 * 4)," @@ -221,32 +221,32 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> %0 = math.exp %extracted : f32 %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> - xla_gpu.yield %1 : vector<8x1xf32> + xla.yield %1 : vector<8x1xf32> } - %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + %xla_loop_0 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> - xla_gpu.yield %inserted : tensor<1x32x33xf32> + xla.yield %inserted : tensor<1x32x33xf32> } %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> return %synced_tensor : tensor<1x32x33xf32> } // CHECK-LABEL: @do_not_fuse_map_constraint_mismatch -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.insert -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.extract // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1, s2] ->" " (d1 floordiv 30," " ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," " (d1 mod 6) * 32 + d0 mod 32)," @@ -254,7 +254,7 @@ func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> " d0 in [0, 127], d1 in [0, 599]," " s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," " (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" +#indexing_map1 = #xla.indexing_map<"(d0, d1)[s0, s1, s2] ->" " (0," " d0 mod 32," " d0 floordiv 32 + s0 * 4)," @@ -268,38 +268,38 @@ func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1 %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j, %k] + %xla_loop = xla.loop (%tid, %bid)[%i, %j, %k] -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> %0 = math.exp %extracted : f32 %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> - xla_gpu.yield %1 : vector<8x1xf32> + xla.yield %1 : vector<8x1xf32> } - %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j, %k] + %xla_loop_0 = xla.loop (%tid, %bid)[%i, %j, %k] -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> - xla_gpu.yield %inserted : tensor<1x32x33xf32> + xla.yield %inserted : tensor<1x32x33xf32> } %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> return %synced_tensor : tensor<1x32x33xf32> } // CHECK-LABEL: @do_not_fuse_unused_loop_iv -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.insert -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: vector.extract // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " ((d0 floordiv 32) * 8192 + d1 * 8 + s0 * 32768 + (d0 floordiv 4) mod 8," " d0 mod 4)," " domain:" " d0 in [0, 127], d1 in [0, 1023]," " s0 in [0, 2], s1 in [0, 0]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0) -> " +#indexing_map1 = #xla.indexing_map<"(d0) -> " " ((d0 floordiv 4) mod 8192," " d0 mod 4)," " domain:" @@ -311,21 +311,21 @@ func.func @fuse_identical_independent_loops(%arg0: tensor<8192x4xf64>, %bid = gpu.block_id x {xla.range = [0 : index, 1023 : index]} %cst_2 = arith.constant 0.50000000000000089 : f64 %cst = arith.constant 0 : index - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map iter_args(%iter = %arg1) -> (tensor<98304x4xf64>) { - %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %0:2 = xla.apply_indexing #indexing_map1(%ra) %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> %3 = arith.mulf %extracted, %cst_2 : f64 %inserted = tensor.insert %3 into %iter[%ra, %rb] : tensor<98304x4xf64> - xla_gpu.yield %inserted : tensor<98304x4xf64> + xla.yield %inserted : tensor<98304x4xf64> } - %xla_loop_1 = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + %xla_loop_1 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map iter_args(%iter = %arg2) -> (tensor<98304x4xf64>) { - %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %0:2 = xla.apply_indexing #indexing_map1(%ra) %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> %inserted = tensor.insert %extracted into %iter[%ra, %rb] : tensor<98304x4xf64> - xla_gpu.yield %inserted : tensor<98304x4xf64> + xla.yield %inserted : tensor<98304x4xf64> } return %xla_loop_1 : tensor<98304x4xf64> } @@ -334,22 +334,22 @@ func.func @fuse_identical_independent_loops(%arg0: tensor<8192x4xf64>, // CHECK-SAME: (%[[ARG0:.*]]: tensor<8192x4xf64>, // CHECK-SAME: %[[ARG1:.*]]: tensor<98304x4xf64>, // CHECK-SAME: %[[ARG2:.*]]: tensor<98304x4xf64>) -// CHECK: %[[LOOP0:.*]], %[[LOOP1:.*]] = xla_gpu.loop +// CHECK: %[[LOOP0:.*]], %[[LOOP1:.*]] = xla.loop // CHECK-SAME: -> (%[[RA:.*]], %[[RB:.*]]) in // CHECK-SAME: iter_args(%[[ITER0:.*]] = %[[ARG1]], %[[ITER1:.*]] = %[[ARG2]]) // CHECK: tensor.insert {{.*}} into %[[ITER0]][%[[RA]], %[[RB]]] // CHECK: tensor.insert {{.*}} into %[[ITER1]][%[[RA]], %[[RB]]] -// CHECK: xla_gpu.yield {{.*}} : tensor<98304x4xf64>, tensor<98304x4xf64> +// CHECK: xla.yield {{.*}} : tensor<98304x4xf64>, tensor<98304x4xf64> // ----- -#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +#indexing_map = #xla.indexing_map<"(d0, d1)[s0, s1] ->" " ((d0 floordiv 32) * 8192 + d1 * 8 + s0 * 32768 + (d0 floordiv 4) mod 8," " d0 mod 4)," " domain:" " d0 in [0, 127], d1 in [0, 1023]," " s0 in [0, 2], s1 in [0, 0]"> -#indexing_map1 = #xla_gpu.indexing_map<"(d0) -> " +#indexing_map1 = #xla.indexing_map<"(d0) -> " " ((d0 floordiv 4) mod 8192," " d0 mod 4)," " domain:" @@ -360,26 +360,26 @@ func.func @do_not_fuse_dependent_loops(%arg0: tensor<8192x4xf64>, %bid = gpu.block_id x {xla.range = [0 : index, 1023 : index]} %cst_2 = arith.constant 0.50000000000000089 : f64 %cst = arith.constant 0 : index - %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + %xla_loop = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map iter_args(%iter = %arg1) -> (tensor<98304x4xf64>) { - %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %0:2 = xla.apply_indexing #indexing_map1(%ra) %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> %3 = arith.mulf %extracted, %cst_2 : f64 %inserted = tensor.insert %3 into %iter[%ra, %rb] : tensor<98304x4xf64> - xla_gpu.yield %inserted : tensor<98304x4xf64> + xla.yield %inserted : tensor<98304x4xf64> } %dependency = tensor.insert %cst_2 into %xla_loop[%cst, %cst] : tensor<98304x4xf64> - %xla_loop_1 = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + %xla_loop_1 = xla.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map iter_args(%iter = %dependency) -> (tensor<98304x4xf64>) { - %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %0:2 = xla.apply_indexing #indexing_map1(%ra) %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> %inserted = tensor.insert %extracted into %iter[%ra, %rb] : tensor<98304x4xf64> - xla_gpu.yield %inserted : tensor<98304x4xf64> + xla.yield %inserted : tensor<98304x4xf64> } return %xla_loop_1 : tensor<98304x4xf64> } // CHECK-LABEL: @do_not_fuse_dependent_loops -// CHECK-COUNT-2: xla_gpu.loop \ No newline at end of file +// CHECK-COUNT-2: xla.loop \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir index f15b37b040b848..fac0ff321778f4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-erase-dead-functions -inline | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-erase-dead-functions -inline | FileCheck %s module { func.func private @mul(%a: f32, %b: f32) -> f32 { @@ -8,21 +8,21 @@ module { func.func private @add(%a: f32, %b: f32) -> f32 { %add = arith.addf %a, %b : f32 - %ret = xla_gpu.pure_call @mul(%add, %add) : (f32, f32) -> (f32) + %ret = xla.pure_call @mul(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } func.func @caller(%a: f32, %b: f32) -> f32 { - %ret = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) return %ret : f32 } } // CHECK-LABEL: module { // CHECK: @caller -// CHECK-NOT: xla_gpu.pure_call @add +// CHECK-NOT: xla.pure_call @add // CHECK: arith.addf -// CHECK-NOT: xla_gpu.pure_call @mul +// CHECK-NOT: xla.pure_call @mul // CHECK: arith.mulf // ----- @@ -30,7 +30,7 @@ module { module { func.func @fused_computation(%arg0: tensor<2xf32> {xla.slice_index = 0 : index}, %arg1: tensor<2xf32> {xla.slice_index = 1 : index}, %arg2: tensor<2xf32> {xla.slice_index = 2 : index}) -> tensor<2xf32> attributes {xla.entry} { %0 = gpu.thread_id x {xla.range = [0 : index, 1 : index]} - %1 = xla_gpu.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 + %1 = xla.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 %inserted = tensor.insert %1 into %arg2[%0] : tensor<2xf32> return %inserted : tensor<2xf32> } @@ -48,7 +48,7 @@ module { // CHECK-LABEL: module { // CHECK: @fused_computation -// CHECK-NOT: xla_gpu.pure_call @add +// CHECK-NOT: xla.pure_call @add // CHECK: gpu.thread_id // CHECK-NEXT: tensor.extract // CHECK-NEXT: tensor.extract @@ -80,13 +80,13 @@ module { func.func private @add(%a: f32, %b: f32) -> f32 { %add = arith.addf %a, %b : f32 - %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + %ret = xla.pure_call @large(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } func.func @caller(%a: f32, %b: f32) -> f32 { - %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + %add = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla.pure_call @large(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } } @@ -94,8 +94,8 @@ module { // CHECK-LABEL: module { // CHECK: @caller // CHECK: arith.addf -// CHECK: xla_gpu.pure_call @large -// CHECK: xla_gpu.pure_call @large +// CHECK: xla.pure_call @large +// CHECK: xla.pure_call @large // ----- @@ -106,15 +106,15 @@ module { } func.func @caller(%a: f32, %b: f32) -> f32 { - %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = xla_gpu.pure_call @add(%add, %add) : (f32, f32) -> (f32) + %add = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla.pure_call @add(%add, %add) : (f32, f32) -> (f32) return %ret : f32 } } // CHECK-LABEL: module { // CHECK: @caller -// CHECK-NOT: xla_gpu.pure_call +// CHECK-NOT: xla.pure_call // CHECK: arith.addf // CHECK: arith.addf @@ -129,48 +129,48 @@ module { return %start : f32 } func.func private @fib2(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib0(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) + %a = xla.pure_call @fib0(%start) : (f32) -> (f32) + %b = xla.pure_call @fib1(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func private @fib3(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) + %a = xla.pure_call @fib1(%start) : (f32) -> (f32) + %b = xla.pure_call @fib2(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func private @fib4(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) + %a = xla.pure_call @fib2(%start) : (f32) -> (f32) + %b = xla.pure_call @fib3(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } // When inlining the other functions into @fib5, this function exceeds the // threshold for inlining. func.func private @fib5(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) + %a = xla.pure_call @fib3(%start) : (f32) -> (f32) + %b = xla.pure_call @fib4(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } // As we do not inline @fib5 into @fib6, this function stays below the // threshold for inlining. func.func private @fib6(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) + %a = xla.pure_call @fib4(%start) : (f32) -> (f32) + %b = xla.pure_call @fib5(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func private @fib7(%start : f32) -> f32 { - %a = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) - %b = xla_gpu.pure_call @fib6(%start) : (f32) -> (f32) + %a = xla.pure_call @fib5(%start) : (f32) -> (f32) + %b = xla.pure_call @fib6(%start) : (f32) -> (f32) %ret = arith.addf %a, %b : f32 return %ret : f32 } func.func @caller(%a: f32) -> f32 { - %ret = xla_gpu.pure_call @fib7(%a) : (f32) -> (f32) + %ret = xla.pure_call @fib7(%a) : (f32) -> (f32) return %ret : f32 } } @@ -178,12 +178,12 @@ module { // CHECK-LABEL: module { // CHECK: @caller // CHECK: arith.constant 0.000000e+00 -// CHECK: xla_gpu.pure_call @fib5 +// CHECK: xla.pure_call @fib5 // CHECK: arith.addf // CHECK: arith.addf // CHECK: arith.addf // CHECK: arith.addf -// CHECK: xla_gpu.pure_call @fib5 +// CHECK: xla.pure_call @fib5 // CHECK: arith.addf // CHECK: arith.addf @@ -196,7 +196,7 @@ module { } func.func @caller(%a: f32, %b: f32) -> complex { - %ret = xla_gpu.pure_call @complex(%a, %b) : (f32, f32) -> (complex) + %ret = xla.pure_call @complex(%a, %b) : (f32, f32) -> (complex) return %ret : complex } } @@ -214,7 +214,7 @@ module { } func.func private @callee1(%a: f32) -> f32 { - %c1 = xla_gpu.pure_call @callee2(%a) : (f32) -> (f32) + %c1 = xla.pure_call @callee2(%a) : (f32) -> (f32) %b0 = arith.addf %a, %a : f32 %b1 = arith.addf %b0, %a : f32 %b2 = arith.addf %b1, %a : f32 @@ -223,18 +223,18 @@ module { %b5 = arith.addf %b4, %a : f32 %b6 = arith.addf %b5, %a : f32 %b7 = arith.addf %b6, %a : f32 - %c2 = xla_gpu.pure_call @callee2(%b7) : (f32) -> (f32) + %c2 = xla.pure_call @callee2(%b7) : (f32) -> (f32) %ret = arith.addf %c1, %c2 : f32 return %ret : f32 } func.func private @dead(%a: f32) -> f32 { - %ret = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %ret = xla.pure_call @callee1(%a) : (f32) -> (f32) return %ret : f32 } func.func @caller(%a: f32, %b: f32) -> f32 { - %ret = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %ret = xla.pure_call @callee1(%a) : (f32) -> (f32) return %ret : f32 } } @@ -242,7 +242,7 @@ module { // CHECK-LABEL: module { // CHECK-NOT: func.func // CHECK: func.func @caller -// CHECK-NOT: xla_gpu.pure_call +// CHECK-NOT: xla.pure_call // CHECK-NOT: func.func // ----- @@ -265,7 +265,7 @@ module { } func.func private @callee2(%a: f32) -> f32 { - %call = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %call = xla.pure_call @callee1(%a) : (f32) -> (f32) %b0 = arith.addf %a, %a : f32 %b1 = arith.addf %b0, %a : f32 %b2 = arith.addf %b1, %a : f32 @@ -281,8 +281,8 @@ module { } func.func @caller(%a: f32, %b: f32) -> f32 { - %call1 = xla_gpu.pure_call @callee2(%a) : (f32) -> (f32) - %call2 = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %call1 = xla.pure_call @callee2(%a) : (f32) -> (f32) + %call2 = xla.pure_call @callee1(%a) : (f32) -> (f32) %ret = arith.addf %call1, %call2 : f32 return %ret : f32 } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index ec53d9f64a86ca..455837a698bee0 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -1,24 +1,24 @@ -// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ +// RUN: emitters_opt %s --allow-unregistered-dialect -split-input-file \ // RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 6}'" \ // RUN: | FileCheck %s -// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ +// RUN: emitters_opt %s --allow-unregistered-dialect -split-input-file \ // RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 7}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-VOLTA -// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ +// RUN: emitters_opt %s --allow-unregistered-dialect -split-input-file \ // RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 8}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-AMPERE -// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ +// RUN: emitters_opt %s --allow-unregistered-dialect -split-input-file \ // RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 9}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-HOPPER -// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ +// RUN: emitters_opt %s --allow-unregistered-dialect -split-input-file \ // RUN: -xla-gpu-lower-tensors="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx908:sramecc+:xnack\"}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100 -// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ +// RUN: emitters_opt %s --allow-unregistered-dialect -split-input-file \ // RUN: -xla-gpu-lower-tensors="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx90a:sramecc+:xnack\"}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200 @@ -232,11 +232,11 @@ func.func @transpose_shared(%in: tensor<1024xf32>, // ----- func.func @atomic_rmw_f32(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf32> { ^bb0(%current : f32): %c42 = arith.constant 1.0 : f32 %add = arith.minimumf %current, %c42 : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } return %ret : tensor<8xf32> } @@ -251,11 +251,11 @@ func.func @atomic_rmw_f32(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { func.func @atomic_rmw_f16(%in: tensor<8xf16>, %i: index) -> (tensor<8xf16>) { - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf16> { ^bb0(%current : f16): %c1 = arith.constant 1.0 : f16 %add = arith.addf %current, %c1 : f16 - xla_gpu.yield %add : f16 + xla.yield %add : f16 } return %ret : tensor<8xf16> } @@ -282,9 +282,9 @@ func.func @atomic_rmw_f16(%in: tensor<8xf16>, %i: index) func.func @atomic_rmw_overwrite(%in: tensor<8xf16>, %i: index) -> (tensor<8xf16>) { %c1 = arith.constant 1.0 : f16 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf16> { ^bb0(%current : f16): - xla_gpu.yield %c1 : f16 + xla.yield %c1 : f16 } return %ret : tensor<8xf16> } @@ -353,9 +353,9 @@ func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) func.func @direct_atomic_rmw_overwrite(%in: tensor<8xi32>, %i: index) -> (tensor<8xi32>) { %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xi32> { ^bb0(%current : i32): - xla_gpu.yield %c2 : i32 + xla.yield %c2 : i32 } return %ret : tensor<8xi32> } @@ -369,10 +369,10 @@ func.func @direct_atomic_rmw_overwrite(%in: tensor<8xi32>, func.func @direct_atomic_rmw_addi(%in: tensor<8xi32>, %i: index) -> (tensor<8xi32>) { %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xi32> { ^bb0(%current : i32): %min = arith.addi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 + xla.yield %c2 : i32 } return %ret : tensor<8xi32> } @@ -386,10 +386,10 @@ func.func @direct_atomic_rmw_addi(%in: tensor<8xi32>, func.func @direct_atomic_rmw_maxsi(%in: tensor<8xi32>, %i: index) -> (tensor<8xi32>) { %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xi32> { ^bb0(%current : i32): %min = arith.maxsi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 + xla.yield %c2 : i32 } return %ret : tensor<8xi32> } @@ -403,10 +403,10 @@ func.func @direct_atomic_rmw_maxsi(%in: tensor<8xi32>, func.func @direct_atomic_rmw_maxui(%in: tensor<8xi32>, %i: index) -> (tensor<8xi32>) { %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xi32> { ^bb0(%current : i32): %min = arith.maxui %current, %c2 : i32 - xla_gpu.yield %c2 : i32 + xla.yield %c2 : i32 } return %ret : tensor<8xi32> } @@ -420,10 +420,10 @@ func.func @direct_atomic_rmw_maxui(%in: tensor<8xi32>, func.func @direct_atomic_rmw_minsi(%in: tensor<8xi32>, %i: index) -> (tensor<8xi32>) { %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xi32> { ^bb0(%current : i32): %min = arith.minsi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 + xla.yield %c2 : i32 } return %ret : tensor<8xi32> } @@ -437,10 +437,10 @@ func.func @direct_atomic_rmw_minsi(%in: tensor<8xi32>, func.func @direct_atomic_rmw_minui(%in: tensor<8xi32>, %i: index) -> (tensor<8xi32>) { %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xi32> { ^bb0(%current : i32): %min = arith.minui %current, %c2 : i32 - xla_gpu.yield %c2 : i32 + xla.yield %c2 : i32 } return %ret : tensor<8xi32> } @@ -454,10 +454,10 @@ func.func @direct_atomic_rmw_minui(%in: tensor<8xi32>, func.func @direct_atomic_rmw_fadd_f32(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { %c2 = arith.constant 2.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf32> { ^bb0(%current : f32): %min = arith.addf %current, %c2 : f32 - xla_gpu.yield %c2 : f32 + xla.yield %c2 : f32 } return %ret : tensor<8xf32> } @@ -493,10 +493,10 @@ func.func @direct_atomic_rmw_fadd_f32(%in: tensor<8xf32>, func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, %i: index) -> (tensor<8xf16>) { %c2 = arith.constant 2.0 : f16 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf16> { ^bb0(%current : f16): %min = arith.addf %current, %c2 : f16 - xla_gpu.yield %c2 : f16 + xla.yield %c2 : f16 } return %ret : tensor<8xf16> } @@ -527,10 +527,10 @@ func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<8xbf16>, %i: index) -> (tensor<8xbf16>) { %c2 = arith.constant 2.0 : bf16 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xbf16> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xbf16> { ^bb0(%current : bf16): %min = arith.addf %current, %c2 : bf16 - xla_gpu.yield %c2 : bf16 + xla.yield %c2 : bf16 } return %ret : tensor<8xbf16> } @@ -547,10 +547,10 @@ func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<8xbf16>, func.func @direct_atomic_rmw_fadd_f64(%in: tensor<8xf64>, %i: index) -> (tensor<8xf64>) { %c2 = arith.constant 2.0 : f64 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf64> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf64> { ^bb0(%current : f64): %min = arith.addf %current, %c2 : f64 - xla_gpu.yield %c2 : f64 + xla.yield %c2 : f64 } return %ret : tensor<8xf64> } @@ -580,10 +580,10 @@ func.func @direct_atomic_rmw_fadd_f64(%in: tensor<8xf64>, func.func @direct_atomic_rmw_maximumf(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { %c2 = arith.constant 2.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xf32> { ^bb0(%current : f32): %min = arith.maximumf %current, %c2 : f32 - xla_gpu.yield %c2 : f32 + xla.yield %c2 : f32 } return %ret : tensor<8xf32> } @@ -616,10 +616,10 @@ func.func @direct_atomic_rmw_maximumf(%in: tensor<8xf32>, func.func @atomic_rmw_c32(%in: tensor<8xcomplex>, %i: index) -> (tensor<8xcomplex>) { - %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xcomplex> { + %ret = xla.atomic_rmw %in[%i] : tensor<8xcomplex> { ^bb0(%current : complex): %a = complex.add %current, %current : complex - xla_gpu.yield %a : complex + xla.yield %a : complex } return %ret : tensor<8xcomplex> } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir index f981cef83029d8..8829e74a675f35 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir @@ -1,22 +1,22 @@ -// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf \ +// RUN: emitters_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf \ // RUN: --split-input-file | FileCheck %s -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," +#map = #xla.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { - %sum = xla_gpu.loop (%dim)[%i, %j] -> (%ra, %rb) + %sum = xla.loop (%dim)[%i, %j] -> (%ra, %rb) in #map iter_args(%sum_ = %init) -> (f32) { %t = tensor.extract %input[%ra, %rb] : tensor<1024x32xf32> %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } {xla.range = [0 : index, 42 : index]} func.return %sum : f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + s1), -// CHECK-DAG: #[[$MAPA:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1), -// CHECK-DAG: #[[$MAPB:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s1 - 1), +// CHECK-DAG: #[[$MAP:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (s0 + s1), +// CHECK-DAG: #[[$MAPA:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (s0 + 1), +// CHECK-DAG: #[[$MAPB:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (s1 - 1), // CHECK-LABEL: func.func @loop_op( // CHECK-SAME: %[[IN:.*]]: tensor<1024x32xf32>, @@ -35,7 +35,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // CHECK: %[[INNER_FOR:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C33]] // CHECK-SAME: step %[[C1]] iter_args(%[[INIT__:.*]] = %[[INIT_]]) -> (f32) { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing +// CHECK: %[[INDEX:.*]] = xla.apply_indexing // CHECK-SAME: #[[$MAP]](%[[DIM]])[%[[I]], %[[J]]] // CHECK: %[[VAL1:.*]] = arith.cmpi sge, %[[INDEX]], %[[C0]] : index // CHECK: %[[VAL2:.*]] = arith.cmpi sle, %[[INDEX]], %[[C90]] : index @@ -45,8 +45,8 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // CHECK: %[[VAL6:.*]] = arith.andi %[[VAL4]], %[[VAL5]] : i1 // CHECK: %[[INBOUNDS:.*]] = arith.andi %[[VAL3]], %[[VAL6]] : i1 // CHECK: %[[IF_RESULT:.*]] = scf.if %[[INBOUNDS]] -> (f32) { -// CHECK: %[[RA:.*]] = xla_gpu.apply_indexing #[[$MAPA]](%[[DIM]])[%[[I]], %[[J]]] -// CHECK: %[[RB:.*]] = xla_gpu.apply_indexing #[[$MAPB]](%[[DIM]])[%[[I]], %[[J]]] +// CHECK: %[[RA:.*]] = xla.apply_indexing #[[$MAPA]](%[[DIM]])[%[[I]], %[[J]]] +// CHECK: %[[RB:.*]] = xla.apply_indexing #[[$MAPB]](%[[DIM]])[%[[I]], %[[J]]] // CHECK: %[[ELEM:.*]] = tensor.extract %[[IN]][%[[RA]], %[[RB]]] // CHECK: %[[SUM:.*]] = arith.addf %[[INIT__]], %[[ELEM]] : f32 // CHECK: scf.yield %[[SUM]] : f32 @@ -59,14 +59,14 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," +#map = #xla.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_yields_value_from_above(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { - %sum = xla_gpu.loop (%dim)[%i, %j] -> (%ra, %rb) + %sum = xla.loop (%dim)[%i, %j] -> (%ra, %rb) in #map iter_args(%sum_ = %init) -> (f32) { - xla_gpu.yield %init : f32 + xla.yield %init : f32 } func.return %sum : f32 } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index f53ccc1e8ae54f..12e395ff58ae40 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \ +// RUN: emitters_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \ // RUN: | FileCheck %s func.func @combiner(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { @@ -16,13 +16,13 @@ func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { // CHECK-DAG: %[[C32:.*]] = arith.constant 32 // CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]] // CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]] -// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @combiner(%[[A]], %[[B]], %[[A4H]], %[[B4H]]) +// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla.pure_call @combiner(%[[A]], %[[B]], %[[A4H]], %[[B4H]]) // CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]] // CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]] -// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @combiner(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]]) +// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla.pure_call @combiner(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]]) // CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]] // CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]] -// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @combiner(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]]) +// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla.pure_call @combiner(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]]) // CHECK: return %[[AB1_0]], %[[AB1_1]] // ----- @@ -85,7 +85,7 @@ func.func @shuffler_i8(%a: i8) -> i8 { func.func @predicated_insert( %v: i32, %tensor: tensor<2xi32>, %index: index, %cond: i1) -> tensor<2xi32> { - %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond + %ret = xla.predicated_insert %v into %tensor[%index] if %cond : tensor<2xi32> return %ret : tensor<2xi32> } @@ -105,7 +105,7 @@ func.func @predicated_insert( func.func @predicated_extract( %v: i32, %tensor: tensor<2xi32>, %index: index, %cond: i1) -> i32 { - %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v + %ret = xla.predicated_extract %tensor[%index] if %cond else %v : tensor<2xi32> return %ret : i32 } @@ -124,8 +124,8 @@ func.func @predicated_extract( func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> { @@ -133,24 +133,24 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> func.return %0 : !xla_gpu.indexed_vector<32x2x2xf32, #map1> } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP1:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1) // CHECK: @materialize(%[[INPUT:.*]]: tensor<32x64xf32>, %[[INDEX1:.*]]: index, %[[INDEX2:.*]]: index) // CHECK: %[[INIT_VEC:.*]] = arith.constant {{.*}} : vector<2x2xf32> -// CHECK: xla_gpu.loop (%[[INDEX1]], %[[INDEX2]])[%[[S0:.*]], %[[S1:.*]]] +// CHECK: xla.loop (%[[INDEX1]], %[[INDEX2]])[%[[S0:.*]], %[[S1:.*]]] // CHECK-SAME: -> (%[[MAP_RESULT1:.*]], %[[MAP_RESULT2:.*]]) in // CHECK-SAME: #[[$MAP]] iter_args(%[[ITER_ARG:.*]] = %[[INIT_VEC]]) -// CHECK: %[[PURE_CALL:.*]] = xla_gpu.pure_call @exp(%[[INPUT]], %[[MAP_RESULT1]], %[[MAP_RESULT2]]) +// CHECK: %[[PURE_CALL:.*]] = xla.pure_call @exp(%[[INPUT]], %[[MAP_RESULT1]], %[[MAP_RESULT2]]) // CHECK: vector.insert %[[PURE_CALL]], %[[ITER_ARG]] [%[[S0]], %[[S1]]] -// CHECK xla_gpu.yield %{{.*}} : vector<2x2xf32> +// CHECK xla.yield %{{.*}} : vector<2x2xf32> // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -158,32 +158,32 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, : !xla_gpu.indexed_vector<32x64xf32, #map> -> tensor<32x64xf32> func.return %0 : tensor<32x64xf32> } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1) +// CHECK-DAG: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP1:.*]] = #xla.indexing_map<"(d0, d1) -> (d0 mod 16, d1) // CHECK: @insert(%[[INPUT:.*]]: !xla_gpu.indexed_vector<32x64xf32, #[[$MAP]]>, // CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index, // CHECK-SAME: %[[OUTPUT:.*]]: tensor<32x64xf32>) -// CHECK: xla_gpu.loop (%[[I]], %[[J]])[%[[S0:.*]], %[[S1:.*]]] -> +// CHECK: xla.loop (%[[I]], %[[J]])[%[[S0:.*]], %[[S1:.*]]] -> // CHECK-SAME: (%[[MAP_RESULT1:.*]], %[[MAP_RESULT2:.*]]) in #[[$MAP]] // CHECK-SAME: iter_args(%[[TENSOR:.*]] = %[[OUTPUT]]) // CHECK: %[[SCALAR:.*]] = vector.extract %{{.*}}[%[[S0]], %[[S1]]] // CHECK-SAME: : f32 from vector<2x2xf32> -// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing +// CHECK: %[[MAP1_RESULT:.*]]:2 = xla.apply_indexing // CHECK-SAME: #[[$MAP1]](%[[MAP_RESULT1]], %[[MAP_RESULT2]]) // CHECK: %[[NEW_TENSOR:.*]] = tensor.insert %[[SCALAR]] // CHECK-SAME: into %[[TENSOR]][%[[MAP1_RESULT]]#0, %[[MAP1_RESULT]]#1] -// CHECK: xla_gpu.yield %[[NEW_TENSOR]] +// CHECK: xla.yield %[[NEW_TENSOR]] // ----- func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> +#map2 = #xla.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -199,8 +199,8 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.func private @exp(%p0: tensor<32x64xcomplex>, %i: index, %j: index) -> complex -#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> +#map = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> func.func @materialize_complex( %input: tensor<32x64xcomplex>, %output: tensor<32x64xcomplex>, @@ -215,20 +215,20 @@ func.func @materialize_complex( // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: xla_gpu.loop ({{.*}})[%[[I:.*]], %[[J:.*]]] +// CHECK: xla.loop ({{.*}})[%[[I:.*]], %[[J:.*]]] // CHECK-SAME: iter_args(%[[ITER:.*]] = {{.*}}) -// CHECK: %[[PURE_CALL:.*]] = xla_gpu.pure_call +// CHECK: %[[PURE_CALL:.*]] = xla.pure_call // CHECK-SAME: complex // CHECK: %[[REAL:.*]] = complex.re %[[PURE_CALL]] // CHECK: %[[IMAG:.*]] = complex.im %[[PURE_CALL]] // CHECK: %[[TEMP:.*]] = vector.insert %[[REAL]], %[[ITER]] [%[[C0]], %[[I]], %[[J]]] // CHECK: %[[FINAL:.*]] = vector.insert %[[IMAG]], %[[TEMP]] [%[[C1]], %[[I]], %[[J]]] -// CHECK: xla_gpu.yield %[[FINAL]] : vector<2x3x4xf32> +// CHECK: xla.yield %[[FINAL]] : vector<2x3x4xf32> // ----- -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> -#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> +#map1 = #xla.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> +#map2 = #xla.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert_complex( %input: !xla_gpu.indexed_vector<32x3x4xcomplex, #map1>, %output: tensor<32x64xcomplex>, @@ -247,10 +247,10 @@ func.func @insert_complex( // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[VECTOR:.*]] = builtin.unrealized_conversion_cast %[[INPUT]] // CHECK-SAME: to vector<2x3x4xf32> -// CHECK: xla_gpu.loop ({{.*}})[%[[I:.*]], %[[J:.*]]] +// CHECK: xla.loop ({{.*}})[%[[I:.*]], %[[J:.*]]] // CHECK-SAME: iter_args(%[[ITER:.*]] = {{.*}}) // CHECK: %[[REAL:.*]] = vector.extract %[[VECTOR]][%[[C0]], %[[I]], %[[J]]] // CHECK: %[[IMAG:.*]] = vector.extract %[[VECTOR]][%[[C1]], %[[I]], %[[J]]] // CHECK: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]] // CHECK: %[[INSERTED:.*]] = tensor.insert %[[COMPLEX]] into %[[ITER]] -// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex> +// CHECK: xla.yield %[[INSERTED]] : tensor<32x64xcomplex> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir index f3670a42d3a3b4..9d6cf8bbbd805a 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors -xla-gpu-merge-pointers | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-gpu-lower-tensors -xla-gpu-merge-pointers | FileCheck %s module { func.func private @tensorargs(%arg0: tensor<43xf32> {xla.slice_index = 0}, diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir index 1094b51a2a6841..8fe920c3abfcd4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir @@ -1,8 +1,8 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31]"> -#map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31]"> -#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]"> +#map = #xla.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31]"> +#map1 = #xla.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31]"> +#map2 = #xla.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]"> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -21,23 +21,23 @@ module { %1 = arith.cmpi eq, %0, %c0 : index %2 = arith.divui %thread_id_x, %c32 : index %3 = arith.cmpi ult, %thread_id_x, %c8 : index - %4 = xla_gpu.apply_indexing #map(%block_id_x) - %5 = xla_gpu.apply_indexing #map1(%block_id_x) + %4 = xla.apply_indexing #map(%block_id_x) + %5 = xla.apply_indexing #map1(%block_id_x) %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32> %6 = arith.mulf %extracted, %cst : f32 %7 = arith.addf %6, %cst : f32 %8 = math.rsqrt %7 : f32 %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) { - %18 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] + %18 = xla.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %20 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] + %20 = xla.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %22 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] + %22 = xla.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16> - %24 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] + %24 = xla.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32> %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) { - %27 = xla_gpu.apply_indexing #map2(%arg10, %thread_id_x)[%arg7] + %27 = xla.apply_indexing #map2(%arg10, %thread_id_x)[%arg7] %28 = vector.extract %25[%arg10] : f32 from vector<2xf32> %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16> %30 = arith.extf %29 : bf16 to f32 @@ -124,14 +124,14 @@ module { } } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_extract // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index // CHECK: %[[VAL0:.*]] = tensor.extract %[[ARG0:.*]][%[[C0]]] // CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]]) // CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C30]] -// CHECK-DAG: %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] +// CHECK-DAG: %[[NEXT_I:.*]] = xla.apply_indexing #[[$MAP]](%[[I]] // CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]] // CHECK-NEXT: tensor.extract %[[ARG0]][%[[NEXT_I]]] // CHECK-NEXT: yield @@ -151,7 +151,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15]">(%i) + %base = xla.apply_indexing #xla.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15]">(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> @@ -161,17 +161,17 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1), +// CHECK-DAG: #[[$MAP0:.*]] = #xla.indexing_map<"(d0) -> (d0 * 2), +// CHECK-DAG: #[[$MAP1:.*]] = #xla.indexing_map<"(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_transfer // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK: %[[BASE0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[C0]] +// CHECK: %[[BASE0:.*]] = xla.apply_indexing #[[$MAP0]](%[[C0]] // CHECK: %[[VAL0:.*]] = vector.transfer_read %[[ARG0:.*]][%[[BASE0]]] // CHECK: scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]]) // CHECK-DAG: %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C16]] -// CHECK-DAG: %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]] -// CHECK-DAG: %[[NEXT_BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[NEXT_I]] +// CHECK-DAG: %[[NEXT_I:.*]] = xla.apply_indexing #[[$MAP1]](%[[I]] +// CHECK-DAG: %[[NEXT_BASE:.*]] = xla.apply_indexing #[[$MAP0]](%[[NEXT_I]] // CHECK: %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]] // CHECK-NEXT: vector.transfer_read %[[ARG0]][%[[NEXT_BASE]]] // CHECK-NEXT: yield diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir index 9ffd7bdc0fbfd1..c799f409c7a474 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir @@ -1,38 +1,38 @@ -// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-peel-loops \ +// RUN: emitters_opt -split-input-file %s -xla-gpu-peel-loops \ // RUN: | FileCheck %s -#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" +#map = #xla.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" "d0 in [0, 3], s0 in [0, 7], s1 in [0, 10], d0 + s0 in [0, 9]," "d0 + s1 in [0, 12]"> func.func @peel_both_loops(%input: tensor<16x32xf32>, %init: f32, %dim: index) -> (f32) { - %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) + %sum = xla.loop (%dim)[%i, %j] -> (%r0, %r1) in #map iter_args(%sum_ = %init) -> (f32) { %t = tensor.extract %input[%i, %j] : tensor<16x32xf32> %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } func.return %sum : f32 } -// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]"> -// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]"> -// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]"> +// CHECK: #[[$PEELED_MAP:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP0:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP1:.*]] = #xla.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]"> // CHECK-LABEL: func.func @peel_both_loops( // CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>, // CHECK-SAME: %[[INIT:.*]]: f32, %[[DIM:.*]]: index) -// CHECK: %[[PEELED:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] -> +// CHECK: %[[PEELED:.*]] = xla.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] -> // CHECK-SAME: in #[[$PEELED_MAP]] iter_args(%[[INIT_:.*]] = %[[INIT]]) // CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] : tensor<16x32xf32> // CHECK: arith.addf %[[INIT_]] -// CHECK: %[[TAIL0:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] +// CHECK: %[[TAIL0:.*]] = xla.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] // CHECK-SAME: in #[[$TAIL_MAP0]] iter_args(%[[INIT_:.*]] = %[[PEELED]]) // CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] // CHECK: arith.addf %[[INIT_]] -// CHECK: %[[TAIL1:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] +// CHECK: %[[TAIL1:.*]] = xla.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] // CHECK-SAME: in #[[$TAIL_MAP1]] iter_args(%[[INIT_:.*]] = %[[TAIL0]]) // CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] // CHECK: arith.addf %[[INIT_]] @@ -41,25 +41,25 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (s0)," "domain: d0 in [0, 3], s0 in [0, 7]"> func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { - %sum = xla_gpu.loop (%dim)[%i] -> (%r0) + %sum = xla.loop (%dim)[%i] -> (%r0) in #map iter_args(%sum_ = %init) -> (f32) { %t = tensor.extract %input[%i] : tensor<16xf32> %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } func.return %sum : f32 } // CHECK-LABEL: func.func @not_constrained_symbol -// CHECK: xla_gpu.loop -// CHECK-NOT: xla_gpu.loop +// CHECK: xla.loop +// CHECK-NOT: xla.loop // ----- -#map = #xla_gpu.indexing_map< +#map = #xla.indexing_map< " (d0)[s0] -> (s0)," " domain:" " d0 in [0, 3]," @@ -67,14 +67,14 @@ func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, " s0 mod 5 in [0, 1]"> func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { - %sum = xla_gpu.loop (%dim)[%i] -> (%r0) + %sum = xla.loop (%dim)[%i] -> (%r0) in #map iter_args(%sum_ = %init) -> (f32) { %t = tensor.extract %input[%i] : tensor<16xf32> %add = arith.addf %sum_, %t : f32 - xla_gpu.yield %add : f32 + xla.yield %add : f32 } func.return %sum : f32 } // CHECK-LABEL: func.func @constraint_exists_after_peeling -// CHECK: xla_gpu.loop -// CHECK-NOT: xla_gpu.loop +// CHECK: xla.loop +// CHECK-NOT: xla.loop diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir index b776337bda0653..9bf77fdc2e2913 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-propagate-slice-indices | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-gpu-propagate-slice-indices | FileCheck %s module { func.func private @add(%arg0: f32, %arg1: f32) -> f32 { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir index e62a530de0e7db..4de35a8afffb8f 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s +// RUN: emitters_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { %c0 = arith.constant 0 : index @@ -62,8 +62,8 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %0 = gpu.thread_id x %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { - %2 = xla_gpu.apply_indexing - #xla_gpu.indexing_map< + %2 = xla.apply_indexing + #xla.indexing_map< "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))," "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]">[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 @@ -92,8 +92,8 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt // ----- func.func @arg_ranges(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing - #xla_gpu.indexing_map< + %0 = xla.apply_indexing + #xla.indexing_map< "()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)," "domain: s0 in [0, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0 : index @@ -107,8 +107,8 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { // ----- func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<"()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)," + %0:2 = xla.apply_indexing + #xla.indexing_map<"()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)," "domain: s0 in [-10, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -125,8 +125,8 @@ func.func @order_summands(%arg1: index) { %c4 = arith.constant 4 : index scf.for %arg2 = %c0 to %c4 step %c1 { scf.for %arg3 = %c0 to %c4 step %c1 { - %0 = xla_gpu.apply_indexing - #xla_gpu.indexing_map< + %0 = xla.apply_indexing + #xla.indexing_map< "()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)," "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]">[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index e6fea946e6e827..bdecdfea064f78 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -cse \ +// RUN: emitters_opt %s -split-input-file -xla-gpu-simplify-arith -cse \ // RUN: -canonicalize | FileCheck %s module { @@ -248,7 +248,7 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %c42_f32 = arith.constant 42.0 : f32 %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 4)," + %0 = xla.apply_indexing #xla.indexing_map<"(d0) -> (d0 mod 4)," "domain: d0 in [0, 9]">(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> @@ -263,10 +263,10 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { // ----- -#map = #xla_gpu.indexing_map< +#map = #xla.indexing_map< "(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)," "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]"> -#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)," +#map1 = #xla.indexing_map<"(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)," "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]"> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { @@ -281,8 +281,8 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, -> (tensor<2400000x9xf32>) { %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3) -> (tensor<2400000x9xf32>) { - %3 = xla_gpu.apply_indexing #map(%th_x, %bl_x)[%i, %j] - %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)[%j] + %3 = xla.apply_indexing #map(%th_x, %bl_x)[%i, %j] + %4 = xla.apply_indexing #map1(%th_x, %bl_x)[%j] %inserted = tensor.insert %c42_f32 into %arg5[%3, %4] : tensor<2400000x9xf32> scf.yield %inserted : tensor<2400000x9xf32> @@ -291,12 +291,12 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, } return %0 : tensor<2400000x9xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), // CHECK-LABEL: func.func @refine_constraints_for_symbol // ----- -#map = #xla_gpu.indexing_map< +#map = #xla.indexing_map< "(d0, d1, d2, d3, d4, d5)[s0] -> ((d0 * 4 + s0) floordiv 6, (d0 * 4 + s0) mod 6)," "domain:" "d0 in [0, 29]," @@ -323,12 +323,12 @@ func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %ar %3 = arith.index_cast %arg3 : i32 to index %4 = arith.minsi %3, %c24 : index %5 = arith.maxsi %4, %c0 : index - %xla_loop = xla_gpu.loop (%thread_id_x, %thread_id_y, %thread_id_z, %block_id_x, %block_id_y, %block_id_z)[%i] -> (%ra, %rb) in #map iter_args(%iter = %arg4) -> (tensor<20x30xf32>) { + %xla_loop = xla.loop (%thread_id_x, %thread_id_y, %thread_id_z, %block_id_x, %block_id_y, %block_id_z)[%i] -> (%ra, %rb) in #map iter_args(%iter = %arg4) -> (tensor<20x30xf32>) { %6 = arith.addi %2, %ra : index %7 = arith.addi %5, %rb : index %extracted = tensor.extract %arg1[%ra, %rb] : tensor<5x6xf32> %inserted = tensor.insert %extracted into %iter[%6, %7] : tensor<20x30xf32> - xla_gpu.yield %inserted : tensor<20x30xf32> + xla.yield %inserted : tensor<20x30xf32> } return %xla_loop : tensor<20x30xf32> } @@ -342,7 +342,7 @@ func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %ar // CHECK-SAME: xla.range = [-9223372036854775808 : index, 24 : index] // CHECK: arith.maxsi // CHECK-SAME: xla.range = [0 : index, 24 : index] -// CHECK: xla_gpu.loop +// CHECK: xla.loop // CHECK: arith.addi // CHECK-SAME: xla.range = [0 : index, 19 : index] // CHECK: arith.addi diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir index da0c8d07f9a184..8220e8ed3d1e66 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-unswitch-loops | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-gpu-unswitch-loops | FileCheck %s module { func.func @unswitchable( diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir index b9c2b1b4086278..a3b7e816bb05fb 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -1,7 +1,7 @@ -// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file \ +// RUN: emitters_opt -allow-unregistered-dialect %s -split-input-file \ // RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -11,7 +11,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -20,7 +20,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63]"> +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63]"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -28,7 +28,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK: %[[BASE:.*]] = xla.apply_indexing #[[$MAP]](%[[I]]) // CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] // CHECK-NEXT: vector.extract %[[V]][%[[J]]] @@ -36,7 +36,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 4 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0 * 4 + s0)," "domain: d0 in [0, 63], s0 in [0, 3]"> func.func @simple_read(%arg0: tensor<256xf16>) -> (f16) { %c0 = arith.constant 0 : index @@ -46,7 +46,7 @@ func.func @simple_read(%arg0: tensor<256xf16>) -> (f16) { %cst = arith.constant 0.0 : f16 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f16 { %inner = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter1 = %iter) -> f16 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<256xf16> %added = arith.addf %iter1, %extracted : f16 scf.yield %added : f16 @@ -55,7 +55,7 @@ func.func @simple_read(%arg0: tensor<256xf16>) -> (f16) { } return %outer : f16 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 63]"> +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 63]"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -63,7 +63,7 @@ func.func @simple_read(%arg0: tensor<256xf16>) -> (f16) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK: %[[BASE:.*]] = xla.apply_indexing #[[$MAP]](%[[I]]) // CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] // CHECK-NEXT: vector.extract %[[V]][%[[J]]] @@ -71,7 +71,7 @@ func.func @simple_read(%arg0: tensor<256xf16>) -> (f16) { // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 8 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0 * 8 + s0)," "domain: d0 in [0, 63], s0 in [0, 7]"> func.func @simple_read(%arg0: tensor<512xi8>) -> (i8) { %c0 = arith.constant 0 : index @@ -81,7 +81,7 @@ func.func @simple_read(%arg0: tensor<512xi8>) -> (i8) { %cst = arith.constant 0 : i8 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> i8 { %inner = scf.for %j = %c0 to %c8 step %c1 iter_args(%iter1 = %iter) -> i8 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<512xi8> %added = arith.addi %iter1, %extracted : i8 scf.yield %added : i8 @@ -90,7 +90,7 @@ func.func @simple_read(%arg0: tensor<512xi8>) -> (i8) { } return %outer : i8 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 8), domain: d0 in [0, 63]"> +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 * 8), domain: d0 in [0, 63]"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -98,7 +98,7 @@ func.func @simple_read(%arg0: tensor<512xi8>) -> (i8) { // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK: %[[BASE:.*]] = xla.apply_indexing #[[$MAP]](%[[I]]) // CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] // CHECK-NEXT: vector.extract %[[V]][%[[J]]] @@ -106,7 +106,7 @@ func.func @simple_read(%arg0: tensor<512xi8>) -> (i8) { // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0 + 1)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0 + 1)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -116,7 +116,7 @@ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -130,7 +130,7 @@ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 3 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0 * 3 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -140,7 +140,7 @@ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -154,7 +154,7 @@ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (3 * d0 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (3 * d0 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -164,7 +164,7 @@ func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<192xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -178,7 +178,7 @@ func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0 * 2)," +#map = #xla.indexing_map<"(d0)[s0] -> (d0 + s0 * 2)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -188,7 +188,7 @@ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -204,7 +204,7 @@ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { // We could vectorize this as a float vector load of double the size, but we // don't currently. -#map = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> (2 * d0 + s0)," "domain: d0 in [0, 127], s0 in [0, 1]"> func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { %c0 = arith.constant 0 : index @@ -212,7 +212,7 @@ func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (c %c2 = arith.constant 2 : index %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xcomplex> %added = complex.add %iter, %extracted : complex scf.yield %added : complex @@ -320,9 +320,9 @@ func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { // ----- -#map = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)," +#map = #xla.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)," "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]"> -#map1 = #xla_gpu.indexing_map< +#map1 = #xla.indexing_map< "(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512)," "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]"> func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, @@ -336,8 +336,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<131072xf32>, f32) { %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<131072xf32>, f32) { - %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i] - %idx = xla_gpu.apply_indexing #map1(%i, %j, %arg4)[%i] + %2 = xla.apply_indexing #map(%j, %arg4)[%i] + %idx = xla.apply_indexing #map1(%i, %j, %arg4)[%i] %extracted2 = tensor.extract %arg0[%idx] : tensor<131072xf32> %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> %3 = arith.extf %extracted3 : bf16 to f32 @@ -351,14 +351,14 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, } return %0#0, %0#1 : tensor<131072xf32>, f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7]"> -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7]"> +// CHECK-DAG: #[[$MAP:.*]] = #xla.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7]"> +// CHECK-DAG: #[[$MAP1:.*]] = #xla.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7]"> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK-DAG: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]], %[[I]]) -// CHECK-DAG: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]], %[[ARG4]], %[[I]]) +// CHECK-DAG: %[[BASE:.*]] = xla.apply_indexing #[[$MAP]](%[[ARG4]], %[[I]]) +// CHECK-DAG: %[[IDX:.*]] = xla.apply_indexing #[[$MAP1]](%[[I]], %[[ARG4]], %[[I]]) // CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] // CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[IDX]]] // CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) @@ -375,7 +375,7 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 64 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 64 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -385,7 +385,7 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -394,16 +394,16 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> ((d0 mod 16) * 4), +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(d0) -> ((d0 mod 16) * 4), // CHECK-LABEL: @remainder_with_modulo // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] +// CHECK: %[[BASE:.*]] = xla.apply_indexing #[[$MAP]](%[[I]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] // ----- -#map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 65 + s0)," +#map = #xla.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 65 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -413,7 +413,7 @@ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i)[%j] + %idx = xla.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -427,9 +427,9 @@ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," +#map0 = #xla.indexing_map<"(d0) -> (d0 + 5)," "domain: d0 in [0, 63]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," +#map1 = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { @@ -439,9 +439,9 @@ module { %c63 = arith.constant 63 : index %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %offset = xla_gpu.apply_indexing #map0(%i) + %offset = xla.apply_indexing #map0(%i) %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map1(%offset)[%j] + %idx = xla.apply_indexing #map1(%offset)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -452,18 +452,18 @@ module { } } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2 + 10), +// CHECK: #[[$MAP0:.*]] = #xla.indexing_map<"(d0) -> (d0 * 2 + 10), // CHECK-SAME: domain: d0 in [0, 63]"> // CHECK-LABEL: @apply_indexing_sequence -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] +// CHECK: %[[BASE:.*]] = xla.apply_indexing #[[$MAP0]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] // ----- -#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," +#map0 = #xla.indexing_map<"(d0) -> (d0 + 5)," "domain: d0 in [0, 63]"> -#map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," +#map1 = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { @@ -476,8 +476,8 @@ module { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { // Usually, this will be hoisted by LICM or folded, so we do not detect // this pattern. - %offset = xla_gpu.apply_indexing #map0(%i) - %idx = xla_gpu.apply_indexing #map1(%offset)[%j] + %offset = xla.apply_indexing #map0(%i) + %idx = xla.apply_indexing #map1(%offset)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index b28085692ad83a..ebcc8d2d2e4021 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -72,6 +72,7 @@ cc_library( ":passes", ":triton_fusion_emitter_legacy_matmul", ":triton_support", + ":xla_triton", ":xla_triton_passes", "//xla:autotuning_proto_cc", "//xla:permutation_util", @@ -80,6 +81,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", @@ -96,7 +98,6 @@ cc_library( "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/transforms:passes", - "//xla/service/gpu/fusions/triton:xla_triton", "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/model:triton_emitter_constraints", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 8af747aedcda03..e0128f9d7a5cf1 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -81,6 +81,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" #include "xla/autotuning.pb.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -1008,8 +1009,8 @@ void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context) { mlir_context .loadDialect(); + mlir::LLVM::LLVMDialect, xla::XlaDialect, + xla::gpu::XlaGpuDialect, ttir::xla::XlaTritonDialect>(); mlir::DialectRegistry registry; mlir::func::registerInlinerExtension(registry); mlir::LLVM::registerInlinerInterface(registry); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 187c0442ed9314..209305d1c168b2 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -216,7 +216,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> +CHECK: #indexing_map = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 @@ -225,7 +225,7 @@ CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[PID_I64:.*]] = arith.index_cast %[[PID_INDEX]] : index to i64 CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C125]], %[[PID_I64]] : i64 -CHECK-DAG: %[[OFFSET_IDX:.*]] = xla_gpu.apply_indexing #indexing_map(%[[PID_INDEX]]) +CHECK-DAG: %[[OFFSET_IDX:.*]] = xla.apply_indexing #indexing_map(%[[PID_INDEX]]) CHECK-DAG: %[[OFFSET_I64:.*]] = arith.index_castui %[[OFFSET_IDX]] : index to i64 CHECK-DAG: %[[BASE_PTR_LOAD:.*]] = tt.addptr %[[P0]], %[[OFFSET_I64]] : !tt.ptr, i64 CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR_LOAD]], [%[[SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > @@ -281,7 +281,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> +CHECK: #indexing_map = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -294,7 +294,7 @@ CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[PID_I64:.*]] = arith.index_cast %[[PID_INDEX]] : index to i64 CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C125]], %[[PID_I64]] : i64 -CHECK-DAG: %[[OFFSET_IDX:.*]] = xla_gpu.apply_indexing #indexing_map(%[[PID_INDEX]]) +CHECK-DAG: %[[OFFSET_IDX:.*]] = xla.apply_indexing #indexing_map(%[[PID_INDEX]]) CHECK-DAG: %[[OFFSET_I64:.*]] = arith.index_castui %[[OFFSET_IDX]] : index to i64 CHECK-DAG: %[[BASE_PTR0_LOAD:.*]] = tt.addptr %[[P0]], %[[OFFSET_I64]] : !tt.ptr, i64 CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR0_LOAD]], [%[[SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]]] {order = array} : > @@ -352,9 +352,9 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249]"> -CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249]"> -CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249]"> +CHECK: #[[MAP:.*]] = #xla.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP1:.*]] = #xla.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP2:.*]] = #xla.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[ZERO_64:.*]] = arith.constant 0 : i64 @@ -363,13 +363,13 @@ CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 CHECK-DAG: %[[C127:.*]] = arith.constant 127 : i64 CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index -CHECK-DAG: %[[ROW_INDEX:.*]] = xla_gpu.apply_indexing #[[MAP]](%[[PID_INDEX]] -CHECK-DAG: %[[COL_INDEX:.*]] = xla_gpu.apply_indexing #[[MAP1]](%[[PID_INDEX]] +CHECK-DAG: %[[ROW_INDEX:.*]] = xla.apply_indexing #[[MAP]](%[[PID_INDEX]] +CHECK-DAG: %[[COL_INDEX:.*]] = xla.apply_indexing #[[MAP1]](%[[PID_INDEX]] CHECK-DAG: %[[ROW_64:.*]] = arith.index_cast %[[ROW_INDEX]] : index to i64 CHECK-DAG: %[[COL_64:.*]] = arith.index_cast %[[COL_INDEX]] : index to i64 CHECK-DAG: %[[ROW_SUB:.*]] = arith.subi %[[C10]], %[[ROW_64]] : i64 CHECK-DAG: %[[COL_SUB:.*]] = arith.subi %[[C125]], %[[COL_64]] : i64 -CHECK-DAG: %[[OFFSET_IDX:.*]] = xla_gpu.apply_indexing #[[MAP2]](%[[PID_INDEX]]) +CHECK-DAG: %[[OFFSET_IDX:.*]] = xla.apply_indexing #[[MAP2]](%[[PID_INDEX]]) CHECK-DAG: %[[OFFSET_I64:.*]] = arith.index_castui %[[OFFSET_IDX]] : index to i64 CHECK-DAG: %[[BASE_PTR0_LOAD:.*]] = tt.addptr %[[P0]], %[[OFFSET_I64]] : !tt.ptr, i64 CHECK-DAG: tt.make_tensor_ptr %[[BASE_PTR0_LOAD]], [%[[ROW_SUB]], %[[COL_SUB]], %[[C127]]], {{.*}} [%[[ZERO]], %[[ZERO]], %[[ZERO]]] {order = array} : > @@ -545,8 +545,8 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047]"> -// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047]"> +// CHECK: #xla.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047]"> +// CHECK: #xla.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047]"> // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr