diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 65a27a76ad1..abc2f41198e 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -234,6 +234,24 @@ get_fusable_input_op_stream(instruction_ref lower_input) return {upper_input, op_stream}; } +std::tuple> +get_fusable_output_op_stream(instruction_ref upper_output) +{ + instruction_ref lower_output = upper_output; + std::vector op_stream; + while(contains(reshaper_names(), lower_output->name())) + { + operation op = lower_output->get_operator(); + op_stream.push_back(op); + + if(lower_output->outputs().size() > 1) + break; + + lower_output = lower_output->outputs().at(0); + } + return {lower_output, op_stream}; +} + void fuse_input_ops(module_ref mm, const std::vector& inputs, std::unordered_map* map_ins) @@ -257,6 +275,20 @@ void fuse_input_ops(module_ref mm, } } +void fuse_output_ops(module_ref mm, + const std::vector &outputs) +{ + size_t input_cnt = mm->get_parameters().size(); + for(instruction_ref output : outputs) + { + auto [lower_output, op_stream] = get_fusable_output_op_stream(output); + for(const auto& op : (op_stream)) + { + lower_output = mm->add_instruction(op, lower_output->inputs()); + } + } +} + std::tuple> fuse_input_ops_and_gemm_based_op(module_ref mm, const std::vector& gemm_based_op_inputs, @@ -1040,6 +1072,37 @@ struct find_unpack_int4_mlir_op } }; +struct find_mlir_single_output_reshaper_op +{ + auto matcher() const + { + auto reshapes = reshaper_names(); + return match::name("gpu::mlir_op")(match::any_of[match::outputs()](match::name(reshapes).bind("output_reshaper"))); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + static int find_count = 0; + auto ins = r.result; + + if (ins->outputs().size() > 1) + return; + + auto out_op = r.instructions["output_reshaper"]; + auto* mlir_module = ins->module_inputs().front(); + std::cout << "Found :" << find_count++ << " reshapers" << std::endl; + + mpm.get_module().debug_print(ins); + mpm.get_module().debug_print(out_op); + mlir_module->debug_print(); + + + fuse_output_ops(mlir_module, out_op>outputs()); + + mlir_module->debug_print(); + } +}; + } // namespace #endif // MIGRAPHX_MLIR @@ -1092,6 +1155,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_pointwise_mlir{}); match::find_matches(mpm, find_unpack_int4_mlir_op{}); + match::find_matches(mpm, find_mlir_single_output_reshaper_op{}); #else (void)mpm;