Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelization of ConstProp compilation #3042

Merged
merged 18 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp"

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"

#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
Expand Down Expand Up @@ -849,6 +849,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
if (axes.empty())
return elms;

Type elementType = elms.getElementType();
MLIRContext *ctx = elementType.getContext();
SmallVector<unsigned, 4> sortedAxes(axes);
std::sort(sortedAxes.begin(), sortedAxes.end());
assert(
Expand Down Expand Up @@ -885,22 +887,74 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,

ShapedType reducedType = type.clone(reducedShape);
return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
// Traverse and populate each element d in dstNums.
for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) {
WideNum &d = dstNums[idxoffs.flattenedIndex];
int64_t srcPos = idxoffs[0];
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
StridesRange<1> axesRange(axesShape, {axesStrides});
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
StridesRange<1> sRange(reducedShape, {reducedStrides});
StridesRange<1> axesRange(axesShape, {axesStrides});
SmallVector<std::pair<int64_t, uint64_t>, 4> batch;
for (auto &idxoffs : sRange)
batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0]));

auto fetchBatch = [&](size_t threadNumber, bool parallel) {
// retrun all data without spliting for sequential execution.
if (!parallel)
return llvm::make_range(batch.begin(), batch.end());
// Each thread fetches the same batch size. The leftovers are set in the
// threads with small thread number.
size_t tileSize = floor(batch.size() / ctx->getNumThreads());
size_t leftovers = batch.size() % ctx->getNumThreads();
int beginOffset;
if (threadNumber < leftovers) {
// for the first few threads, it is as if the block size is larger by 1.
tileSize++;
beginOffset = threadNumber * tileSize;
} else {
// for the last threads, its as we shift the start by leftovers.
beginOffset = threadNumber * tileSize + leftovers;
}
}
int endOffset = beginOffset + tileSize;
return llvm::make_range(
batch.begin() + beginOffset, batch.begin() + endOffset);
};

auto work = [&](size_t threadNumber, bool parallel = true) {
auto tile = fetchBatch(threadNumber, parallel);
// Traverse and populate each element d in dstNums.
for (auto b : tile) {
WideNum &d = dstNums[b.first];
int64_t srcPos = b.second;
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
}
}
};
// Using 'parallelFor()' introduces large overhead. Followings are actual
// measurement results on IBM z16 to decide the 'minCount'. We measured
// 'onnx.ReduceSum()' in 'test/mlir/onnx/onnx_constprop_parallel.mlir' using
// several input size. From these results, we decided to use 2000 as the
// 'minCount'.
//
// inputCounts|Sequential | Parallel with 2 threads
// | (work()) | (parallelFor())
// | (msec) | (msec)
// --------------------------------------------------
// 400 | 0.065 | 0.153
// 800 | 0.115 | 0.164
// 1200 | 0.175 | 0.201
// 1600 | 0.226 | 0.228
// 2000 | 0.282 | 0.258
// 2400 | 0.336 | 0.284
constexpr size_t minCount = 2000;
size_t inputCount = batch.size() * axesRange.size();
if (inputCount < minCount)
work(0, /*parallel*/ false);
else
parallelFor(ctx, 0, ctx->getNumThreads(), work);
});
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also assume that the work up there assumes that there are batch.size() reductions that can all be done in parallel.

Since we have for quantization "whole tensor" quantization, we have cases where we have only 1 reduction.
That can also be done in parallel. Say you have 1000 elements and 10 threads. Each thread process its own 100 numbers, and save its result in its location in an array of 10 partial sum. Then after the parallel region, just reduce these 10 values sequentially. You will still get a near 10x speedup.

Also, should we check if that if the batch.size is small, we may want to do things sequentially? It would probably be good in case we have a few very small tensors. You can easily print out the sizes on stderr for a few benchmarks and see if you have such cases.


Expand Down
45 changes: 41 additions & 4 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H
#define ONNX_MLIR_ELEM_ATTR_BUILDER_H
#include "mlir/IR/Threading.h"

#include "src/Dialect/ONNX/ElementsAttr/BType.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
Expand Down Expand Up @@ -244,10 +245,46 @@ class ElementsAttrBuilder {
// Constructs a transformer that changes every element to the result of
// applying the given function to the element.
template <typename Function = WideNum (*)(WideNum)>
static inline Transformer functionTransformer(Function fun) {
return [fun = std::move(fun)](llvm::MutableArrayRef<WideNum> data) -> void {
for (WideNum &n : data)
n = fun(n);
inline Transformer functionTransformer(Function fun) {
mlir::MLIRContext *ctx = disposablePool.getContext();
return [fun = std::move(fun), ctx](
llvm::MutableArrayRef<WideNum> data) -> void {
auto fetchBatch = [&](size_t threadNumber, bool parallel) {
// retrun all data without spliting for sequential execution.
if (!parallel)
return llvm::make_range(data.begin(), data.end());
// Each thread fetches the same data size. The leftovers are set in the
// threads with small thread number.
size_t tileSize = floor(data.size() / ctx->getNumThreads());
size_t leftovers = data.size() % ctx->getNumThreads();
int beginOffset;
if (threadNumber < leftovers) {
// for the first few threads, it is as if the block size is larger
// by 1.
tileSize++;
beginOffset = threadNumber * tileSize;
} else {
// for the last threads, its as we shift the start by leftovers.
beginOffset = threadNumber * tileSize + leftovers;
}
int endOffset = beginOffset + tileSize;
return llvm::make_range(
data.begin() + beginOffset, data.begin() + endOffset);
};

auto work = [&](size_t threadNumber, bool parallel = true) {
auto tile = fetchBatch(threadNumber, parallel);
for (WideNum &n : tile)
n = fun(n);
};
// Using 'parallelFor()' introduces large overhead.
// To avoid this overhead, call work() directry if input size is less than
// `minCount`.
constexpr size_t minCount = 1000;
if (data.size() < minCount)
work(0, /*parallel*/ false);
else
parallelFor(ctx, 0, ctx->getNumThreads(), work);
};
}

Expand Down
Loading