Skip to content

Commit

Permalink
[fiber] Implement std concurrency interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
salkinium committed Apr 14, 2024
1 parent 7d284b0 commit 3f372b6
Show file tree
Hide file tree
Showing 14 changed files with 1,264 additions and 10 deletions.
80 changes: 80 additions & 0 deletions src/modm/processing/fiber/barrier.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2023, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::barrier` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/barrier
template< class CompletionFunction = decltype([](){}) >
class barrier
{
barrier(const barrier&) = delete;
barrier& operator=(const barrier&) = delete;
using count_t = uint16_t;

const CompletionFunction completion;
volatile count_t expected;
volatile count_t count;
volatile count_t sequence{};
public:
using arrival_token = count_t;
constexpr explicit barrier(std::ptrdiff_t expected,
CompletionFunction f = CompletionFunction())
: completion(std::move(f)), expected(expected), count(expected) {}

[[nodiscard]] static constexpr std::ptrdiff_t
max() { return count_t(-1); }

[[nodiscard]] arrival_token
arrive(std::ptrdiff_t n=1)
{
count_t last_arrival{sequence};
count -= n;
if (count == 0)
{
count = expected;
sequence++;
completion();
}
return last_arrival;
}

void
wait(arrival_token&& arrival) const
{
while (arrival == sequence) modm::this_fiber::yield();
}

void
arrive_and_wait()
{
wait(arrive());
}

void
arrive_and_drop()
{
expected--;
arrive();
}
};

/// @}

}
161 changes: 161 additions & 0 deletions src/modm/processing/fiber/condition_variable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2023, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"
#include <atomic>
#include <condition_variable>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::condition_variable_any` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/condition_variable
class condition_variable_any
{
condition_variable_any(const condition_variable_any&) = delete;
condition_variable_any& operator=(const condition_variable_any&) = delete;

std::atomic<uint16_t> sequence{};
public:
constexpr condition_variable_any() = default;

void inline
notify_one()
{
sequence++;
}

void inline
notify_any()
{
sequence++;
}


template< class Lock >
void
wait(Lock& lock)
{
lock.unlock();
const auto my_sequence = sequence.load();
while(my_sequence == sequence.load()) modm::this_fiber::yield();
lock.lock();
}

template< class Lock, class Predicate >
void
wait(Lock& lock, Predicate&& pred)
{
while (not pred()) wait(lock);
}

template< class Lock, class Predicate >
bool
wait(Lock& lock, std::stop_token stoken, Predicate pred)
{
while (not stoken.stop_requested())
{
if (pred()) return true;
wait(lock);
}
return pred();
}


template< class Lock, class Rep, class Period >
std::cv_status
wait_for(Lock& lock, std::chrono::duration<Rep, Period> rel_time)
{
lock.unlock();
const auto condition = [this, my_sequence = sequence.load()]()
{ return my_sequence != sequence.load(); };
const bool result = this_fiber::poll_for(rel_time, condition);
lock.lock();
return result ? std::cv_status::no_timeout : std::cv_status::timeout;
}

template< class Lock, class Rep, class Period, class Predicate >
bool
wait_for(Lock& lock, std::chrono::duration<Rep, Period> rel_time, Predicate&& pred)
{
while (not pred())
{
if (wait_for(lock, rel_time) == std::cv_status::timeout)
return pred();
}
return true;
}

template< class Lock, class Rep, class Period, class Predicate >
bool
wait_for(Lock& lock, std::stop_token stoken,
std::chrono::duration<Rep, Period> rel_time, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (pred()) return true;
if (wait_for(lock, rel_time) == std::cv_status::timeout)
return pred();
}
return pred();
}


template< class Lock, class Clock, class Duration >
std::cv_status
wait_until(Lock& lock, std::chrono::time_point<Clock, Duration> abs_time)
{
lock.unlock();
const auto condition = [this, my_sequence = sequence.load()]()
{ return my_sequence != sequence.load(); };
const bool result = this_fiber::poll_until(abs_time, condition);
lock.lock();
return result ? std::cv_status::no_timeout : std::cv_status::timeout;
}

template< class Lock, class Clock, class Duration, class Predicate >
bool
wait_until(Lock& lock, std::chrono::time_point<Clock, Duration> abs_time, Predicate&& pred)
{
while (not pred())
{
if (wait_until(lock, abs_time) == std::cv_status::timeout)
return pred();
}
return true;
}

template< class Lock, class Clock, class Duration, class Predicate >
bool
wait_until(Lock& lock, std::stop_token stoken,
std::chrono::time_point<Clock, Duration> abs_time, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (pred()) return true;
if (wait_until(lock, abs_time) == std::cv_status::timeout)
return pred();
}
return pred();
}
};

// There is no specialization for std::unique_lock.
using condition_variable = condition_variable_any;

/// @}

}
68 changes: 68 additions & 0 deletions src/modm/processing/fiber/latch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2023, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"
#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::latch` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/latch
class latch
{
latch(const latch&) = delete;
latch& operator=(const latch&) = delete;

using count_t = uint16_t;
std::atomic<count_t> count;
public:
constexpr explicit
latch(std::ptrdiff_t expected)
: count(expected) {}

[[nodiscard]] static constexpr std::ptrdiff_t
max() { return count_t(-1); }

void inline
count_down(std::ptrdiff_t n=1)
{
count -= n;
}

[[nodiscard]] bool inline
try_wait() const
{
return count.load() == 0;
}

void inline
wait() const
{
while(not try_wait()) modm::this_fiber::yield();
}

void inline
arrive_and_wait(std::ptrdiff_t n=1)
{
count_down(n);
wait();
}
};

/// @}

}
9 changes: 8 additions & 1 deletion src/modm/processing/fiber/module.lb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def is_enabled(env):
not env.has_module(":processing:protothread")

def prepare(module, options):
module.depends(":processing:timer")
module.depends(":processing:timer", ":architecture:atomic")

module.add_query(
EnvironmentQuery(name="__enabled", factory=is_enabled))
Expand Down Expand Up @@ -77,3 +77,10 @@ def build(env):
env.copy("task.hpp")
env.copy("functions.hpp")
env.copy("fiber.hpp")

env.copy("mutex.hpp")
env.copy("shared_mutex.hpp")
env.copy("semaphore.hpp")
env.copy("latch.hpp")
env.copy("barrier.hpp")
env.copy("condition_variable.hpp")
Loading

0 comments on commit 3f372b6

Please sign in to comment.