Skip to content

Commit

Permalink
Loop Partitioning Policy through Stage::partition(VarOrRVar, LoopPart…
Browse files Browse the repository at this point in the history
…itionPolicy) (#7914)

* Loop Partitioning Policy through Stage::partition(VarOrRVar, LoopPartitionPolicy)

* Renamed LoopPartitionPolicy to Partition. Added tests in boundary_conditions to verify correctness of the code with and without loop partitioning. Added tests that validates that disabling loop partitioning works.

* Include error-test for when partitioning is always requested, but none was performed.
  • Loading branch information
mcourteaux authored Oct 31, 2023
1 parent 0134c40 commit 1865101
Show file tree
Hide file tree
Showing 43 changed files with 354 additions and 115 deletions.
2 changes: 1 addition & 1 deletion src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NoOpCollapsingMutator : public IRMutator {
if (is_no_op(body)) {
return body;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3720,7 +3720,7 @@ void bounds_test() {
Buffer<int32_t> in(10);
in.set_name("input");

Stmt loop = For::make("x", 3, 10, ForType::Serial, DeviceAPI::Host,
Stmt loop = For::make("x", 3, 10, ForType::Serial, Partition::Auto, DeviceAPI::Host,
Provide::make("output",
{Add::make(Call::make(in, input_site_1),
Call::make(in, input_site_2))},
Expand Down
4 changes: 2 additions & 2 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ class BoundsInference : public IRMutator {
}
}

return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body);
}

Scope<> let_vars_in_scope;
Expand Down Expand Up @@ -1389,7 +1389,7 @@ Stmt bounds_inference(Stmt s,
s = Block::make(Evaluate::make(marker), s);

// Add a synthetic outermost loop to act as 'root'.
s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s);
s = For::make("<outermost>", 0, 1, ForType::Serial, Partition::Never, DeviceAPI::None, s);

s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
outputs, func_bounds, target)
Expand Down
2 changes: 1 addition & 1 deletion src/CanonicalizeGPUVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CanonicalizeGPUVars : public IRMutator {
body.same_as(op->body)) {
return op;
} else {
return For::make(name, min, extent, op->for_type, op->device_api, body);
return For::make(name, min, extent, op->for_type, op->partition_policy, op->device_api, body);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class InjectHVXLocks : public IRMutator {
body = acquire_hvx_context(body, target);
body = substitute("uses_hvx", true, body);
Stmt new_for = For::make(op->name, op->min, op->extent, op->for_type,
op->device_api, body);
op->partition_policy, op->device_api, body);
Stmt prolog =
IfThenElse::make(uses_hvx_var, call_halide_qurt_hvx_unlock());
Stmt epilog =
Expand Down Expand Up @@ -407,7 +407,7 @@ class InjectHVXLocks : public IRMutator {
// halide_qurt_unlock
// }
s = For::make(op->name, op->min, op->extent, op->for_type,
op->device_api, body);
op->partition_policy, op->device_api, body);
}

uses_hvx = old_uses_hvx;
Expand Down
19 changes: 18 additions & 1 deletion src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class Deserializer {

DeviceAPI deserialize_device_api(Serialize::DeviceAPI device_api);

Partition deserialize_partition(Serialize::Partition partition);

Call::CallType deserialize_call_type(Serialize::CallType call_type);

VectorReduce::Operator deserialize_vector_reduce_op(Serialize::VectorReduceOp vector_reduce_op);
Expand Down Expand Up @@ -201,6 +203,20 @@ ForType Deserializer::deserialize_for_type(Serialize::ForType for_type) {
}
}

Partition Deserializer::deserialize_partition(Serialize::Partition partition) {
switch (partition) {
case Serialize::Partition::Partition_Auto:
return Halide::Partition::Auto;
case Serialize::Partition::Partition_Never:
return Halide::Partition::Never;
case Serialize::Partition::Partition_Always:
return Halide::Partition::Always;
default:
user_error << "unknown loop partition policy " << partition << "\n";
return Halide::Partition::Auto;
}
}

DeviceAPI Deserializer::deserialize_device_api(Serialize::DeviceAPI device_api) {
switch (device_api) {
case Serialize::DeviceAPI::DeviceAPI_None:
Expand Down Expand Up @@ -505,9 +521,10 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt)
const auto min = deserialize_expr(for_stmt->min_type(), for_stmt->min());
const auto extent = deserialize_expr(for_stmt->extent_type(), for_stmt->extent());
const ForType for_type = deserialize_for_type(for_stmt->for_type());
const Partition partition_policy = deserialize_partition(for_stmt->partition_policy());
const DeviceAPI device_api = deserialize_device_api(for_stmt->device_api());
const auto body = deserialize_stmt(for_stmt->body_type(), for_stmt->body());
return For::make(name, min, extent, for_type, device_api, body);
return For::make(name, min, extent, for_type, partition_policy, device_api, body);
}
case Serialize::Stmt_Store: {
const auto *store_stmt = (const Serialize::Store *)stmt;
Expand Down
26 changes: 25 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
const auto &iter = std::find_if(dims.begin(), dims.end(),
[&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
if (iter == dims.end()) {
Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar};
Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto};
dims.insert(dims.end() - 1, d);
}
}
Expand Down Expand Up @@ -1631,6 +1631,24 @@ Stage &Stage::unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail
return *this;
}

Stage &Stage::partition(const VarOrRVar &var, Partition policy) {
definition.schedule().touched() = true;
bool found = false;
vector<Dim> &dims = definition.schedule().dims();
for (auto &dim : dims) {
if (var_name_match(dim.var, var.name())) {
found = true;
dim.partition_policy = policy;
}
}
user_assert(found)
<< "In schedule for " << name()
<< ", could not find var " << var.name()
<< " to set loop partition policy.\n"
<< dump_argument_list();
return *this;
}

Stage &Stage::tile(const VarOrRVar &x, const VarOrRVar &y,
const VarOrRVar &xo, const VarOrRVar &yo,
const VarOrRVar &xi, const VarOrRVar &yi,
Expand Down Expand Up @@ -2318,6 +2336,12 @@ Func &Func::unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail)
return *this;
}

Func &Func::partition(const VarOrRVar &var, Partition policy) {
invalidate_cache();
Stage(func, func.definition(), 0).partition(var, policy);
return *this;
}

Func &Func::bound(const Var &var, Expr min, Expr extent) {
user_assert(!min.defined() || Int(32).can_represent(min.type())) << "Can't represent min bound in int32\n";
user_assert(extent.defined()) << "Extent bound of a Func can't be undefined\n";
Expand Down
8 changes: 8 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ class Stage {
Stage &parallel(const VarOrRVar &var, const Expr &task_size, TailStrategy tail = TailStrategy::Auto);
Stage &vectorize(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
Stage &unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
Stage &partition(const VarOrRVar &var, Partition partition_policy);
Stage &tile(const VarOrRVar &x, const VarOrRVar &y,
const VarOrRVar &xo, const VarOrRVar &yo,
const VarOrRVar &xi, const VarOrRVar &yi, const Expr &xfactor, const Expr &yfactor,
Expand Down Expand Up @@ -1442,6 +1443,13 @@ class Func {
* dimension of the split. 'factor' must be an integer. */
Func &unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);

/** Set the loop partition policy. Loop partitioning can be useful to
* optimize boundary conditions (such as clamp_edge). Loop partitioning
* splits a for loop into three for loops: a prologue, a steady-state,
* and an epilogue.
* The default policy is Auto. */
Func &partition(const VarOrRVar &var, Partition partition_policy);

/** Statically declare that the range over which a function should
* be evaluated is given by the second and third arguments. This
* can let Halide perform some optimizations. E.g. if you know
Expand Down
21 changes: 11 additions & 10 deletions src/FuseGPUThreadLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class NormalizeDimensionality : public IRMutator {
}
while (max_depth < block_size.threads_dimensions()) {
string name = thread_names[max_depth];
s = For::make("." + name, 0, 1, ForType::GPUThread, device_api, s);
s = For::make("." + name, 0, 1, ForType::GPUThread, Partition::Never, device_api, s);
max_depth++;
}
return s;
Expand Down Expand Up @@ -398,7 +398,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator {
Expr v = Variable::make(Int(32), loop_name);
host_side_preamble = substitute(op->name, v, host_side_preamble);
host_side_preamble = For::make(loop_name, new_min, new_extent,
ForType::Serial, DeviceAPI::None, host_side_preamble);
ForType::Serial, Partition::Never, DeviceAPI::None, host_side_preamble);
if (old_preamble.defined()) {
host_side_preamble = Block::make(old_preamble, host_side_preamble);
}
Expand All @@ -407,7 +407,8 @@ class ExtractSharedAndHeapAllocations : public IRMutator {
}

return For::make(op->name, new_min, new_extent,
op->for_type, op->device_api, body);
op->for_type, op->partition_policy,
op->device_api, body);
}

Stmt visit(const Block *op) override {
Expand Down Expand Up @@ -1101,7 +1102,7 @@ class ExtractRegisterAllocations : public IRMutator {
allocations.swap(old);
}

return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->device_api, body);
return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->partition_policy, op->device_api, body);
}
}

Expand Down Expand Up @@ -1262,7 +1263,7 @@ class InjectThreadBarriers : public IRMutator {
body = Block::make(body, make_barrier(0));
}
return For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, body);
op->for_type, op->partition_policy, op->device_api, body);
} else {
return IRMutator::visit(op);
}
Expand Down Expand Up @@ -1413,14 +1414,14 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator {
string thread_id = "." + thread_names[0];
// Add back in any register-level allocations
body = register_allocs.rewrap(body, thread_id);
body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->device_api, body);
body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->partition_policy, op->device_api, body);

// Rewrap the whole thing in other loops over threads
for (int i = 1; i < block_size.threads_dimensions(); i++) {
thread_id = "." + thread_names[i];
body = register_allocs.rewrap(body, thread_id);
body = For::make("." + thread_names[i], 0, block_size.num_threads(i),
ForType::GPUThread, op->device_api, body);
ForType::GPUThread, op->partition_policy, op->device_api, body);
}
thread_id.clear();
body = register_allocs.rewrap(body, thread_id);
Expand All @@ -1436,7 +1437,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator {
if (body.same_as(op->body)) {
return op;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body);
}
} else {
return IRMutator::visit(op);
Expand Down Expand Up @@ -1505,7 +1506,7 @@ class ZeroGPULoopMins : public IRMutator {
internal_assert(op);
Expr adjusted = Variable::make(Int(32), op->name) + op->min;
Stmt body = substitute(op->name, adjusted, op->body);
stmt = For::make(op->name, 0, op->extent, op->for_type, op->device_api, body);
stmt = For::make(op->name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body);
}
return stmt;
}
Expand Down Expand Up @@ -1587,7 +1588,7 @@ class AddConditionToALoop : public IRMutator {
return IRMutator::visit(op);
}

return For::make(op->name, op->min, op->extent, op->for_type, op->device_api,
return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api,
IfThenElse::make(condition, op->body, Stmt()));
}

Expand Down
2 changes: 1 addition & 1 deletion src/HexagonOffload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ class InjectHexagonRpc : public IRMutator {
if (is_const_one(loop->extent)) {
body = LetStmt::make(loop->name, loop->min, loop->body);
} else {
body = For::make(loop->name, loop->min, loop->extent, loop->for_type,
body = For::make(loop->name, loop->min, loop->extent, loop->for_type, loop->partition_policy,
DeviceAPI::None, loop->body);
}

Expand Down
7 changes: 6 additions & 1 deletion src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,11 @@ Stmt ProducerConsumer::make_consume(const std::string &name, Stmt body) {
return ProducerConsumer::make(name, false, std::move(body));
}

Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body) {
Stmt For::make(const std::string &name,
Expr min, Expr extent,
ForType for_type, Partition partition_policy,
DeviceAPI device_api,
Stmt body) {
internal_assert(min.defined()) << "For of undefined\n";
internal_assert(extent.defined()) << "For of undefined\n";
internal_assert(min.type() == Int(32)) << "For with non-integer min\n";
Expand All @@ -354,6 +358,7 @@ Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type,
node->min = std::move(min);
node->extent = std::move(extent);
node->for_type = for_type;
node->partition_policy = partition_policy;
node->device_api = device_api;
node->body = std::move(body);
return node;
Expand Down
8 changes: 7 additions & 1 deletion src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "Buffer.h"
#include "Expr.h"
#include "FunctionPtr.h"
#include "LoopPartitioningDirective.h"
#include "ModulusRemainder.h"
#include "Parameter.h"
#include "PrefetchDirective.h"
Expand Down Expand Up @@ -807,8 +808,13 @@ struct For : public StmtNode<For> {
ForType for_type;
DeviceAPI device_api;
Stmt body;
Partition partition_policy;

static Stmt make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body);
static Stmt make(const std::string &name,
Expr min, Expr extent,
ForType for_type, Partition partition_policy,
DeviceAPI device_api,
Stmt body);

bool is_unordered_parallel() const {
return Halide::Internal::is_unordered_parallel(for_type);
Expand Down
2 changes: 1 addition & 1 deletion src/IRMutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Stmt IRMutator::visit(const For *op) {
return op;
}
return For::make(op->name, std::move(min), std::move(extent),
op->for_type, op->device_api, std::move(body));
op->for_type, op->partition_policy, op->device_api, std::move(body));
}

Stmt IRMutator::visit(const Store *op) {
Expand Down
19 changes: 17 additions & 2 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ std::ostream &operator<<(std::ostream &out, const TailStrategy &t) {
return out;
}

std::ostream &operator<<(std::ostream &out, const Partition &p) {
switch (p) {
case Partition::Auto:
out << "Auto";
break;
case Partition::Never:
out << "Never";
break;
case Partition::Always:
out << "Always";
break;
}
return out;
}

ostream &operator<<(ostream &stream, const LoopLevel &loop_level) {
return stream << "loop_level("
<< (loop_level.defined() ? loop_level.to_string() : "undefined")
Expand All @@ -206,12 +221,12 @@ void IRPrinter::test() {
internal_assert(expr_source.str() == "((x + 3)*((y/2) + 17))");

Stmt store = Store::make("buf", (x * 17) / (x - 3), y - 1, Parameter(), const_true(), ModulusRemainder());
Stmt for_loop = For::make("x", -2, y + 2, ForType::Parallel, DeviceAPI::Host, store);
Stmt for_loop = For::make("x", -2, y + 2, ForType::Parallel, Partition::Auto, DeviceAPI::Host, store);
vector<Expr> args(1);
args[0] = x % 3;
Expr call = Call::make(i32, "buf", args, Call::Extern);
Stmt store2 = Store::make("out", call + 1, x, Parameter(), const_true(), ModulusRemainder(3, 5));
Stmt for_loop2 = For::make("x", 0, y, ForType::Vectorized, DeviceAPI::Host, store2);
Stmt for_loop2 = For::make("x", 0, y, ForType::Vectorized, Partition::Auto, DeviceAPI::Host, store2);

Stmt producer = ProducerConsumer::make_produce("buf", for_loop);
Stmt consumer = ProducerConsumer::make_consume("buf", for_loop2);
Expand Down
5 changes: 4 additions & 1 deletion src/IRPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ std::ostream &operator<<(std::ostream &stream, const DeviceAPI &);
std::ostream &operator<<(std::ostream &stream, const MemoryType &);

/** Emit a halide tail strategy in human-readable form */
std::ostream &operator<<(std::ostream &stream, const TailStrategy &t);
std::ostream &operator<<(std::ostream &stream, const TailStrategy &);

/** Emit a halide loop partitioning policy in human-readable form */
std::ostream &operator<<(std::ostream &stream, const Partition &);

/** Emit a halide LoopLevel in human-readable form */
std::ostream &operator<<(std::ostream &stream, const LoopLevel &);
Expand Down
Loading

0 comments on commit 1865101

Please sign in to comment.