浏览代码

完善协程支持stop_token相关功能

tags/2.9.10
tearshark 4 年前
父节点
当前提交
ba29351ddd

+ 1
- 1
benchmark/benchmark_async_mem.cpp 查看文件

#include "librf.h" #include "librf.h"
const size_t N = 10000000;
const size_t N = 5000000;
const size_t LOOP_COUNT = 50; const size_t LOOP_COUNT = 50;
std::atomic<size_t> globalValue{0}; std::atomic<size_t> globalValue{0};

+ 11
- 8
librf/src/awaitable.h 查看文件

*/ */
void set_exception(std::exception_ptr e) const void set_exception(std::exception_ptr e) const
{ {
this->_state->set_exception(std::move(e));
this->_state = nullptr;
counted_ptr<state_type> cp(std::move(this->_state));
cp->set_exception(std::move(e));
} }


/** /**
struct [[nodiscard]] awaitable_t : public awaitable_impl_t<_Ty> struct [[nodiscard]] awaitable_t : public awaitable_impl_t<_Ty>
{ {
using typename awaitable_impl_t<_Ty>::value_type; using typename awaitable_impl_t<_Ty>::value_type;
using typename awaitable_impl_t<_Ty>::state_type;
using awaitable_impl_t<_Ty>::awaitable_impl_t; using awaitable_impl_t<_Ty>::awaitable_impl_t;


/** /**
template<class U> template<class U>
void set_value(U&& value) const void set_value(U&& value) const
{ {
this->_state->set_value(std::forward<U>(value));
this->_state = nullptr;
counted_ptr<state_type> cp(std::move(this->_state));
cp->set_value(std::forward<U>(value));
} }
}; };


struct [[nodiscard]] awaitable_t<_Ty&> : public awaitable_impl_t<_Ty&> struct [[nodiscard]] awaitable_t<_Ty&> : public awaitable_impl_t<_Ty&>
{ {
using typename awaitable_impl_t<_Ty&>::value_type; using typename awaitable_impl_t<_Ty&>::value_type;
using typename awaitable_impl_t<_Ty&>::state_type;
using awaitable_impl_t<_Ty&>::awaitable_impl_t; using awaitable_impl_t<_Ty&>::awaitable_impl_t;


void set_value(_Ty& value) const void set_value(_Ty& value) const
{ {
this->_state->set_value(value);
this->_state = nullptr;
counted_ptr<state_type> cp(std::move(this->_state));
cp->set_value(value);
} }
}; };


template<> template<>
struct [[nodiscard]] awaitable_t<void> : public awaitable_impl_t<void> struct [[nodiscard]] awaitable_t<void> : public awaitable_impl_t<void>
{ {
using awaitable_impl_t<void>::state_type;
using awaitable_impl_t<void>::awaitable_impl_t; using awaitable_impl_t<void>::awaitable_impl_t;


void set_value() const void set_value() const
{ {
this->_state->set_value();
this->_state = nullptr;
counted_ptr<state_type> cp(std::move(this->_state));
cp->set_value();
} }
}; };
#endif //DOXYGEN_SKIP_PROPERTY #endif //DOXYGEN_SKIP_PROPERTY

+ 2
- 2
librf/src/channel_v2.h 查看文件

#ifndef DOXYGEN_SKIP_PROPERTY #ifndef DOXYGEN_SKIP_PROPERTY
RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>) RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>)
#endif //DOXYGEN_SKIP_PROPERTY #endif //DOXYGEN_SKIP_PROPERTY
write_awaiter operator << (U&& val) const noexcept(std::is_move_constructible_v<U>);
write_awaiter operator << (U&& val) const noexcept(std::is_nothrow_move_constructible_v<U>);


/** /**
* @brief 在协程中向channel_t里写入一个数据。 * @brief 在协程中向channel_t里写入一个数据。
#ifndef DOXYGEN_SKIP_PROPERTY #ifndef DOXYGEN_SKIP_PROPERTY
RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>) RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>)
#endif //DOXYGEN_SKIP_PROPERTY #endif //DOXYGEN_SKIP_PROPERTY
write_awaiter write(U&& val) const noexcept(std::is_move_constructible_v<U>);
write_awaiter write(U&& val) const noexcept(std::is_nothrow_move_constructible_v<U>);




#ifndef DOXYGEN_SKIP_PROPERTY #ifndef DOXYGEN_SKIP_PROPERTY

+ 2
- 2
librf/src/channel_v2.inl 查看文件

template<class _Ty, bool _Optional, bool _OptimizationThread> template<class _Ty, bool _Optional, bool _OptimizationThread>
template<class U COMMA_RESUMEF_ENABLE_IF_TYPENAME()> RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>) template<class U COMMA_RESUMEF_ENABLE_IF_TYPENAME()> RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>)
typename channel_t<_Ty, _Optional, _OptimizationThread>::write_awaiter typename channel_t<_Ty, _Optional, _OptimizationThread>::write_awaiter
channel_t<_Ty, _Optional, _OptimizationThread>::write(U&& val) const noexcept(std::is_move_constructible_v<U>)
channel_t<_Ty, _Optional, _OptimizationThread>::write(U&& val) const noexcept(std::is_nothrow_move_constructible_v<U>)
{ {
return write_awaiter{ _chan.get(), std::forward<U>(val) }; return write_awaiter{ _chan.get(), std::forward<U>(val) };
} }
template<class _Ty, bool _Optional, bool _OptimizationThread> template<class _Ty, bool _Optional, bool _OptimizationThread>
template<class U COMMA_RESUMEF_ENABLE_IF_TYPENAME()> RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>) template<class U COMMA_RESUMEF_ENABLE_IF_TYPENAME()> RESUMEF_REQUIRES(std::is_constructible_v<_Ty, U&&>)
typename channel_t<_Ty, _Optional, _OptimizationThread>::write_awaiter typename channel_t<_Ty, _Optional, _OptimizationThread>::write_awaiter
channel_t<_Ty, _Optional, _OptimizationThread>::operator << (U&& val) const noexcept(std::is_move_constructible_v<U>)
channel_t<_Ty, _Optional, _OptimizationThread>::operator << (U&& val) const noexcept(std::is_nothrow_move_constructible_v<U>)
{ {
return write_awaiter{ _chan.get(), std::forward<U>(val) }; return write_awaiter{ _chan.get(), std::forward<U>(val) };
} }

+ 23
- 8
librf/src/counted_ptr.h 查看文件

/** /**
* @brief 拷贝构造函数。 * @brief 拷贝构造函数。
*/ */
counted_ptr(const counted_ptr& cp) : _p(cp._p)
counted_ptr(const counted_ptr& cp) noexcept : _p(cp._p)
{ {
_lock(); _lock();
} }
/** /**
* @brief 通过裸指针构造一个计数指针。 * @brief 通过裸指针构造一个计数指针。
*/ */
counted_ptr(T* p) : _p(p)
counted_ptr(T* p) noexcept : _p(p)
{ {
_lock(); _lock();
} }
/** /**
* @brief 移动构造函数。 * @brief 移动构造函数。
*/ */
counted_ptr(counted_ptr&& cp) noexcept
counted_ptr(counted_ptr&& cp) noexcept : _p(std::exchange(cp._p, nullptr))
{ {
std::swap(_p, cp._p);
} }
/** /**
{ {
if (&cp != this) if (&cp != this)
{ {
counted_ptr t = cp;
counted_ptr t(cp);
std::swap(_p, t._p); std::swap(_p, t._p);
} }
return *this; return *this;
/** /**
* @brief 移动赋值函数。 * @brief 移动赋值函数。
*/ */
counted_ptr& operator=(counted_ptr&& cp) noexcept
counted_ptr& operator=(counted_ptr&& cp)
{ {
if (&cp != this) if (&cp != this)
{
std::swap(_p, cp._p); std::swap(_p, cp._p);
cp._unlock();
}
return *this; return *this;
} }
void swap(counted_ptr& cp) noexcept
{
std::swap(_p, cp._p);
}
/** /**
* @brief 析构函数中自动做一个计数减一操作。计数减为0,则删除state对象。 * @brief 析构函数中自动做一个计数减一操作。计数减为0,则删除state对象。
*/ */
t->unlock(); t->unlock();
} }
} }
void _lock(T* p)
void _lock(T* p) noexcept
{ {
if (p != nullptr) if (p != nullptr)
p->lock(); p->lock();
_p = p; _p = p;
} }
void _lock()
void _lock() noexcept
{ {
if (_p != nullptr) if (_p != nullptr)
_p->lock(); _p->lock();
} }
} }
namespace std
{
template<typename T>
inline void swap(resumef::counted_ptr<T>& a, resumef::counted_ptr<T>& b) noexcept
{
a.swap(b);
}
}

