File channel.hpp

File List > coroutine > channel.hpp

Go to the documentation of this file.

#pragma once
#ifndef LUNCLIFF_COROUTINE_CHANNEL_HPP
#define LUNCLIFF_COROUTINE_CHANNEL_HPP
#include <mutex>
#include <tuple>

#if __has_include(<coroutine/frame.h>) && !defined(USE_EXPERIMENTAL_COROUTINE)
#include <coroutine/frame.h>
namespace coro {
using std::coroutine_handle;
using std::suspend_always;
using std::suspend_never;

#elif __has_include(<experimental/coroutine>)
#include <experimental/coroutine>
namespace coro {
using std::experimental::coroutine_handle;
using std::experimental::suspend_always;
using std::experimental::suspend_never;

#else
#error "requires header <experimental/coroutine> or <coroutine/frame.h>"
#endif

struct bypass_mutex final {
    constexpr bool try_lock() noexcept {
        return true;
    }
    constexpr void lock() noexcept {
    }
    constexpr void unlock() noexcept {
    }
};

namespace internal {

static void* poison() noexcept(false) {
    return reinterpret_cast<void*>(0xFADE'038C'BCFA'9E64);
}

template <typename T>
class list {
    T* head{};
    T* tail{};

  public:
    bool is_empty() const noexcept(false) {
        return head == nullptr;
    }
    void push(T* node) noexcept(false) {
        if (tail) {
            tail->next = node;
            tail = node;
        } else
            head = tail = node;
    }
    auto pop() noexcept(false) -> T* {
        T* node = head;
        if (head == tail) // empty or 1
            head = tail = nullptr;
        else // 2 or more
            head = head->next;
        return node;
    }
};
} // namespace internal

template <typename T, typename M = bypass_mutex>
class channel; // by default, channel doesn't care about the race condition
template <typename T, typename M>
class channel_reader;
template <typename T, typename M>
class channel_writer;
template <typename T, typename M>
class channel_peeker;

template <typename T, typename M>
class channel_reader {
  public:
    using value_type = T;
    using pointer = T*;
    using reference = T&;
    using channel_type = channel<T, M>;

  private:
    using reader_list = typename channel_type::reader_list;
    using writer = typename channel_type::writer;
    using writer_list = typename channel_type::writer_list;

    friend channel_type;
    friend writer;
    friend reader_list;

  protected:
    mutable pointer ptr; 
    mutable void* frame; 
    union {
        channel_reader* next = nullptr; 
        channel_type* chan;             
    };

  protected:
    explicit channel_reader(channel_type& ch) noexcept(false)
        : ptr{}, frame{nullptr}, chan{std::addressof(ch)} {
    }
    channel_reader(const channel_reader&) noexcept = delete;
    channel_reader& operator=(const channel_reader&) noexcept = delete;
    channel_reader(channel_reader&&) noexcept = delete;
    channel_reader& operator=(channel_reader&&) noexcept = delete;

  public:
    ~channel_reader() noexcept = default;

  public:
    bool await_ready() const noexcept(false) {
        chan->mtx.lock();
        if (chan->writer_list::is_empty())
            // await_suspend will unlock in the case
            return false;

        writer* w = chan->writer_list::pop();
        // exchange address & resumeable_handle
        std::swap(this->ptr, w->ptr);
        std::swap(this->frame, w->frame);

        chan->mtx.unlock();
        return true;
    }
    void await_suspend(coroutine_handle<void> coro) noexcept(false) {
        // notice that next & chan are sharing memory
        channel_type& ch = *(this->chan);
        // remember handle before push/unlock
        this->frame = coro.address();
        this->next = nullptr;
        // push to channel
        ch.reader_list::push(this);
        ch.mtx.unlock();
    }
    auto await_resume() noexcept(false) -> std::tuple<value_type, bool> {
        auto t = std::make_tuple(value_type{}, false);
        // frame holds poision if the channel is under destruction
        if (this->frame == internal::poison())
            return t;
        // the resume operation can destroy the other coroutine
        // store before resume
        std::get<0>(t) = std::move(*ptr);
        if (auto coro = coroutine_handle<void>::from_address(frame))
            coro.resume();
        std::get<1>(t) = true;
        return t;
    }
};

template <typename T, typename M>
class channel_writer {
  public:
    using value_type = T;
    using pointer = T*;
    using reference = T&;
    using channel_type = channel<T, M>;

  private:
    using reader = typename channel_type::reader;
    using reader_list = typename channel_type::reader_list;
    using writer_list = typename channel_type::writer_list;
    using peeker = typename channel_type::peeker;

    friend channel_type;
    friend reader;
    friend writer_list;
    friend peeker; // for `peek()` implementation

  private:
    mutable pointer ptr; 
    mutable void* frame; 
    union {
        channel_writer* next = nullptr; 
        channel_type* chan;             
    };

  private:
    explicit channel_writer(channel_type& ch, pointer pv) noexcept(false)
        : ptr{pv}, frame{nullptr}, chan{std::addressof(ch)} {
    }
    channel_writer(const channel_writer&) noexcept = delete;
    channel_writer& operator=(const channel_writer&) noexcept = delete;
    channel_writer(channel_writer&&) noexcept = delete;
    channel_writer& operator=(channel_writer&&) noexcept = delete;

