Skip to content

Commit

Permalink
Intial commit to target all reshaper opsops after mlir_op TODO get th…
Browse files Browse the repository at this point in the history
…is to fuse for outputs
  • Loading branch information
TedThemistokleous committed Jan 10, 2025
1 parent 9bc94ab commit 000d5a4
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,24 @@ get_fusable_input_op_stream(instruction_ref lower_input)
return {upper_input, op_stream};
}

std::tuple<instruction_ref, std::vector<operation>>
get_fusable_output_op_stream(instruction_ref upper_output)
{
instruction_ref lower_output = upper_output;
std::vector<operation> op_stream;
while(contains(reshaper_names(), lower_output->name()))
{
operation op = lower_output->get_operator();
op_stream.push_back(op);

if(lower_output->outputs().size() > 1)
break;

lower_output = lower_output->outputs().at(0);
}
return {lower_output, op_stream};
}

void fuse_input_ops(module_ref mm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins)
Expand All @@ -257,6 +275,20 @@ void fuse_input_ops(module_ref mm,
}
}

void fuse_output_ops(module_ref mm,
const std::vector<instruction_ref> &outputs)
{
size_t input_cnt = mm->get_parameters().size();

Check warning on line 281 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'input_cnt' [clang-diagnostic-unused-variable,-warnings-as-errors]

Check warning on line 281 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Variable 'input_cnt' is assigned a value that is never used. [unreadVariable]
for(instruction_ref output : outputs)
{
auto [lower_output, op_stream] = get_fusable_output_op_stream(output);
for(const auto& op : (op_stream))
{
lower_output = mm->add_instruction(op, lower_output->inputs());
}
}
}

std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
Expand Down Expand Up @@ -1040,6 +1072,37 @@ struct find_unpack_int4_mlir_op
}
};

struct find_mlir_single_output_reshaper_op
{
auto matcher() const
{
auto reshapes = reshaper_names();

Check warning on line 1079 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 'reshapes' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]
return match::name("gpu::mlir_op")(match::any_of[match::outputs()](match::name(reshapes).bind("output_reshaper")));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
static int find_count = 0;
auto ins = r.result;

if (ins->outputs().size() > 1)
return;

auto out_op = r.instructions["output_reshaper"];
auto* mlir_module = ins->module_inputs().front();
std::cout << "Found :" << find_count++ << " reshapers" << std::endl;

mpm.get_module().debug_print(ins);
mpm.get_module().debug_print(out_op);
mlir_module->debug_print();


fuse_output_ops(mlir_module, out_op>outputs());

Check warning on line 1100 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

invalid operands to binary expression ('instruction_ref' (aka '_List_iterator<migraphx::instruction>') and '(lambda at src/include/migraphx/matcher.hpp:587:12)') [clang-diagnostic-error]

Check warning on line 1100 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

use of undeclared identifier 'outputs'; did you mean 'match::outputs'? [clang-diagnostic-error]

mlir_module->debug_print();
}
};

} // namespace

#endif // MIGRAPHX_MLIR
Expand Down Expand Up @@ -1092,6 +1155,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const

match::find_matches(mpm, find_pointwise_mlir{});
match::find_matches(mpm, find_unpack_int4_mlir_op{});
match::find_matches(mpm, find_mlir_single_output_reshaper_op{});

#else
(void)mpm;
Expand Down

0 comments on commit 000d5a4

Please sign in to comment.