+ 1
- 7
librf/src/rf_task.cpp 查看文件

task_t::~task_t() task_t::~task_t()
{ {
///TODO : 这里有线程安全问题(2020/05/09)
_stop.clear_callback();
///TODO : 这里有线程安全问题(2020/05/09)
} }
const stop_source & task_t::get_stop_source() const stop_source & task_t::get_stop_source()
{ {
///TODO : 这里有线程安全问题(2020/05/09)
_stop.make_possible();
///TODO : 这里有线程安全问题(2020/05/09)
_stop.make_sure_possible();
return _stop; return _stop;
} }
} }

+ 25
- 3
librf/src/rf_task.h 查看文件

task_t(); task_t();
virtual ~task_t(); virtual ~task_t();
/// TODO : 存在BUG(2020/05/09)
/**
* @brief 获取stop_source,第一次获取时,会生成一个有效的stop_source。
* @return stop_source
*/
const stop_source & get_stop_source(); const stop_source & get_stop_source();
/// TODO : 存在BUG(2020/05/09)
/**
* @brief 获取一个跟stop_source绑定的,新的stop_token。
* @return stop_token
*/
stop_token get_stop_token() stop_token get_stop_token()
{ {
return get_stop_source().get_token(); return get_stop_source().get_token();
} }
/// TODO : 存在BUG(2020/05/09)
/**
* @brief 要求停止协程。
* @return bool 返回操作成功与否。
*/
bool request_stop() bool request_stop()
{ {
return get_stop_source().request_stop(); return get_stop_source().request_stop();
} }
/**
* @brief 要求停止协程。
* @return bool 返回操作成功与否。
*/
bool request_stop_if_possible()
{
if (_stop.stop_possible())
return _stop.request_stop();
return false;
}
protected: protected:
friend scheduler_t; friend scheduler_t;
counted_ptr<state_base_t> _state; counted_ptr<state_base_t> _state;

