File channel.hpp¶
File List > coroutine > channel.hpp
Go to the documentation of this file.
#pragma once
#ifndef LUNCLIFF_COROUTINE_CHANNEL_HPP
#define LUNCLIFF_COROUTINE_CHANNEL_HPP
#include <mutex>
#include <tuple>
#if __has_include(<coroutine/frame.h>) && !defined(USE_EXPERIMENTAL_COROUTINE)
#include <coroutine/frame.h>
namespace coro {
using std::coroutine_handle;
using std::suspend_always;
using std::suspend_never;
#elif __has_include(<experimental/coroutine>)
#include <experimental/coroutine>
namespace coro {
using std::experimental::coroutine_handle;
using std::experimental::suspend_always;
using std::experimental::suspend_never;
#else
#error "requires header <experimental/coroutine> or <coroutine/frame.h>"
#endif
struct bypass_mutex final {
constexpr bool try_lock() noexcept {
return true;
}
constexpr void lock() noexcept {
}
constexpr void unlock() noexcept {
}
};
namespace internal {
static void* poison() noexcept(false) {
return reinterpret_cast<void*>(0xFADE'038C'BCFA'9E64);
}
template <typename T>
class list {
T* head{};
T* tail{};
public:
bool is_empty() const noexcept(false) {
return head == nullptr;
}
void push(T* node) noexcept(false) {
if (tail) {
tail->next = node;
tail = node;
} else
head = tail = node;
}
auto pop() noexcept(false) -> T* {
T* node = head;
if (head == tail) // empty or 1
head = tail = nullptr;
else // 2 or more
head = head->next;
return node;
}
};
} // namespace internal
template <typename T, typename M = bypass_mutex>
class channel; // by default, channel doesn't care about the race condition
template <typename T, typename M>
class channel_reader;
template <typename T, typename M>
class channel_writer;
template <typename T, typename M>
class channel_peeker;
template <typename T, typename M>
class channel_reader {
public:
using value_type = T;
using pointer = T*;
using reference = T&;
using channel_type = channel<T, M>;
private:
using reader_list = typename channel_type::reader_list;
using writer = typename channel_type::writer;
using writer_list = typename channel_type::writer_list;
friend channel_type;
friend writer;
friend reader_list;
protected:
mutable pointer ptr;
mutable void* frame;
union {
channel_reader* next = nullptr;
channel_type* chan;
};
protected:
explicit channel_reader(channel_type& ch) noexcept(false)
: ptr{}, frame{nullptr}, chan{std::addressof(ch)} {
}
channel_reader(const channel_reader&) noexcept = delete;
channel_reader& operator=(const channel_reader&) noexcept = delete;
channel_reader(channel_reader&&) noexcept = delete;
channel_reader& operator=(channel_reader&&) noexcept = delete;
public:
~channel_reader() noexcept = default;
public:
bool await_ready() const noexcept(false) {
chan->mtx.lock();
if (chan->writer_list::is_empty())
// await_suspend will unlock in the case
return false;
writer* w = chan->writer_list::pop();
// exchange address & resumeable_handle
std::swap(this->ptr, w->ptr);
std::swap(this->frame, w->frame);
chan->mtx.unlock();
return true;
}
void await_suspend(coroutine_handle<void> coro) noexcept(false) {
// notice that next & chan are sharing memory
channel_type& ch = *(this->chan);
// remember handle before push/unlock
this->frame = coro.address();
this->next = nullptr;
// push to channel
ch.reader_list::push(this);
ch.mtx.unlock();
}
auto await_resume() noexcept(false) -> std::tuple<value_type, bool> {
auto t = std::make_tuple(value_type{}, false);
// frame holds poision if the channel is under destruction
if (this->frame == internal::poison())
return t;
// the resume operation can destroy the other coroutine
// store before resume
std::get<0>(t) = std::move(*ptr);
if (auto coro = coroutine_handle<void>::from_address(frame))
coro.resume();
std::get<1>(t) = true;
return t;
}
};
template <typename T, typename M>
class channel_writer {
public:
using value_type = T;
using pointer = T*;
using reference = T&;
using channel_type = channel<T, M>;
private:
using reader = typename channel_type::reader;
using reader_list = typename channel_type::reader_list;
using writer_list = typename channel_type::writer_list;
using peeker = typename channel_type::peeker;
friend channel_type;
friend reader;
friend writer_list;
friend peeker; // for `peek()` implementation
private:
mutable pointer ptr;
mutable void* frame;
union {
channel_writer* next = nullptr;
channel_type* chan;
};
private:
explicit channel_writer(channel_type& ch, pointer pv) noexcept(false)
: ptr{pv}, frame{nullptr}, chan{std::addressof(ch)} {
}
channel_writer(const channel_writer&) noexcept = delete;
channel_writer& operator=(const channel_writer&) noexcept = delete;
channel_writer(channel_writer&&) noexcept = delete;
channel_writer& operator=(channel_writer&&) noexcept = delete;
public:
~channel_writer() noexcept = default;
public:
bool await_ready() const noexcept(false) {
chan->mtx.lock();
if (chan->reader_list::is_empty())
// await_suspend will unlock in the case
return false;
reader* r = chan->reader_list::pop();
// exchange address & resumeable_handle
std::swap(this->ptr, r->ptr);
std::swap(this->frame, r->frame);
chan->mtx.unlock();
return true;
}
void await_suspend(coroutine_handle<void> coro) noexcept(false) {
// notice that next & chan are sharing memory
channel_type& ch = *(this->chan);
this->frame = coro.address(); // remember handle before push/unlock
this->next = nullptr; // clear to prevent confusing
ch.writer_list::push(this); // push to channel
ch.mtx.unlock();
}
bool await_resume() noexcept(false) {
// frame holds poision if the channel is under destruction
if (this->frame == internal::poison())
return false;
if (auto coro = coroutine_handle<void>::from_address(frame))
coro.resume();
return true;
}
};
template <typename T, typename M>
class channel final : internal::list<channel_reader<T, M>>,
internal::list<channel_writer<T, M>> {
static_assert(std::is_reference<T>::value == false,
"reference type can't be channel's value_type.");
public:
using value_type = T;
using pointer = value_type*;
using reference = value_type&;
using mutex_type = M;
private:
using reader = channel_reader<value_type, mutex_type>;
using reader_list = internal::list<reader>;
using writer = channel_writer<value_type, mutex_type>;
using writer_list = internal::list<writer>;
using peeker = channel_peeker<value_type, mutex_type>;
friend reader;
friend writer;
friend peeker; // for `peek()` implementation
private:
mutex_type mtx{};
private:
channel(const channel&) noexcept(false) = delete;
channel(channel&&) noexcept(false) = delete;
channel& operator=(const channel&) noexcept(false) = delete;
channel& operator=(channel&&) noexcept(false) = delete;
public:
channel() noexcept(false) : reader_list{}, writer_list{}, mtx{} {
}
~channel() noexcept(false) {
void* closing = internal::poison();
writer_list& writers = *this;
reader_list& readers = *this;
// even 5'000+ can be unsafe for hazard usage ...
size_t repeat = 1;
do {
std::unique_lock lck{mtx};
while (writers.is_empty() == false) {
writer* w = writers.pop();
auto coro = coroutine_handle<void>::from_address(w->frame);
w->frame = closing;
coro.resume();
}
while (readers.is_empty() == false) {
reader* r = readers.pop();
auto coro = coroutine_handle<void>::from_address(r->frame);
r->frame = closing;
coro.resume();
}
} while (repeat--);
}
public:
decltype(auto) write(reference ref) noexcept(false) {
return channel_writer{*this, std::addressof(ref)};
}
decltype(auto) read() noexcept(false) {
return channel_reader{*this};
}
};
template <typename T, typename M>
class channel_peeker final : protected channel_reader<T, M> {
using channel_type = channel<T, M>;
using writer = typename channel_type::writer;
using writer_list = typename channel_type::writer_list;
private:
channel_peeker(const channel_peeker&) noexcept(false) = delete;
channel_peeker(channel_peeker&&) noexcept(false) = delete;
channel_peeker& operator=(const channel_peeker&) noexcept(false) = delete;
channel_peeker& operator=(channel_peeker&&) noexcept(false) = delete;
public:
explicit channel_peeker(channel_type& ch) noexcept(false)
: channel_reader<T, M>{ch} {
}
~channel_peeker() noexcept = default;
public:
void peek() const noexcept(false) {
std::unique_lock lck{this->chan->mtx};
if (this->chan->writer_list::is_empty() == false) {
writer* w = this->chan->writer_list::pop();
std::swap(this->ptr, w->ptr);
std::swap(this->frame, w->frame);
}
}
bool acquire(T& storage) noexcept(false) {
// if there was a writer, take its value
if (this->ptr == nullptr)
return false;
storage = std::move(*this->ptr);
// resume writer coroutine
if (auto coro = coroutine_handle<void>::from_address(this->frame))
coro.resume();
return true;
}
};
template <typename T, typename M, typename Fn>
void select(channel<T, M>& ch, Fn&& fn) noexcept(false) {
static_assert(sizeof(channel_reader<T, M>) == sizeof(channel_peeker<T, M>));
channel_peeker p{ch}; // peeker will move element
T storage{}; // into the call stack
p.peek(); // the channel has waiting writer?
if (p.acquire(storage)) // acquire + resume writer
fn(storage); // invoke the function
}
template <typename... Args, typename Ch, typename Fn>
void select(Ch& ch, Fn&& fn, Args&&... args) noexcept(false) {
using namespace std;
select(ch, forward<Fn&&>(fn)); // evaluate
return select(forward<Args&&>(args)...); // try next pair
}
} // namespace coro
#endif // LUNCLIFF_COROUTINE_CHANNEL_HPP