Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse reshapes on pointwise inputs for mlir output fusion #3569

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 59 additions & 67 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ const auto& reshaper_names()
{
// clang-format off
static const std::unordered_set<std::string> names = {
"slice",
"transpose",
"multibroadcast",
"broadcast",
Expand All @@ -220,12 +219,17 @@ const auto& reshaper_names()
return names;
}

bool is_fusable_input_op(const std::string& name)
{
return contains(reshaper_names(), name) or contains({"slice"}, name);
}

std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(contains(reshaper_names(), upper_input->name()))
while(is_fusable_input_op(upper_input->name()))
{
operation op = upper_input->get_operator();
op_stream.push_back(op);
Expand Down Expand Up @@ -364,6 +368,18 @@ create_param_map_with_literals(module_ref mm, const module* pm, const shape& sha
return ins_map;
}

instruction_ref insert_pointwise(module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args)
{
// Only used in assert
(void)mod_args;
assert(mod_args.empty());
return insert_common_op(m, ins, op, inputs, {.common_type = false});
}

instruction_ref unroll_pointwise(module& main_mod,
instruction_ref pos,
const operation& op,
Expand Down Expand Up @@ -501,9 +517,7 @@ MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins)
{
if(ins->name() != "split_fused_reduce")
return false;
auto* mod_arg = ins->module_inputs().front();
auto supported_reshapes = reshaper_names();
supported_reshapes.erase("slice");
auto* mod_arg = ins->module_inputs().front();
std::unordered_set<std::string> builtins = {"@param", "@literal", "@return"};
for(const auto i : iterator_for(*mod_arg))
{
Expand Down Expand Up @@ -629,10 +643,7 @@ struct find_mlir_fused_ops
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto reshapes = reshaper_names();
// slice is not supported
reshapes.erase("slice");
auto dot_or_conv = match::skip(match::name(reshapes))(
auto dot_or_conv = match::skip(match::name(reshaper_names()))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
Expand All @@ -650,68 +661,62 @@ struct find_mlir_fused_ops
return i != x_ins and reaches(gemm_based_op, i);
}))
return;
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());

std::unordered_map<instruction_ref, instruction_ref> map_ins;
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, pw_ins->get_shape());
auto [upper_input, op_stream] = get_fusable_input_op_stream(x_ins);
assert(upper_input == gemm_based_op);
auto prev_input = anchor_op;
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
assert(prev_input->get_shape().lens() == x_ins->get_shape().lens());
param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped
// input to pointwise in new fused module
fuse_input_ops(mm, gemm_based_op->inputs(), &map_ins);

bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1;
Comment on lines +669 to 670
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be tangential to this PR, but what happens if one of the reshape ops has multi outputs? ie. something like:

dot -> reshape -> pointwise -> .... -> dot
          |_____________________________|

Copy link
Collaborator Author

@pfultz2 pfultz2 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop below checks that all the reshapes between dot and pointwise have only one output, and it will output both if any reshape is used more than once. For the aux inputs to pointwise, it doesnt really matter:

┌───┐┌─┐      
│dot││x│      
└┬──┘└┬┘      
 │┌───▽──────┐
 ││reshape   │
 │└┬───────┬─┘
┌▽─▽──────┐│  
│pointwise││  
└┬────────┘│  
┌▽─────────▽┐ 
│convolution│ 
└───────────┘ 

After mlir fusion it just becomes:

┌──────────┐                         
│x         │                         
└┬────────┬┘                         
┌▽──────┐┌▽─────────────────────────┐
│reshape││mlir_dot_reshape_pointwise│
└┬──────┘└┬─────────────────────────┘
┌▽────────▽─┐                        
│convolution│                        
└───────────┘ 

DCE wont remove the reshape since its used, and the duplicated reshape doesnt really matter since it is potentially an aliasing-like operator.

auto reshaped_gemm = x_ins;
std::vector<instruction_ref> reshapes_vec;
while(reshaped_gemm != gemm_based_op)
std::vector<instruction_ref> inss_to_insert;
auto reshape_ins = x_ins;
for(; reshape_ins != gemm_based_op; reshape_ins = reshape_ins->inputs().front())
{
reshapes_vec.push_back(reshaped_gemm);
gemm_has_multi_outs = gemm_has_multi_outs or reshaped_gemm->outputs().size() > 1;
reshaped_gemm = reshaped_gemm->inputs().at(0);
inss_to_insert.push_back(reshape_ins);
gemm_has_multi_outs |= reshape_ins->outputs().size() > 1;
}
reshapes_vec.push_back(reshaped_gemm);
inss_to_insert.push_back(gemm_based_op);
std::reverse(inss_to_insert.begin(), inss_to_insert.end());
mm->add_instructions(inss_to_insert, &map_ins);

auto return_vals = mm->fuse(*pm, pw_ins->inputs(), &param_map);
fuse_input_ops(mm, pw_ins->inputs(), &map_ins);
auto rins = mm->fuse(*pm, pw_ins->inputs(), &map_ins, &insert_pointwise);
if(gemm_has_multi_outs)
{
return_vals.insert(return_vals.begin(), anchor_op);
rins.push_back(map_ins.at(gemm_based_op));
}
mm->add_return(return_vals);
mm->add_return(rins);

std::vector<instruction_ref> inputs;
std::copy_if(pw_ins->inputs().begin(),
pw_ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != x_ins; });
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
auto inputs = find_inputs(map_ins, &mpm.get_module(), mm);
auto fused_ins = mpm.get_module().insert_instruction(
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
if(gemm_has_multi_outs)
{
auto fused_ins = mpm.get_module().insert_instruction(
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
mpm.get_module().replace_instruction(
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused_ins);
auto dot_ins = mpm.get_module().insert_instruction(
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
// move all the reshape instructions and original GEMM instruction after the fused op to
// avoid generating invalid migraphx program
for(const auto& orig_i : reverse(reshapes_vec))
pw_ins,
migraphx::make_op("get_tuple_elem", {{"index", rins.size() - 1}}),
fused_ins);

// move all the reshape instructions after the fused op to avoid
// generating invalid migraphx program since the reshapes can be
// used by the replaced dot_ins
for(instruction_ref x : inss_to_insert)
{
mpm.get_module().move_instruction(orig_i, pw_ins);
if(x == gemm_based_op)
continue;
mpm.get_module().move_instruction(x, pw_ins);
}

mpm.get_module().replace_instruction(gemm_based_op, dot_ins);
if(rins.size() == 2)
{
mpm.get_module().replace_instruction(
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
}
}
else
{
mpm.get_module().replace_instruction(
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
mpm.get_module().replace_instruction(pw_ins, fused_ins);
}
}
};
Expand Down Expand Up @@ -851,9 +856,8 @@ struct find_mlir_standalone_attention_op
map_main_to_mattn[fused_reduce] = softmax;

// all preceeding ops should be fusable ops
if(not std::all_of(m_gemm1, softmax, [](auto i) {
return (is_pointwise_op_supported_by_mlir(i) or
contains(reshaper_names(), i.name()));
if(not std::all_of(m_gemm1, softmax, [](const instruction& i) {
return (is_pointwise_op_supported_by_mlir(i) or is_fusable_input_op(i.name()));
}))
return;

Expand Down Expand Up @@ -938,18 +942,6 @@ struct find_pointwise_mlir
return contains(op_names, op_ins->name());
}

static instruction_ref insert_pointwise(module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args)
{
// Only used in assert
(void)mod_args;
assert(mod_args.empty());
return insert_common_op(m, ins, op, inputs, {.common_type = false});
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
Expand Down
Loading
Loading