  public:
    ~channel_writer() noexcept = default;

  public:
    bool await_ready() const noexcept(false) {
        chan->mtx.lock();
        if (chan->reader_list::is_empty())
            // await_suspend will unlock in the case
            return false;

        reader* r = chan->reader_list::pop();
        // exchange address & resumeable_handle
        std::swap(this->ptr, r->ptr);
        std::swap(this->frame, r->frame);

        chan->mtx.unlock();
        return true;
    }
    void await_suspend(coroutine_handle<void> coro) noexcept(false) {
        // notice that next & chan are sharing memory
        channel_type& ch = *(this->chan);

        this->frame = coro.address(); // remember handle before push/unlock
        this->next = nullptr;         // clear to prevent confusing

        ch.writer_list::push(this); // push to channel
        ch.mtx.unlock();
    }
    bool await_resume() noexcept(false) {
        // frame holds poision if the channel is under destruction
        if (this->frame == internal::poison())
            return false;
        if (auto coro = coroutine_handle<void>::from_address(frame))
            coro.resume();
        return true;
    }
};

template <typename T, typename M>
class channel final : internal::list<channel_reader<T, M>>,
                      internal::list<channel_writer<T, M>> {
    static_assert(std::is_reference<T>::value == false,
                  "reference type can't be channel's value_type.");

  public:
    using value_type = T;
    using pointer = value_type*;
    using reference = value_type&;
    using mutex_type = M;

  private:
    using reader = channel_reader<value_type, mutex_type>;
    using reader_list = internal::list<reader>;
    using writer = channel_writer<value_type, mutex_type>;
    using writer_list = internal::list<writer>;
    using peeker = channel_peeker<value_type, mutex_type>;

    friend reader;
    friend writer;
    friend peeker; // for `peek()` implementation

  private:
    mutex_type mtx{};

  private:
    channel(const channel&) noexcept(false) = delete;
    channel(channel&&) noexcept(false) = delete;
    channel& operator=(const channel&) noexcept(false) = delete;
    channel& operator=(channel&&) noexcept(false) = delete;

  public:
    channel() noexcept(false) : reader_list{}, writer_list{}, mtx{} {
    }

    ~channel() noexcept(false) {
        void* closing = internal::poison();
        writer_list& writers = *this;
        reader_list& readers = *this;
        // even 5'000+ can be unsafe for hazard usage ...
        size_t repeat = 1;
        do {
            std::unique_lock lck{mtx};
            while (writers.is_empty() == false) {
                writer* w = writers.pop();
                auto coro = coroutine_handle<void>::from_address(w->frame);
                w->frame = closing;

                coro.resume();
            }
            while (readers.is_empty() == false) {
                reader* r = readers.pop();
                auto coro = coroutine_handle<void>::from_address(r->frame);
                r->frame = closing;

                coro.resume();
            }
        } while (repeat--);
    }

  public:
    decltype(auto) write(reference ref) noexcept(false) {
        return channel_writer{*this, std::addressof(ref)};
    }
    decltype(auto) read() noexcept(false) {
        return channel_reader{*this};
    }
};

template <typename T, typename M>
class channel_peeker final : protected channel_reader<T, M> {
    using channel_type = channel<T, M>;
    using writer = typename channel_type::writer;
    using writer_list = typename channel_type::writer_list;

  private:
    channel_peeker(const channel_peeker&) noexcept(false) = delete;
    channel_peeker(channel_peeker&&) noexcept(false) = delete;
    channel_peeker& operator=(const channel_peeker&) noexcept(false) = delete;
    channel_peeker& operator=(channel_peeker&&) noexcept(false) = delete;

  public:
    explicit channel_peeker(channel_type& ch) noexcept(false)
        : channel_reader<T, M>{ch} {
    }
    ~channel_peeker() noexcept = default;

  public:
    void peek() const noexcept(false) {
        std::unique_lock lck{this->chan->mtx};
        if (this->chan->writer_list::is_empty() == false) {
            writer* w = this->chan->writer_list::pop();
            std::swap(this->ptr, w->ptr);
            std::swap(this->frame, w->frame);
        }
    }
    bool acquire(T& storage) noexcept(false) {
        // if there was a writer, take its value
        if (this->ptr == nullptr)
            return false;
        storage = std::move(*this->ptr);
        // resume writer coroutine
        if (auto coro = coroutine_handle<void>::from_address(this->frame))
            coro.resume();
        return true;
    }
};

template <typename T, typename M, typename Fn>
void select(channel<T, M>& ch, Fn&& fn) noexcept(false) {
    static_assert(sizeof(channel_reader<T, M>) == sizeof(channel_peeker<T, M>));
    channel_peeker p{ch};   // peeker will move element
    T storage{};            //    into the call stack
    p.peek();               // the channel has waiting writer?
    if (p.acquire(storage)) // acquire + resume writer
        fn(storage);        // invoke the function
}

template <typename... Args, typename Ch, typename Fn>
void select(Ch& ch, Fn&& fn, Args&&... args) noexcept(false) {
    using namespace std;
    select(ch, forward<Fn&&>(fn));           // evaluate
    return select(forward<Args&&>(args)...); // try next pair
}

} // namespace coro

#endif // LUNCLIFF_COROUTINE_CHANNEL_HPP