+ 10
- 10
librf/src/state.h 查看文件

private: private:
std::atomic<intptr_t> _count{0}; std::atomic<intptr_t> _count{0};
public: public:
void lock()
void lock() noexcept
{ {
++_count; ++_count;
} }
virtual bool has_handler() const noexcept = 0; virtual bool has_handler() const noexcept = 0;
virtual state_base_t* get_parent() const noexcept; virtual state_base_t* get_parent() const noexcept;
void set_scheduler(scheduler_t* sch)
void set_scheduler(scheduler_t* sch) noexcept
{ {
_scheduler = sch; _scheduler = sch;
} }
coroutine_handle<> get_handler() const
coroutine_handle<> get_handler() const noexcept
{ {
return _coro; return _coro;
} }
state_base_t* get_root() const noexcept
state_base_t* get_root() const
{ {
state_base_t* root = const_cast<state_base_t*>(this); state_base_t* root = const_cast<state_base_t*>(this);
state_base_t* next = root->get_parent(); state_base_t* next = root->get_parent();
bool switch_scheduler_await_suspend(scheduler_t* sch); bool switch_scheduler_await_suspend(scheduler_t* sch);
void set_initial_suspend(coroutine_handle<> handler)
void set_initial_suspend(coroutine_handle<> handler) noexcept
{ {
_coro = handler; _coro = handler;
} }
static_assert(sizeof(std::atomic<initor_type>) == 1); static_assert(sizeof(std::atomic<initor_type>) == 1);
static_assert(alignof(std::atomic<initor_type>) == 1); static_assert(alignof(std::atomic<initor_type>) == 1);
protected: protected:
explicit state_future_t(bool awaitor)
explicit state_future_t(bool awaitor) noexcept
{ {
#if RESUMEF_DEBUG_COUNTER #if RESUMEF_DEBUG_COUNTER
_id = ++g_resumef_state_id; _id = ++g_resumef_state_id;
return _alloc_size; return _alloc_size;
} }
inline bool future_await_ready() noexcept
inline bool future_await_ready() const noexcept
{ {
//scoped_lock<lock_type> __guard(this->_mtx); //scoped_lock<lock_type> __guard(this->_mtx);
return _has_value.load(std::memory_order_acquire) != result_type::None; return _has_value.load(std::memory_order_acquire) != result_type::None;
using state_future_t::lock_type; using state_future_t::lock_type;
using value_type = _Ty; using value_type = _Ty;
private: private:
explicit state_t(bool awaitor) :state_future_t(awaitor) {}
explicit state_t(bool awaitor) noexcept :state_future_t(awaitor) {}
public: public:
~state_t() ~state_t()
{ {
using value_type = _Ty; using value_type = _Ty;
using reference_type = _Ty&; using reference_type = _Ty&;
private: private:
explicit state_t(bool awaitor) :state_future_t(awaitor) {}
explicit state_t(bool awaitor) noexcept :state_future_t(awaitor) {}
public: public:
~state_t() ~state_t()
{ {
friend state_future_t; friend state_future_t;
using state_future_t::lock_type; using state_future_t::lock_type;
private: private:
explicit state_t(bool awaitor) :state_future_t(awaitor) {}
explicit state_t(bool awaitor) noexcept :state_future_t(awaitor) {}
public: public:
void future_await_resume(); void future_await_resume();
template<class _PromiseT, typename = std::enable_if_t<traits::is_promise_v<_PromiseT>>> template<class _PromiseT, typename = std::enable_if_t<traits::is_promise_v<_PromiseT>>>

+ 8
- 29
librf/src/stop_token.hpp 查看文件

* @ V1.0 * @ V1.0
*************************************************/ *************************************************/


//librf注:暂时使用一个网友提供的stop_token实现。
//等待C++20的stop_token被各个编译器都实现后,再使用STL里的stop_token来完成对应功能。
#pragma once #pragma once


namespace milk namespace milk
{ {
if (state_.fetch_sub(klockAndTokenRefIncrement, std::memory_order_acq_rel) < kLockedAndZeroRef) if (state_.fetch_sub(klockAndTokenRefIncrement, std::memory_order_acq_rel) < kLockedAndZeroRef)
{ {
clear_callback();
delete this; delete this;
} }
} }
auto old_state = state_.fetch_sub(kTokenRefIncrement, std::memory_order_acq_rel); auto old_state = state_.fetch_sub(kTokenRefIncrement, std::memory_order_acq_rel);
if (old_state < kZeroRef) if (old_state < kZeroRef)
{ {
clear_callback();
delete this; delete this;
} }
} }
auto old_state = state_.fetch_sub(kSourceRefIncrement, std::memory_order_acq_rel); auto old_state = state_.fetch_sub(kSourceRefIncrement, std::memory_order_acq_rel);
if (old_state < kZeroRef) if (old_state < kZeroRef)
{ {
clear_callback();
delete this; delete this;
} }
} }
remove_token_reference(); remove_token_reference();
} }


void clear_callback() noexcept
{
lock();
stop_callback_base* cb = head_;
head_ = nullptr;

while (cb)
{
stop_callback_base* tmp = cb->next;
cb->prev = nullptr;
cb->next = nullptr;
cb = tmp;
}
unlock();
}

}; };


}//namespace details }//namespace details
return state_ != nullptr; return state_ != nullptr;
} }


