Skip to content

Commit

Permalink
Execute the task in the Task::wait() method to avoid deadlock. (#436)
Browse files Browse the repository at this point in the history
Co-authored-by: domrjchen <[email protected]>
  • Loading branch information
kevingpqi123 and domchen authored Jan 21, 2025
1 parent 759f268 commit ce19888
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 81 deletions.
48 changes: 30 additions & 18 deletions include/tgfx/core/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#pragma once

#include <atomic>
#include <condition_variable>
#include <functional>
#include <memory>
Expand All @@ -27,31 +28,38 @@ namespace tgfx {
class TaskGroup;

/**
* The Task class manages the concurrent execution of one or more code blocks.
* Defines the possible states of a Task.
*/
class Task {
public:
enum class TaskStatus {
/**
* Submits a code block for asynchronous execution immediately and returns a Task wraps the code
* block. Hold a reference to the returned Task if you want to cancel it or wait for it to finish
* execution. Returns nullptr if the block is nullptr.
* The Task is waiting to be executed.
*/
static std::shared_ptr<Task> Run(std::function<void()> block);

Queueing,
/**
* Returns true if the Task is currently executing its code block.
* The Task is currently executing.
*/
bool executing();

Executing,
/**
* Returns true if the Task has been cancelled
* The Task has finished executing.
*/
bool cancelled();
Finished,
/**
* The Task has been canceled.
*/
Canceled
};

/**
* The Task class manages the concurrent execution of one or more code blocks.
*/
class Task {
public:
/**
* Returns true if the Task has finished executing its code block.
* Submits a code block for asynchronous execution immediately and returns a Task wraps the code
* block. Hold a reference to the returned Task if you want to cancel it or wait for it to finish
* execution. Returns nullptr if the block is nullptr.
*/
bool finished();
static std::shared_ptr<Task> Run(std::function<void()> block);

/**
* Advises the Task that it should stop executing its code block. Cancellation does not affect the
Expand All @@ -62,16 +70,20 @@ class Task {
/**
* Blocks the current thread until the Task finishes its execution. Returns immediately if the
* Task is finished or canceled. The task may be executed on the calling thread if it is not
* cancelled and still in the queue.
* canceled and still in the queue.
*/
void wait();

/**
* Return the current status of the Task.
*/
TaskStatus status() const;

private:
std::mutex locker = {};
std::condition_variable condition = {};
bool _executing = true;
bool _cancelled = false;
std::function<void()> block = nullptr;
std::atomic<TaskStatus> _status = TaskStatus::Queueing;

explicit Task(std::function<void()> block);
void execute();
Expand Down
49 changes: 27 additions & 22 deletions src/core/utils/LockFreeQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,27 @@

namespace tgfx {
static constexpr size_t CACHELINE_SIZE = 64;
// QUEUE_SIZE needs to be a power of 2, otherwise the implementation of GetIndex needs to be changed.
static constexpr uint32_t QUEUE_SIZE = 1024;

inline uint32_t GetIndex(uint32_t position) {
return position & (QUEUE_SIZE - 1);
}

template <typename T>
class LockFreeQueue {
public:
LockFreeQueue() {
queuePool = reinterpret_cast<T*>(std::calloc(QUEUE_SIZE, sizeof(T)));
/**
* The capacity needs to be a power of 2, otherwise, it will be automatically set to the nearest
* power of 2 larger than the capacity.
* @param capacity
*/
explicit LockFreeQueue(uint32_t capacity) {
if ((capacity & (capacity - 1)) != 0) {
_capacity = 1;
while (_capacity < capacity) {
_capacity <<= 1;
}
} else {
_capacity = capacity;
}
queuePool = reinterpret_cast<T*>(std::calloc(_capacity, sizeof(T)));
if (queuePool == nullptr) {
LOGE("LockFreeQueue init Failed!\n");
return;
ABORT("LockFreeQueue init Failed!\n");
}
}

Expand All @@ -59,9 +65,6 @@ class LockFreeQueue {
}

T dequeue() {
if (queuePool == nullptr) {
return nullptr;
}
uint32_t newHead = 0;
uint32_t oldHead = head.load(std::memory_order_relaxed);
T element = nullptr;
Expand All @@ -71,11 +74,11 @@ class LockFreeQueue {
if (newHead == tailPosition.load(std::memory_order_acquire)) {
return nullptr;
}
element = queuePool[GetIndex(newHead)];
element = queuePool[getIndex(newHead)];
} while (!head.compare_exchange_weak(oldHead, newHead, std::memory_order_acq_rel,
std::memory_order_relaxed));

queuePool[GetIndex(newHead)] = nullptr;
queuePool[getIndex(newHead)] = nullptr;

uint32_t newHeadPosition = 0;
uint32_t oldHeadPosition = headPosition.load(std::memory_order_relaxed);
Expand All @@ -87,22 +90,19 @@ class LockFreeQueue {
}

bool enqueue(const T& element) {
if (queuePool == nullptr) {
return false;
}
uint32_t newTail = 0;
uint32_t oldTail = tail.load(std::memory_order_relaxed);

do {
newTail = oldTail + 1;
if (GetIndex(newTail) == GetIndex(headPosition.load(std::memory_order_acquire))) {
LOGI("The queue has reached its maximum capacity, capacity: %u!\n", QUEUE_SIZE);
if (getIndex(oldTail) == getIndex(headPosition.load(std::memory_order_acquire))) {
LOGI("The queue has reached its maximum capacity, capacity: %u!\n", _capacity);
return false;
}
newTail = oldTail + 1;
} while (!tail.compare_exchange_weak(oldTail, newTail, std::memory_order_acq_rel,
std::memory_order_relaxed));

queuePool[GetIndex(oldTail)] = std::move(element);
queuePool[getIndex(oldTail)] = std::move(element);

uint32_t newTailPosition = 0;
uint32_t oldTailPosition = tailPosition.load(std::memory_order_relaxed);
Expand All @@ -115,6 +115,7 @@ class LockFreeQueue {

private:
T* queuePool = nullptr;
uint32_t _capacity = 0;
#ifdef DISABLE_ALIGNAS
// head indicates the position after requesting space.
std::atomic<uint32_t> head = {0};
Expand All @@ -134,6 +135,10 @@ class LockFreeQueue {
// tailPosition indicates the position after filling data.
alignas(CACHELINE_SIZE) std::atomic<uint32_t> tailPosition = {1};
#endif

uint32_t getIndex(uint32_t position) {
return position & (_capacity - 1);
}
};

} // namespace tgfx
66 changes: 39 additions & 27 deletions src/core/utils/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,54 @@ std::shared_ptr<Task> Task::Run(std::function<void()> block) {
Task::Task(std::function<void()> block) : block(std::move(block)) {
}

bool Task::executing() {
std::lock_guard<std::mutex> autoLock(locker);
return _executing;
}

bool Task::cancelled() {
std::lock_guard<std::mutex> autoLock(locker);
return _cancelled;
}

bool Task::finished() {
std::lock_guard<std::mutex> autoLock(locker);
return !_executing && !_cancelled;
}

void Task::wait() {
std::unique_lock<std::mutex> autoLock(locker);
if (!_executing) {
auto oldStatus = _status.load(std::memory_order_relaxed);
if (oldStatus == TaskStatus::Canceled || oldStatus == TaskStatus::Finished) {
return;
}
condition.wait(autoLock);
// If wait() is called from the thread pool, all threads might block, leaving no thread to execute
// this task. To avoid deadlock, execute the task directly on the current thread if it's queued.
if (oldStatus == TaskStatus::Queueing) {
if (_status.compare_exchange_weak(oldStatus, TaskStatus::Executing, std::memory_order_acq_rel,
std::memory_order_relaxed)) {
block();
oldStatus = TaskStatus::Executing;
while (!_status.compare_exchange_weak(oldStatus, TaskStatus::Finished,
std::memory_order_acq_rel, std::memory_order_relaxed))
;
return;
}
}
std::unique_lock<std::mutex> autoLock(locker);
if (_status.load(std::memory_order_acquire) == TaskStatus::Executing) {
condition.wait(autoLock);
}
}

void Task::cancel() {
std::unique_lock<std::mutex> autoLock(locker);
if (!_executing) {
return;
auto currentStatus = _status.load(std::memory_order_relaxed);
if (currentStatus == TaskStatus::Queueing) {
_status.compare_exchange_weak(currentStatus, TaskStatus::Canceled, std::memory_order_acq_rel,
std::memory_order_relaxed);
}
_executing = false;
_cancelled = true;
}

void Task::execute() {
block();
std::lock_guard<std::mutex> auoLock(locker);
_executing = false;
condition.notify_all();
auto oldStatus = _status.load(std::memory_order_relaxed);
if (oldStatus == TaskStatus::Queueing &&
_status.compare_exchange_weak(oldStatus, TaskStatus::Executing, std::memory_order_acq_rel,
std::memory_order_relaxed)) {
block();
oldStatus = TaskStatus::Executing;
while (!_status.compare_exchange_weak(oldStatus, TaskStatus::Finished,
std::memory_order_acq_rel, std::memory_order_relaxed))
;
std::unique_lock<std::mutex> autoLock(locker);
condition.notify_all();
}
}

TaskStatus Task::status() const {
return _status.load(std::memory_order_relaxed);
}
} // namespace tgfx
26 changes: 14 additions & 12 deletions src/core/utils/TaskGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

namespace tgfx {
static constexpr auto THREAD_TIMEOUT = std::chrono::seconds(10);
static constexpr uint32_t THREAD_POOL_SIZE = 32;
static constexpr uint32_t TASK_QUEUE_SIZE = 1024;

int GetCPUCores() {
int cpuCores = 0;
Expand All @@ -44,9 +46,6 @@ int GetCPUCores() {
return cpuCores;
}

static const int CPUCores = GetCPUCores();
static const int MaxThreads = CPUCores > 16 ? 16 : CPUCores;

TaskGroup* TaskGroup::GetInstance() {
static auto& taskGroup = *new TaskGroup();
return &taskGroup;
Expand Down Expand Up @@ -76,15 +75,18 @@ void OnAppExit() {
}

TaskGroup::TaskGroup() {
threads = new LockFreeQueue<std::thread*>(THREAD_POOL_SIZE);
tasks = new LockFreeQueue<std::shared_ptr<Task>>(TASK_QUEUE_SIZE);
std::atexit(OnAppExit);
threads.resize(static_cast<size_t>(MaxThreads), nullptr);
}

bool TaskGroup::checkThreads() {
static const int CPUCores = GetCPUCores();
static const int MaxThreads = CPUCores > 16 ? 16 : CPUCores;
if (waitingThreads == 0 && totalThreads < MaxThreads) {
auto thread = new (std::nothrow) std::thread(&TaskGroup::RunLoop, this);
if (thread) {
threads[static_cast<size_t>(totalThreads)] = thread;
threads->enqueue(thread);
totalThreads++;
}
} else {
Expand All @@ -100,7 +102,7 @@ bool TaskGroup::pushTask(std::shared_ptr<Task> task) {
if (exited || !checkThreads()) {
return false;
}
if (!tasks.enqueue(std::move(task))) {
if (!tasks->enqueue(std::move(task))) {
return false;
}
if (waitingThreads > 0) {
Expand All @@ -112,7 +114,7 @@ bool TaskGroup::pushTask(std::shared_ptr<Task> task) {
std::shared_ptr<Task> TaskGroup::popTask() {
std::unique_lock<std::mutex> autoLock(locker);
while (!exited) {
auto task = tasks.dequeue();
auto task = tasks->dequeue();
if (task) {
return task;
}
Expand All @@ -129,12 +131,12 @@ std::shared_ptr<Task> TaskGroup::popTask() {
void TaskGroup::exit() {
exited = true;
condition.notify_all();
for (int i = 0; i < totalThreads; i++) {
auto thread = threads[static_cast<size_t>(i)];
if (thread) {
ReleaseThread(thread);
}
std::thread* thread = nullptr;
while ((thread = threads->dequeue()) != nullptr) {
ReleaseThread(thread);
}
delete threads;
delete tasks;
totalThreads = 0;
waitingThreads = 0;
}
Expand Down
4 changes: 2 additions & 2 deletions src/core/utils/TaskGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class TaskGroup {
std::atomic_int totalThreads = 0;
std::atomic_bool exited = false;
std::atomic_int waitingThreads = 0;
LockFreeQueue<std::shared_ptr<Task>> tasks = {};
std::vector<std::thread*> threads = {};
LockFreeQueue<std::shared_ptr<Task>>* tasks = nullptr;
LockFreeQueue<std::thread*>* threads = nullptr;
static TaskGroup* GetInstance();
static void RunLoop(TaskGroup* taskGroup);

Expand Down

0 comments on commit ce19888

Please sign in to comment.