Skip to content

Commit

Permalink
[xla:cpu] Add XnnFusionThunk options to be able to run without a thre…
Browse files Browse the repository at this point in the history
…ad pool

PiperOrigin-RevId: 719931008
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 26, 2025
1 parent 1cbcb65 commit 1321030
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 39 deletions.
8 changes: 7 additions & 1 deletion xla/backends/cpu/runtime/thunk.proto
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,14 @@ message XnnDotThunkProto {
}

message XnnFusionThunkProto {
message Options {
bool use_threadpool = 1;
}

Options options = 1;

oneof impl {
XnnDotThunkProto xnn_dot_thunk = 1;
XnnDotThunkProto xnn_dot_thunk = 2;
}
}

Expand Down
6 changes: 5 additions & 1 deletion xla/backends/cpu/runtime/thunk_serdes_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,10 @@ static absl::StatusOr<std::unique_ptr<XnnDotThunk>> XnnDotThunkFromProto(
const ThunkProto& proto, const BufferAssignment& buffer_assignment) {
TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info()));

XnnDotThunk::Options options = {
proto.xnn_fusion_thunk().options().use_threadpool(),
};

TF_ASSIGN_OR_RETURN(
auto lhs_slice_shape,
DeserializeSliceShapeFromProto(
Expand All @@ -1268,7 +1272,7 @@ static absl::StatusOr<std::unique_ptr<XnnDotThunk>> XnnDotThunkFromProto(
const auto& [out_buffer, out_shape] = out_slice_shape;

return XnnDotThunk::Create(
std::move(info),
std::move(options), std::move(info),
proto.xnn_fusion_thunk().xnn_dot_thunk().dot_dimensions(), lhs_buffer,
lhs_shape, rhs_buffer, rhs_shape, out_buffer, out_shape);
}
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ xla_cc_test(
"//xla/backends/cpu/runtime:thunk",
"//xla/backends/cpu/runtime:thunk_testlib",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
],
)

Expand Down Expand Up @@ -229,12 +231,14 @@ xla_cc_test(
"//xla/backends/cpu/runtime:thunk",
"//xla/backends/cpu/runtime:thunk_testlib",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@XNNPACK",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
],
)
12 changes: 7 additions & 5 deletions xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ absl::StatusOr<xnn_subgraph_t> XnnDotThunk::BuildDotSubgraph(
}

