diff --git a/lib/include/srslte/common/choice_type.h b/lib/include/srslte/common/choice_type.h new file mode 100644 index 000000000..3a34d9ffa --- /dev/null +++ b/lib/include/srslte/common/choice_type.h @@ -0,0 +1,345 @@ +/* + * Copyright 2013-2020 Software Radio Systems Limited + * + * This file is part of srsLTE. + * + * srsLTE is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of + * the License, or (at your option) any later version. + * + * srsLTE is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * A copy of the GNU Affero General Public License can be found in + * the LICENSE file in the top-level directory of this distribution + * and at http://www.gnu.org/licenses/. + * + */ + +#include +#include +#include + +#ifndef SRSLTE_CHOICE_TYPE_H +#define SRSLTE_CHOICE_TYPE_H + +namespace srslte { + +namespace choice_details { + +using size_idx_t = std::size_t; + +static constexpr size_idx_t invalid_idx = std::numeric_limits::max(); + +class bad_choice_access : public std::runtime_error +{ +public: + explicit bad_choice_access(const std::string& what_arg) : runtime_error(what_arg) {} + explicit bad_choice_access(const char* what_arg) : runtime_error(what_arg) {} +}; + +//! Get Index of a type in a list of types (reversed) +template +struct type_indexer; + +template +struct type_indexer { + static constexpr size_idx_t index = + std::is_same::value ? sizeof...(Types) : type_indexer::index; +}; + +template +struct type_indexer { + static constexpr size_idx_t index = invalid_idx; +}; + +//! Get a type of an index in a list of types +template +struct type_get; + +template +struct type_get { + using type = typename std::conditional::type>::type; +}; + +template +struct type_get { + using type = void; +}; + +//! Compute maximum at compile time +template +struct static_max; +template +struct static_max { + static const std::size_t value = arg; +}; +template +struct static_max { + static const std::size_t value = + arg1 >= arg2 ? static_max::value : static_max::value; +}; + +//! Holds one of the Args types +template +struct choice_storage_t { + static const std::size_t max_size = MaxSize; + static const std::size_t max_align = MaxAlign; + using buffer_t = typename std::aligned_storage::type; + buffer_t buffer; + + void* get_buffer() { return &buffer; } + + template + T& get_unsafe() + { + return *reinterpret_cast(&buffer); + } + + template + const T& get_unsafe() const + { + return *reinterpret_cast(&buffer); + } + + template + void destroy_unsafe() + { + get_unsafe().~U(); + }; +}; + +/************************* + * Tagged Union Helpers + ************************/ + +template +struct CopyCtorVisitor { + explicit CopyCtorVisitor(C* c_) : c(c_) {} + template + void operator()(const T& t) + { + c->construct_unsafe(t); + } + C* c; +}; + +template +struct MoveCtorVisitor { + explicit MoveCtorVisitor(C* c_) : c(c_) {} + template + void operator()(T&& t) + { + c->construct_unsafe(std::move(t)); + } + C* c; +}; + +template +struct DtorUnsafeVisitor { + explicit DtorUnsafeVisitor(C* c_) : c(c_) {} + template + void operator()(T& t) + { + c->template destroy_unsafe(); + } + C* c; +}; + +/** + * @brief visit pattern implementation + * @tparam F functor + * @tparam V tagged union type + * @tparam Types remaining types to iterate + */ +template +struct visit_impl; + +template +struct visit_impl { + static void apply(V& c, F&& f) + { + if (c.template is()) { + f(c.template get_unsafe()); + } else { + visit_impl::apply(c, std::forward(f)); + } + } + static void apply(const V& c, F&& f) + { + if (c.template is()) { + f(c.template get_unsafe()); + } else { + visit_impl::apply(c, std::forward(f)); + } + } +}; +template +struct visit_impl { + static void apply(V& c, F&& f) { f(c.template get_unsafe()); } + static void apply(const V& c, F&& f) { f(c.template get_unsafe()); } +}; + +template +struct tagged_union_t; + +template +void visit(tagged_union_t& u, Functor&& f) +{ + visit_impl, First, Args...>::apply(u, std::forward(f)); +} +template +void visit(const tagged_union_t& u, Functor&& f) +{ + visit_impl, First, Args...>::apply(u, std::forward(f)); +} + +template +struct tagged_union_t + : private choice_storage_t::value, static_max::value> { + using base_t = choice_storage_t::value, static_max::value>; + using buffer_t = typename base_t::buffer_t; + using this_type = tagged_union_t; + using default_type = typename type_get::type; + + std::size_t type_id; + + using base_t::destroy_unsafe; + using base_t::get_buffer; + using base_t::get_unsafe; + + auto construct_unsafe() -> decltype(default_type(), void()) + { + type_id = sizeof...(Args) - 1; + new (get_buffer()) default_type(); + } + + template + void construct_unsafe(U&& u) + { + using U2 = typename std::decay::type; + static_assert(type_indexer::index != invalid_idx, + "The provided type to ctor is not part of the list of possible types"); + type_id = type_indexer::index; + new (get_buffer()) U2(std::forward(u)); + } + + void copy_unsafe(const this_type& other) { visit(other, CopyCtorVisitor{this}); } + + void move_unsafe(this_type&& other) { visit(other, MoveCtorVisitor{this}); } + + void dtor_unsafe() { visit(*this, choice_details::DtorUnsafeVisitor{this}); } + + template ::type> + T& get_unsafe() + { + return get_unsafe(); + } + + template + bool is() const + { + return type_indexer::index == type_id; + } + + template + T& get() + { + if (is()) { + return get_unsafe(); + } + throw choice_details::bad_choice_access{"in get"}; + } + + template + T* get_if() + { + return (is()) ? &get_unsafe() : nullptr; + } + + template + constexpr static bool can_hold_type() + { + return type_indexer::index != invalid_idx; + } +}; + +} // namespace choice_details + +template +class choice_t : private choice_details::tagged_union_t +{ + using base_t = choice_details::tagged_union_t; + template + using enable_if_can_hold = + typename std::enable_if::type>()>::type; + +public: + using base_t::can_hold_type; + using base_t::get; + using base_t::get_if; + using base_t::is; + + choice_t() noexcept { this_union().construct_unsafe(); } + + choice_t(const choice_t& other) noexcept { base_t::copy_unsafe(other); } + + choice_t(choice_t&& other) noexcept { base_t::move_unsafe(std::move(other)); } + + template > + explicit choice_t(U&& u) noexcept + { + base_t::construct_unsafe(std::forward(u)); + } + + ~choice_t() { base_t::dtor_unsafe(); } + + template > + choice_t& operator=(U&& u) noexcept + { + if (not base_t::template is()) { + base_t::dtor_unsafe(); + } + base_t::construct_unsafe(std::forward(u)); + return *this; + } + + template + void emplace(U&& u) noexcept + { + *this = std::forward(u); + } + + choice_t& operator=(const choice_t& other) noexcept + { + if (this != &other) { + base_t::dtor_unsafe(); + base_t::copy_unsafe(other); + } + return *this; + } + + choice_t& operator=(choice_t&& other) noexcept + { + base_t::dtor_unsafe(); + base_t::move_unsafe(other); + return *this; + } + +private: + base_t& this_union() { return *this; } + const base_t& this_union() const { return *this; } +}; + +// template +// void visit(choice_t& u, Functor&& f) +//{ +// choice_details::visit_impl::template apply(u, f); +//} + +} // namespace srslte + +#endif // SRSLTE_CHOICE_TYPE_H diff --git a/lib/test/common/CMakeLists.txt b/lib/test/common/CMakeLists.txt index c3d2a2839..f1613faa0 100644 --- a/lib/test/common/CMakeLists.txt +++ b/lib/test/common/CMakeLists.txt @@ -68,10 +68,10 @@ target_link_libraries(pdu_test srslte_phy srslte_common ${CMAKE_THREAD_LIBS_INIT add_test(pdu_test pdu_test) if (ENABLE_5GNR) - add_executable(mac_nr_pdu_test mac_nr_pdu_test.cc) - target_link_libraries(mac_nr_pdu_test srslte_phy srslte_common ${CMAKE_THREAD_LIBS_INIT}) - add_test(mac_nr_pdu_test mac_nr_pdu_test) -endif(ENABLE_5GNR) + add_executable(mac_nr_pdu_test mac_nr_pdu_test.cc) + target_link_libraries(mac_nr_pdu_test srslte_phy srslte_common ${CMAKE_THREAD_LIBS_INIT}) + add_test(mac_nr_pdu_test mac_nr_pdu_test) +endif (ENABLE_5GNR) add_executable(stack_procedure_test stack_procedure_test.cc) add_test(stack_procedure_test stack_procedure_test) @@ -94,4 +94,8 @@ add_test(test_common_test test_common_test) add_executable(tti_point_test tti_point_test.cc) target_link_libraries(tti_point_test srslte_common) -add_test(tti_point_test tti_point_test) \ No newline at end of file +add_test(tti_point_test tti_point_test) + +add_executable(choice_type_test choice_type_test.cc) +target_link_libraries(choice_type_test srslte_common) +add_test(choice_type_test choice_type_test) diff --git a/lib/test/common/choice_type_test.cc b/lib/test/common/choice_type_test.cc new file mode 100644 index 000000000..937e98c42 --- /dev/null +++ b/lib/test/common/choice_type_test.cc @@ -0,0 +1,173 @@ +/* + * Copyright 2013-2020 Software Radio Systems Limited + * + * This file is part of srsLTE. + * + * srsLTE is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of + * the License, or (at your option) any later version. + * + * srsLTE is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * A copy of the GNU Affero General Public License can be found in + * the LICENSE file in the top-level directory of this distribution + * and at http://www.gnu.org/licenses/. + * + */ + +#include "srslte/common/choice_type.h" +#include "srslte/common/test_common.h" + +namespace srslte { +namespace choice_details { + +static_assert(static_max<1, 2>::value == 2, "StaticMax not working"); +static_assert(static_max<2, 1>::value == 2, "StaticMax not working"); +static_assert(type_indexer::index == 0, "Type indexer not working"); +static_assert(type_indexer::index == 4, "Type indexer not working"); +static_assert(type_indexer::index == invalid_idx, "Type Indexer not working"); +static_assert(sizeof(choice_storage_t<5, 4>) == 8, "Size of storage wrongly computed"); +static_assert(alignof(choice_storage_t<5, 4>) == 4, "Alignment of storage wrongly computed"); +static_assert(std::is_same::type, double>::value, + "type index-based search not working"); +static_assert(std::is_same::type, float>::value, + "type index-based search not working"); +static_assert(std::is_same::default_type, char>::value, + "Default type is incorrect\n"); +static_assert(tagged_union_t::can_hold_type(), "Can hold type implementation is incorrect\n"); +static_assert(not tagged_union_t::can_hold_type(), + "Can hold type implementation is incorrect\n"); + +struct C { + static int counter; + C() { counter++; } + C(C&&) { counter++; } + C(const C&) { counter++; } + C& operator=(C&& other) + { + counter++; + return *this; + } + C& operator=(const C& other) + { + if (this != &other) { + counter++; + } + return *this; + } + ~C() { counter--; } +}; +int C::counter = 0; + +struct D { + static int counter; + D() { counter++; } + D(D&&) { counter++; } + D(const D&) = delete; + D& operator=(const D&) = delete; + D& operator==(D&&) + { + counter++; + return *this; + } + ~D() { counter--; } +}; +int D::counter = 0; + +int test_tagged_union() +{ + tagged_union_t u; + u.construct_unsafe(5); + TESTASSERT(u.is()); + TESTASSERT(u.get_unsafe() == 5); + TESTASSERT(u.get_unsafe<1>() == 5); + u.destroy_unsafe(); + + TESTASSERT(C::counter == 0); + u.construct_unsafe(C{}); + TESTASSERT(C::counter == 1); + u.destroy_unsafe(); + TESTASSERT(C::counter == 0); + + return SRSLTE_SUCCESS; +} + +int test_choice() +{ + TESTASSERT(C::counter == 0); + TESTASSERT(D::counter == 0); + { + int i = 6; + C c0{}; + + // TEST: correct construction, holding the right type and value + choice_t c, c2{i}, c3{c0}; + TESTASSERT(c.is()); + TESTASSERT(c2.is() and c2.get() == i and *c2.get_if() == i); + TESTASSERT(c3.is()); + TESTASSERT(C::counter == 2); + + // TEST: Invalid member access. get<>() should throw + TESTASSERT(c2.get_if() == nullptr); + bool catched = false; + try { + char n = '1'; + n = c2.get(); + TESTASSERT(n == '1'); + } catch (choice_details::bad_choice_access& e) { + catched = true; + } + TESTASSERT(catched); + + // TEST: simple emplace after construction + c2 = 'c'; + TESTASSERT(c2.is() and c2.get() == 'c'); + + // TEST: copy ctor test. + choice_t c5{c3}; + TESTASSERT(C::counter == 3); + TESTASSERT(c5.is()); + TESTASSERT(c5.get_if() == &c5.get()); + + // TEST: copy assignment + c = c5; + TESTASSERT(C::counter == 4); + TESTASSERT(c.is() and c.get_if() != c5.get_if()); + c = c2; + TESTASSERT(C::counter == 3); + TESTASSERT(c2.is() and c.get() == 'c'); + } + TESTASSERT(C::counter == 0); + TESTASSERT(D::counter == 0); + { + choice_t c, c2{5.0}, c3{C{}}, c4{D{}}; + TESTASSERT(c.is()); + TESTASSERT(c2.is() and c2.get() == 5.0 and *c2.get_if() == 5.0); + TESTASSERT(c3.is()); + TESTASSERT(c4.is()); + TESTASSERT(C::counter == 1); + TESTASSERT(D::counter == 1); + + choice_t c5{std::move(c3)}; + TESTASSERT(C::counter == 2); + choice_t c6{std::move(c4)}; + TESTASSERT(D::counter == 2); + } + TESTASSERT(C::counter == 0); + TESTASSERT(D::counter == 0); + + return SRSLTE_SUCCESS; +} + +} // namespace choice_details +} // namespace srslte + +int main() +{ + TESTASSERT(srslte::choice_details::test_tagged_union() == SRSLTE_SUCCESS); + TESTASSERT(srslte::choice_details::test_choice() == SRSLTE_SUCCESS); +}