diff --git a/lib/include/srslte/common/task.h b/lib/include/srslte/common/task.h index 2bd6c97de..ecdba51de 100644 --- a/lib/include/srslte/common/task.h +++ b/lib/include/srslte/common/task.h @@ -28,9 +28,9 @@ namespace srslte { -constexpr size_t default_buffer_size = 256; +constexpr size_t default_buffer_size = 32; -template +template class inplace_task; namespace task_details { @@ -39,30 +39,42 @@ template struct oper_table_t { using call_oper_t = R (*)(void* src, Args&&... args); using move_oper_t = void (*)(void* src, void* dest); - // using copy_oper_t = void (*)(void* src, void* dest); using dtor_oper_t = void (*)(void* src); - static oper_table_t* get_empty() noexcept + const static oper_table_t* get_empty() noexcept { - static oper_table_t t; - t.call = [](void* src, Args&&... args) -> R { throw std::bad_function_call(); }; - t.move = [](void*, void*) {}; - // t.copy = [](void*, void*) {}; - t.dtor = [](void*) {}; + const static oper_table_t t{true, + [](void* src, Args&&... args) -> R { throw std::bad_function_call(); }, + [](void*, void*) {}, + [](void*) {}}; return &t; } template - static oper_table_t* get() noexcept + const static oper_table_t* get_small() noexcept { - static oper_table_t t{}; - t.call = [](void* src, Args&&... args) -> R { return (*static_cast(src))(std::forward(args)...); }; - t.move = [](void* src, void* dest) -> void { - ::new (dest) Func{std::move(*static_cast(src))}; - static_cast(src)->~Func(); - }; - // t.copy = [](void* src, void* dest) -> void { ::new (dest) Func{*static_cast(src)}; }; - t.dtor = [](void* src) -> void { static_cast(src)->~Func(); }; + const static oper_table_t t{ + true, + [](void* src, Args&&... args) -> R { return (*static_cast(src))(std::forward(args)...); }, + [](void* src, void* dest) -> void { + ::new (dest) Func{std::move(*static_cast(src))}; + static_cast(src)->~Func(); + }, + [](void* src) -> void { static_cast(src)->~Func(); }}; + return &t; + } + + template + const static oper_table_t* get_big() noexcept + { + const static oper_table_t t{ + false, + [](void* src, Args&&... args) -> R { return (*static_cast(src))(std::forward(args)...); }, + [](void* src, void* dest) -> void { + *static_cast(dest) = *static_cast(src); + *static_cast(src) = nullptr; + }, + [](void* src) -> void { static_cast(src)->~Func(); }}; return &t; } @@ -72,101 +84,123 @@ struct oper_table_t { oper_table_t& operator=(oper_table_t&&) = delete; ~oper_table_t() = default; + bool is_in_buffer; call_oper_t call; move_oper_t move; - // copy_oper_t copy; dtor_oper_t dtor; - static oper_table_t* empty_oper; - private: oper_table_t() = default; + oper_table_t(bool is_in_buffer_, call_oper_t call_, move_oper_t move_, dtor_oper_t dtor_) : + is_in_buffer(is_in_buffer_), + call(call_), + move(move_), + dtor(dtor_) + {} }; template struct is_inplace_task : std::false_type {}; -template -struct is_inplace_task > : std::true_type {}; - -template -oper_table_t* oper_table_t::empty_oper = oper_table_t::get_empty(); +template +struct is_inplace_task > : std::true_type {}; } // namespace task_details -template -class inplace_task +template +class inplace_task { - using storage_t = typename std::aligned_storage::type; - using oper_table_t = task_details::oper_table_t; + static constexpr size_t capacity = Capacity >= sizeof(void*) ? Capacity : sizeof(void*); + using storage_t = typename std::aligned_storage::type; + using oper_table_t = task_details::oper_table_t; public: - inplace_task() noexcept { oper_ptr = oper_table_t::empty_oper; } + inplace_task() noexcept { oper_ptr = oper_table_t::get_empty(); } template ::type, + typename = typename std::enable_if::type, typename = typename std::enable_if::value>::type> inplace_task(T&& function) { - static_assert(sizeof(FunT) <= sizeof(buffer), "inplace_task cannot store object with given size.\n"); - static_assert(Alignment % alignof(FunT) == 0, "inplace_task cannot store object with given alignment.\n"); - + oper_ptr = oper_table_t::template get_small(); ::new (&buffer) FunT{std::forward(function)}; - oper_ptr = oper_table_t::template get(); + } + + template ::type, + typename = typename std::enable_if::value and + (sizeof(FunT) > capacity)>::type> + inplace_task(T&& function) + { + oper_ptr = oper_table_t::template get_big(); + ptr = static_cast(new FunT{std::forward(function)}); } inplace_task(inplace_task&& other) noexcept { oper_ptr = other.oper_ptr; - other.oper_ptr = oper_table_t::empty_oper; - oper_ptr->move(&other.buffer, &buffer); + other.oper_ptr = oper_table_t::get_empty(); + if (oper_ptr->is_in_buffer) { + oper_ptr->move(&other.buffer, &buffer); + } else { + oper_ptr->move(&other.ptr, &ptr); + } } - // inplace_task(const inplace_task& other) noexcept - // { - // oper_ptr = other.oper_ptr; - // oper_ptr->copy(&other.buffer, &buffer); - // } - - ~inplace_task() { oper_ptr->dtor(&buffer); } + ~inplace_task() { oper_ptr->dtor(get_buffer()); } inplace_task& operator=(inplace_task&& other) noexcept { - oper_ptr->dtor(&buffer); + oper_ptr->dtor(get_buffer()); oper_ptr = other.oper_ptr; - other.oper_ptr = oper_table_t::empty_oper; - oper_ptr->move(&other.buffer, &buffer); + other.oper_ptr = oper_table_t::get_empty(); + if (oper_ptr->is_in_buffer) { + oper_ptr->move(&other.buffer, &buffer); + } else { + oper_ptr->move(&other.ptr, &ptr); + } return *this; } - // inplace_task& operator=(const inplace_task& other) noexcept - // { - // if (this != &other) { - // oper_ptr->dtor(&buffer); - // oper_ptr = other.oper_ptr; - // oper_ptr->copy(&other.buffer, &buffer); - // } - // return *this; - // } + R operator()(Args&&... args) { return oper_ptr->call(get_buffer(), std::forward(args)...); } - R operator()(Args&&... args) { return oper_ptr->call(&buffer, std::forward(args)...); } - - bool is_empty() const { return oper_ptr == oper_table_t::empty_oper; } + bool is_empty() const { return oper_ptr == oper_table_t::get_empty(); } + bool is_in_small_buffer() const { return oper_ptr->is_in_buffer; } void swap(inplace_task& other) noexcept { if (this == &other) return; - storage_t tmp; - oper_ptr->move(&buffer, &tmp); - other.oper_ptr->move(&other.buffer, &buffer); - oper_ptr->move(&tmp, &other.buffer); + if (oper_ptr->is_in_buffer and other.oper_ptr->is_in_buffer) { + storage_t tmp; + oper_ptr->move(&buffer, &tmp); + other.oper_ptr->move(&other.buffer, &buffer); + oper_ptr->move(&tmp, &other.buffer); + } else if (oper_ptr->is_in_buffer and not other.oper_ptr->is_in_buffer) { + void* tmpptr = other.ptr; + oper_ptr->move(&buffer, &other.buffer); + ptr = tmpptr; + } else if (not oper_ptr->is_in_buffer and other.oper_ptr->is_in_buffer) { + void* tmpptr = ptr; + other.oper_ptr->move(&other.buffer, &buffer); + oper_ptr->move(&tmpptr, &other.ptr); + } else { + std::swap(ptr, other.ptr); + } std::swap(oper_ptr, other.oper_ptr); } + friend void swap(inplace_task& lhs, inplace_task& rhs) noexcept { lhs.swap(rhs); } + private: - storage_t buffer; - oper_table_t* oper_ptr; + union { + storage_t buffer; + void* ptr; + }; + const oper_table_t* oper_ptr; + + void* get_buffer() { return oper_ptr->is_in_buffer ? &buffer : ptr; } }; } // namespace srslte diff --git a/lib/test/common/queue_test.cc b/lib/test/common/queue_test.cc index b6c66019c..f38ff8e0e 100644 --- a/lib/test/common/queue_test.cc +++ b/lib/test/common/queue_test.cc @@ -324,6 +324,10 @@ int test_task_thread_pool3() struct C { std::unique_ptr val{new int{5}}; }; +struct D { + std::array big_val; + D() { big_val[0] = 6; } +}; int test_inplace_task() { @@ -350,6 +354,36 @@ int test_inplace_task() TESTASSERT(v == 5); } + D d; + srslte::inplace_task t6 = [&v, d]() { v = d.big_val[0]; }; + { + srslte::inplace_task t7; + t6(); + TESTASSERT(v == 6); + v = 0; + t7 = std::move(t6); + t7(); + TESTASSERT(v == 6); + } + + auto l1 = std::bind([&v](C& c) { v = *c.val; }, C{}); + auto l2 = [&v, d]() { v = d.big_val[0]; }; + t = std::move(l1); + t2 = l2; + v = 0; + t(); + TESTASSERT(v == 5); + t2(); + TESTASSERT(v == 6); + TESTASSERT(t.is_in_small_buffer() and not t2.is_in_small_buffer()); + swap(t, t2); + TESTASSERT(t2.is_in_small_buffer() and not t.is_in_small_buffer()); + v = 0; + t(); + TESTASSERT(v == 6); + t2(); + TESTASSERT(v == 5); + std::cout << "outcome: Success\n"; std::cout << "========================================\n"; return 0;