absl::StatusOr<std::unique_ptr<XnnDotThunk>> XnnDotThunk::Create(
Info info, DotDimensionNumbers dot_dimensions,
Options options, Info info, DotDimensionNumbers dot_dimensions,
BufferAllocation::Slice lhs_buffer, Shape lhs_shape,
BufferAllocation::Slice rhs_buffer, Shape rhs_shape,
BufferAllocation::Slice out_buffer, Shape out_shape) {
Expand All @@ -97,7 +97,8 @@ absl::StatusOr<std::unique_ptr<XnnDotThunk>> XnnDotThunk::Create(
out_buffer, std::move(out_shape)};

return absl::WrapUnique(
new XnnDotThunk(info, std::move(dot_dimensions), std::move(dot_slices),
new XnnDotThunk(std::move(options), std::move(info),
std::move(dot_dimensions), std::move(dot_slices),
std::move(dot_shape), std::move(dot_canonical_dims)));
}

Expand All @@ -111,11 +112,12 @@ static std::vector<XnnFusionThunk::Result> DotResults(const DotSlices& slices) {
return {XnnFusionThunk::Result{slices.out_buffer, slices.out_shape}};
}

XnnDotThunk::XnnDotThunk(Info info, DotDimensionNumbers dot_dimensions,
XnnDotThunk::XnnDotThunk(Options options, Info info,
DotDimensionNumbers dot_dimensions,
DotSlices dot_slices, DotShape dot_shape,
DotCanonicalDims dot_canonical_dims)
: XnnFusionThunk(std::move(info), DotArguments(dot_slices),
DotResults(dot_slices),
: XnnFusionThunk(std::move(options), std::move(info),
DotArguments(dot_slices), DotResults(dot_slices),
std::bind(&XnnDotThunk::BuildDotSubgraph, this,
std::placeholders::_1, std::placeholders::_2)),
dot_dimensions_(std::move(dot_dimensions)),
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace xla::cpu {
class XnnDotThunk final : public XnnFusionThunk {
public:
static absl::StatusOr<std::unique_ptr<XnnDotThunk>> Create(
Info info, DotDimensionNumbers dot_dimensions,
Options options, Info info, DotDimensionNumbers dot_dimensions,
BufferAllocation::Slice lhs_buffer, Shape lhs_shape,
BufferAllocation::Slice rhs_buffer, Shape rhs_shape,
BufferAllocation::Slice out_buffer, Shape out_shape);
Expand All @@ -54,7 +54,7 @@ class XnnDotThunk final : public XnnFusionThunk {
std::string result_name(size_t index) const final;

private:
XnnDotThunk(Info info, DotDimensionNumbers dot_dimensions,
XnnDotThunk(Options options, Info info, DotDimensionNumbers dot_dimensions,
DotSlices dot_slices, DotShape dot_shape,
DotCanonicalDims dot_canonical_dims);

Expand Down
22 changes: 20 additions & 2 deletions xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,23 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/threadpool.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::cpu {
namespace {

TEST(XnnDotThunkTest, SimpleDot) {
class XnnDotThunkTest : public testing::TestWithParam<bool> {
protected:
bool use_threadpool() const { return GetParam(); }
};

TEST_P(XnnDotThunkTest, SimpleDot) {
auto lhs = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto rhs = LiteralUtil::CreateR2<float>({{4.0, 3.0}, {2.0, 1.0}});
auto out = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
Expand All @@ -47,11 +57,17 @@ TEST(XnnDotThunkTest, SimpleDot) {
dot_dimensions.add_rhs_contracting_dimensions(0);

TF_ASSERT_OK_AND_ASSIGN(
auto thunk, XnnDotThunk::Create({"dot"}, dot_dimensions, lhs_slice, shape,
auto thunk, XnnDotThunk::Create(XnnDotThunk::Options{use_threadpool()},
{"dot"}, dot_dimensions, lhs_slice, shape,
rhs_slice, shape, out_slice, shape));

tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

Thunk::ExecuteParams params;
params.buffer_allocations = &allocations;
params.intra_op_threadpool = use_threadpool() ? &device : nullptr;

auto execute_event = thunk->Execute(params);
tsl::BlockUntilReady(execute_event);
Expand All @@ -60,5 +76,7 @@ TEST(XnnDotThunkTest, SimpleDot) {
EXPECT_EQ(out, LiteralUtil::CreateR2<float>({{8.0, 5.0}, {20.0, 13.0}}));
}

INSTANTIATE_TEST_SUITE_P(XnnDot, XnnDotThunkTest, testing::Values(true, false));

} // namespace
} // namespace xla::cpu
67 changes: 50 additions & 17 deletions xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ limitations under the License.

namespace xla::cpu {

namespace {
enum class ParallelizationMode { kInline, kParallelLoopRunner, kPThreadPool };

template <typename Sink>
void AbslStringify(Sink& sink, ParallelizationMode m) {
switch (m) {
case ParallelizationMode::kInline:
sink.Append("kInline");
break;
case ParallelizationMode::kParallelLoopRunner:
sink.Append("kParallelLoopRunner");
break;
case ParallelizationMode::kPThreadPool:
sink.Append("kPThreadPool");
break;
}
}

} // namespace

// XNNPACK runtime instantiated for the fusion operation.
struct XnnFusionThunk::XnnRuntime {
XnnRuntime() = default;
Expand Down Expand Up @@ -109,9 +129,16 @@ XnnFusionThunk::XnnRuntime::Invoke(const Eigen::ThreadPoolDevice* device,
XNN_RETURN_IF_ERROR(xnn_setup_runtime_v2(runtime, external_values.size(),
external_values.data()));

runner->set_device(device);
// Execute XNNPACK runtime using a parallel loop runner.
if (runner) {
runner->set_device(device);
XNN_RETURN_IF_ERROR(xnn_invoke_runtime(runtime));
return runner->ResetDoneEvent();
}

// Execute XNNPACK runtime in the caller thread.
XNN_RETURN_IF_ERROR(xnn_invoke_runtime(runtime));
return runner->ResetDoneEvent();
return OkExecuteEventSingleton();
}

void XnnFusionThunk::XnnRuntime::Destroy() {
Expand All @@ -125,24 +152,28 @@ void XnnFusionThunk::XnnRuntime::Destroy() {

absl::StatusOr<XnnFusionThunk::XnnRuntime> XnnFusionThunk::CreateXnnRuntime(
const Eigen::ThreadPoolDevice* device) {
bool use_custom_threadpool = device && IsCustomPthreadpoolEnabled();
ParallelizationMode parallelization_mode =
options_.use_threadpool ? (device && IsCustomPthreadpoolEnabled()
? ParallelizationMode::kParallelLoopRunner
: ParallelizationMode::kPThreadPool)
: ParallelizationMode::kInline;

VLOG(3) << absl::StreamFormat(
"Create XNN runtime for `%s` operation: num_created=%d, "
"use_custom_threadpool=%v",
info().op_name, xnn_runtime_pool_.num_created(), use_custom_threadpool);
"parallelization_mode=%v",
info().op_name, xnn_runtime_pool_.num_created(), parallelization_mode);

XnnRuntime runtime;

// Construct XNNPACK subgraph using user-provided builder function.
TF_ASSIGN_OR_RETURN(runtime.subgraph, builder_(arguments_, results_));

// If XLA is compiled with custom pthreadpool, use it in XNNPACK runtime,
// otherwise we'll run all XNNPACK operations in the default pthreadpool.
runtime.runner = std::make_unique<ParallelLoopRunner>(
device, /*worker_timeslice=*/absl::Microseconds(100));
if (use_custom_threadpool) {
// Configure XNNPACK runtime thread pool if parallelization is enabled.
if (parallelization_mode == ParallelizationMode::kParallelLoopRunner) {
runtime.runner = std::make_unique<ParallelLoopRunner>(
device, /*worker_timeslice=*/absl::Microseconds(100));
runtime.threadpool = CreateCustomPthreadpool(runtime.runner.get());
} else {
} else if (parallelization_mode == ParallelizationMode::kPThreadPool) {
runtime.threadpool = DefaultPthreadpool();
}

Expand All @@ -158,18 +189,20 @@ absl::StatusOr<XnnFusionThunk::XnnRuntime> XnnFusionThunk::CreateXnnRuntime(
}

absl::StatusOr<std::unique_ptr<XnnFusionThunk>> XnnFusionThunk::Create(
Info info, std::vector<Argument> arguments, std::vector<Result> results,
Builder builder) {
Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, Builder builder) {
TF_RETURN_IF_ERROR(InitializeXnnPack());

return absl::WrapUnique(
new XnnFusionThunk(std::move(info), std::move(arguments),
std::move(results), std::move(builder)));
return absl::WrapUnique(new XnnFusionThunk(
std::move(options), std::move(info), std::move(arguments),
std::move(results), std::move(builder)));
}

XnnFusionThunk::XnnFusionThunk(Info info, std::vector<Argument> arguments,
XnnFusionThunk::XnnFusionThunk(Options options, Info info,
std::vector<Argument> arguments,
std::vector<Result> results, Builder builder)
: Thunk(Kind::kXnnFusion, std::move(info)),
options_(std::move(options)),
arguments_(std::move(arguments)),
results_(std::move(results)),
builder_(std::move(builder)),
Expand Down
14 changes: 11 additions & 3 deletions xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_FUSION_THUNK_H_
#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_FUSION_THUNK_H_

#include <stdbool.h>

#include <cstddef>
#include <memory>
#include <string>
Expand Down Expand Up @@ -43,6 +45,10 @@ class XnnFusionThunk : public Thunk {
public:
~XnnFusionThunk() override;

struct Options {
bool use_threadpool = true;
};

struct Argument {
BufferAllocation::Slice slice;
Shape shape;
Expand All @@ -58,15 +64,15 @@ class XnnFusionThunk : public Thunk {
absl::Span<const Argument> arguments, absl::Span<const Result> results)>;

static absl::StatusOr<std::unique_ptr<XnnFusionThunk>> Create(
Info info, std::vector<Argument> arguments, std::vector<Result> results,
Builder builder);
Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, Builder builder);

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

BufferUses buffer_uses() const final;

protected:
XnnFusionThunk(Info info, std::vector<Argument> arguments,
XnnFusionThunk(Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, Builder builder);

// Extension points for subclasses to customize the logging behavior.
Expand All @@ -91,6 +97,8 @@ class XnnFusionThunk : public Thunk {
absl::StatusOr<XnnRuntime> CreateXnnRuntime(
const Eigen::ThreadPoolDevice* device);

Options options_;

std::vector<Argument> arguments_;
std::vector<Result> results_;
Builder builder_;
Expand Down
27 changes: 23 additions & 4 deletions xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/threadpool.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::cpu {
namespace {
Expand Down Expand Up @@ -77,7 +82,12 @@ static absl::StatusOr<xnn_subgraph_t> CreateBinaryAdd(
return subgraph;
}

TEST(XnnFusionThunkTest, ElementwiseAdd) {
class XnnFusionThunkTest : public testing::TestWithParam<bool> {
protected:
bool use_threadpool() const { return GetParam(); }
};

TEST_P(XnnFusionThunkTest, ElementwiseAdd) {
auto lhs = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
auto rhs = LiteralUtil::CreateR1<float>({4.0, 3.0, 2.0, 1.0});
auto out = LiteralUtil::CreateR1<float>({0.0, 0.0, 0.0, 0.0});
Expand All @@ -95,12 +105,18 @@ TEST(XnnFusionThunkTest, ElementwiseAdd) {
XnnFusionThunk::Argument rhs_arg = {rhs_slice, shape};
XnnFusionThunk::Result out_res = {out_slice, shape};

TF_ASSERT_OK_AND_ASSIGN(auto thunk,
XnnFusionThunk::Create({"fusion"}, {lhs_arg, rhs_arg},
{out_res}, &CreateBinaryAdd));
TF_ASSERT_OK_AND_ASSIGN(
auto thunk, XnnFusionThunk::Create(
XnnFusionThunk::Options{use_threadpool()}, {"fusion"},
{lhs_arg, rhs_arg}, {out_res}, &CreateBinaryAdd));

tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

Thunk::ExecuteParams params;
params.buffer_allocations = &allocations;
params.intra_op_threadpool = use_threadpool() ? &device : nullptr;

auto execute_event = thunk->Execute(params);
tsl::BlockUntilReady(execute_event);
Expand All @@ -109,5 +125,8 @@ TEST(XnnFusionThunkTest, ElementwiseAdd) {
EXPECT_EQ(out, LiteralUtil::CreateR1<float>({5.0, 5.0, 5.0, 5.0}));
}

INSTANTIATE_TEST_SUITE_P(XnnFusion, XnnFusionThunkTest,
testing::Values(true, false));

} // namespace
} // namespace xla::cpu
2 changes: 1 addition & 1 deletion xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ static void DestroyCustomPthreadpool(pthreadpool_t threadpool) { // NOLINT

static size_t GetThreadsCount(pthreadpool_t threadpool) { // NOLINT
if (ABSL_PREDICT_FALSE(threadpool == nullptr)) {
return 0;
return 1;
}

return Cast(threadpool)->runner()->num_threads();
Expand Down
Loading

0 comments on commit 1321030

Please sign in to comment.