void make_possible()
void make_sure_possible()
{ {
if (state_ == nullptr) if (state_ == nullptr)
{ {
details::stop_state* st = new details::stop_state(); details::stop_state* st = new details::stop_state();
details::stop_state* tmp = nullptr; details::stop_state* tmp = nullptr;
if (!std::atomic_compare_exchange_strong_explicit( if (!std::atomic_compare_exchange_strong_explicit(
reinterpret_cast<std::atomic<details::stop_state*>*>(&state_), &tmp, st, std::memory_order_release, std::memory_order_acquire))
reinterpret_cast<std::atomic<details::stop_state*>*>(&state_),
&tmp,
st,
std::memory_order_release,
std::memory_order_acquire))
{ {
st->remove_source_reference(); st->remove_source_reference();
} }
return stop_token{state_}; return stop_token{state_};
} }


void clear_callback() const noexcept
{
if (state_)
{
state_->clear_callback();
}
}

void swap(stop_source& other) noexcept void swap(stop_source& other) noexcept
{ {
std::swap(state_, other.state_); std::swap(state_, other.state_);

+ 2
- 2
test_librf.cpp 查看文件

(void)argc; (void)argc;
(void)argv; (void)argv;


//resumable_main_mutex();
//resumable_main_stop_token();
//return 0; //return 0;


//if (argc > 1) //if (argc > 1)
resumable_main_sleep(); resumable_main_sleep();
resumable_main_when_all(); resumable_main_when_all();
resumable_main_switch_scheduler(); resumable_main_switch_scheduler();
//resumable_main_stop_token();
resumable_main_stop_token();
std::cout << "ALL OK!" << std::endl; std::cout << "ALL OK!" << std::endl;


benchmark_main_channel_passing_next(); //这是一个死循环测试 benchmark_main_channel_passing_next(); //这是一个死循环测试

+ 21
- 19
tutorial/test_async_stop_token.cpp 查看文件

using namespace resumef; using namespace resumef;
using namespace std::chrono; using namespace std::chrono;


//token触发停止后,将不再调用cb
template<class _Ctype>
//_Ctype签名:void(bool, int64_t)
template<class _Ctype, typename=std::enable_if_t<std::is_invocable_v<_Ctype, bool, int64_t>>>
static void callback_get_long_with_stop(stop_token token, int64_t val, _Ctype&& cb) static void callback_get_long_with_stop(stop_token token, int64_t val, _Ctype&& cb)
{ {
std::thread([val, token = std::move(token), cb = std::forward<_Ctype>(cb)] std::thread([val, token = std::move(token), cb = std::forward<_Ctype>(cb)]
for (int i = 0; i < 10; ++i) for (int i = 0; i < 10; ++i)
{ {
if (token.stop_requested()) if (token.stop_requested())
{
cb(false, 0);
return; return;
}
std::this_thread::sleep_for(10ms); std::this_thread::sleep_for(10ms);
} }


cb(val * val);
//有可能未检测到token的停止要求
//如果使用stop_callback来停止,则务必保证检测到的退出要求是唯一的,且线程安全的
//否则,多次调用cb,会导致协程在半退出状态下,外部的awaitable_t管理的state获取跟root出现错误。
cb(true, val * val);
}).detach(); }).detach();
} }


static future_t<int64_t> async_get_long_with_stop(stop_token token, int64_t val) static future_t<int64_t> async_get_long_with_stop(stop_token token, int64_t val)
{ {
awaitable_t<int64_t> awaitable; awaitable_t<int64_t> awaitable;
//保证stopptr的生存期,与callback_get_long_with_cancel()的回调参数的生存期一致。
//如果token已经被取消,则传入的lambda会立即被调用,则awaitable将不能再set_value
auto stopptr = make_stop_callback(token, [awaitable]
{
if (awaitable)
awaitable.throw_exception(canceled_exception(error_code::stop_requested));
});


if (awaitable) //处理已经被取消的情况
{
callback_get_long_with_stop(token, val, [awaitable, stopptr = std::move(stopptr)](int64_t val)
{
if (awaitable)
awaitable.set_value(val);
});
}
//在这里通过stop_callback来处理退出,并将退出转化为error_code::stop_requested异常。
//则必然会存在线程竞争问题,导致协程提前于callback_get_long_with_stop的回调之前而退出。
//同时,callback_get_long_with_stop还未必一定能检测到退出要求----毕竟只是一个要求,而不是强制。


callback_get_long_with_stop(token, val, [awaitable](bool ok, int64_t val)
{
if (ok)
awaitable.set_value(val);
else
awaitable.throw_exception(canceled_exception{error_code::stop_requested});
});
return awaitable.get_future(); return awaitable.get_future();
} }


srand((int)time(nullptr)); srand((int)time(nullptr));
for (int i = 0; i < 10; ++i) for (int i = 0; i < 10; ++i)
test_get_long_with_stop(i); test_get_long_with_stop(i);

std::cout << "OK - stop_token!" << std::endl;
} }

正在加载...
取消
保存