diff --git a/lib/include/srsran/adt/circular_map.h b/lib/include/srsran/adt/circular_map.h index 50ca4b318..6c3823204 100644 --- a/lib/include/srsran/adt/circular_map.h +++ b/lib/include/srsran/adt/circular_map.h @@ -13,6 +13,7 @@ #ifndef SRSRAN_ID_MAP_H #define SRSRAN_ID_MAP_H +#include "detail/type_storage.h" #include "expected.h" #include "srsran/common/srsran_assert.h" #include @@ -24,8 +25,7 @@ class static_circular_map { static_assert(std::is_integral::value and std::is_unsigned::value, "Map key must be an unsigned integer"); - using obj_t = std::pair; - using obj_storage_t = typename std::aligned_storage::type; + using obj_t = std::pair; public: class iterator @@ -48,23 +48,23 @@ public: obj_t& operator*() { - srsran_assert(idx < ptr->buffer.size(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->buffer.size()); + srsran_assert(idx < ptr->capacity(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->capacity()); return ptr->get_obj_(idx); } obj_t* operator->() { - srsran_assert(idx < ptr->buffer.size(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->buffer.size()); + srsran_assert(idx < ptr->capacity(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->capacity()); return &ptr->get_obj_(idx); } const obj_t* operator*() const { - srsran_assert(idx < ptr->buffer.size(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->buffer.size()); - return ptr->buffer[idx]; + srsran_assert(idx < ptr->capacity(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->capacity()); + return &ptr->get_obj_(idx); } const obj_t* operator->() const { - srsran_assert(idx < ptr->buffer.size(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->buffer.size()); - return ptr->buffer[idx]; + srsran_assert(idx < ptr->capacity(), "Iterator out-of-bounds (%zd >= %zd)", idx, ptr->capacity()); + return &ptr->get_obj_(idx); } bool operator==(const iterator& other) const { return ptr == other.ptr and idx == other.idx; } @@ -88,8 +88,8 @@ public: return *this; } - const obj_t* operator*() const { return ptr->buffer[idx]; } - const obj_t* operator->() const { return ptr->buffer[idx]; } + const obj_t* operator*() const { return &ptr->buffer[idx].get(); } + const obj_t* operator->() const { return &ptr->buffer[idx].get(); } bool operator==(const const_iterator& other) const { return ptr == other.ptr and idx == other.idx; } bool operator!=(const const_iterator& other) const { return not(*this == other); } @@ -103,17 +103,17 @@ public: static_circular_map() { std::fill(present.begin(), present.end(), false); } static_circular_map(const static_circular_map& other) : present(other.present), count(other.count) { - for (size_t idx = 0; idx < other.size(); ++idx) { + for (size_t idx = 0; idx < other.capacity(); ++idx) { if (present[idx]) { - new (&buffer[idx]) obj_t(other.get_obj_(idx)); + buffer[idx].template construct(other.get_obj_(idx)); } } } static_circular_map(static_circular_map&& other) noexcept : present(other.present), count(other.count) { - for (size_t idx = 0; idx < other.size(); ++idx) { + for (size_t idx = 0; idx < other.capacity(); ++idx) { if (present[idx]) { - new (&buffer[idx]) obj_t(std::move(other.get_obj_(idx))); + buffer[idx].template construct(std::move(other.get_obj_(idx))); } } other.clear(); @@ -124,26 +124,21 @@ public: if (this == &other) { return *this; } - clear(); + for (size_t idx = 0; idx < other.capacity(); ++idx) { + copy_if_present_helper(buffer[idx], other.buffer[idx], present[idx], other.present[idx]); + } count = other.count; present = other.present; - for (size_t idx = 0; idx < other.size(); ++idx) { - if (present[idx]) { - new (&buffer[idx]) obj_t(other.get_obj_(idx)); - } - } } static_circular_map& operator=(static_circular_map&& other) noexcept { - clear(); + for (size_t idx = 0; idx < other.capacity(); ++idx) { + move_if_present_helper(buffer[idx], other.buffer[idx], present[idx], other.present[idx]); + } count = other.count; present = other.present; - for (size_t idx = 0; idx < other.size(); ++idx) { - if (present[idx]) { - new (&buffer[idx]) obj_t(std::move(other.get_obj_(idx))); - } - } - clear(); + other.clear(); + return *this; } bool contains(K id) @@ -158,7 +153,7 @@ public: if (present[idx]) { return false; } - new (&buffer[idx]) obj_t(id, obj); + buffer[idx].template construct(id, obj); present[idx] = true; count++; return true; @@ -169,7 +164,7 @@ public: if (present[idx]) { return srsran::expected(std::move(obj)); } - new (&buffer[idx]) obj_t(id, std::move(obj)); + buffer[idx].template construct(id, std::move(obj)); present[idx] = true; count++; return iterator(this, idx); @@ -242,12 +237,12 @@ public: } private: - obj_t& get_obj_(size_t idx) { return reinterpret_cast(buffer[idx]); } - const obj_t& get_obj_(size_t idx) const { return reinterpret_cast(buffer[idx]); } + obj_t& get_obj_(size_t idx) { return buffer[idx].get(); } + const obj_t& get_obj_(size_t idx) const { return buffer[idx].get(); } - std::array buffer; - std::array present; - size_t count = 0; + std::array, N> buffer; + std::array present; + size_t count = 0; }; } // namespace srsran diff --git a/lib/include/srsran/adt/detail/type_storage.h b/lib/include/srsran/adt/detail/type_storage.h new file mode 100644 index 000000000..2645eaef1 --- /dev/null +++ b/lib/include/srsran/adt/detail/type_storage.h @@ -0,0 +1,76 @@ +/** + * + * \section COPYRIGHT + * + * Copyright 2013-2021 Software Radio Systems Limited + * + * By using this file, you agree to the terms and conditions set + * forth in the LICENSE file which can be found at the top level of + * the distribution. + * + */ + +#ifndef SRSRAN_TYPE_STORAGE_H +#define SRSRAN_TYPE_STORAGE_H + +#include +#include + +namespace srsran { + +template +struct type_storage { + template + void construct(Args&&... args) + { + new (&buffer) T(std::forward(args)...); + } + void destroy() { get().~T(); } + void copy_ctor(const type_storage& other) { buffer.get() = other.get(); } + void move_ctor(type_storage&& other) { buffer.get() = std::move(other.get()); } + void copy_assign(const type_storage& other) + { + if (this == &other) { + return; + } + get() = other.get(); + } + void move_assign(type_storage&& other) { get() = std::move(other.get()); } + + T& get() { return reinterpret_cast(buffer); } + const T& get() const { return reinterpret_cast(buffer); } + + typename std::aligned_storage::type buffer; +}; + +template +void copy_if_present_helper(type_storage& lhs, const type_storage& rhs, bool lhs_present, bool rhs_present) +{ + if (lhs_present and rhs_present) { + lhs.get() = rhs.get(); + } + if (lhs_present) { + lhs.destroy(); + } + if (rhs_present) { + lhs.template construct(rhs.get()); + } +} + +template +void move_if_present_helper(type_storage& lhs, type_storage& rhs, bool lhs_present, bool rhs_present) +{ + if (lhs_present and rhs_present) { + lhs.move_assign(std::move(rhs)); + } + if (lhs_present) { + lhs.destroy(); + } + if (rhs_present) { + lhs.template construct(std::move(rhs.get())); + } +} + +} // namespace srsran + +#endif // SRSRAN_TYPE_STORAGE_H diff --git a/lib/test/adt/circular_map_test.cc b/lib/test/adt/circular_map_test.cc index a8a79ab0a..aed0dd88a 100644 --- a/lib/test/adt/circular_map_test.cc +++ b/lib/test/adt/circular_map_test.cc @@ -15,7 +15,7 @@ namespace srsran { -int test_id_map() +void test_id_map() { static_circular_map myobj; TESTASSERT(myobj.size() == 0 and myobj.empty() and not myobj.full()); @@ -57,11 +57,9 @@ int test_id_map() TESTASSERT(myobj.size() == 2 and not myobj.empty() and not myobj.full()); myobj.clear(); TESTASSERT(myobj.size() == 0 and myobj.empty()); - - return SRSRAN_SUCCESS; } -int test_id_map_wraparound() +void test_id_map_wraparound() { static_circular_map mymap; @@ -81,8 +79,56 @@ int test_id_map_wraparound() TESTASSERT(not mymap.full()); TESTASSERT(mymap.insert(4, "4")); TESTASSERT(mymap.full()); +} - return SRSRAN_SUCCESS; +struct C { + C() { count++; } + ~C() { count--; } + C(C&&) { count++; } + C(const C&) = delete; + C& operator=(C&&) = default; + + static size_t count; +}; +size_t C::count = 0; + +void test_correct_destruction() +{ + TESTASSERT(C::count == 0); + { + static_circular_map circ_buffer; + TESTASSERT(C::count == 0); + TESTASSERT(circ_buffer.insert(0, C{})); + TESTASSERT(C::count == 1); + TESTASSERT(circ_buffer.insert(1, C{})); + TESTASSERT(circ_buffer.insert(2, C{})); + TESTASSERT(circ_buffer.insert(3, C{})); + TESTASSERT(C::count == 4); + TESTASSERT(not circ_buffer.insert(4, C{})); + TESTASSERT(C::count == 4); + TESTASSERT(circ_buffer.erase(1)); + TESTASSERT(C::count == 3); + TESTASSERT(not circ_buffer.contains(1)); + + std::array content{0, 2, 3}; + size_t i = 0; + for (auto& e : circ_buffer) { + TESTASSERT(content[i] == e.first); + i++; + } + + TESTASSERT(C::count == 3); + static_circular_map circ_buffer2; + circ_buffer2 = std::move(circ_buffer); + TESTASSERT(C::count == 3); + + static_circular_map circ_buffer3; + TESTASSERT(circ_buffer3.insert(1, C{})); + TESTASSERT(C::count == 4); + circ_buffer2 = std::move(circ_buffer3); + TESTASSERT(C::count == 1); + } + TESTASSERT(C::count == 0); } } // namespace srsran @@ -94,7 +140,10 @@ int main(int argc, char** argv) srsran::test_init(argc, argv); - TESTASSERT(srsran::test_id_map() == SRSRAN_SUCCESS); - TESTASSERT(srsran::test_id_map_wraparound() == SRSRAN_SUCCESS); + srsran::test_id_map(); + srsran::test_id_map_wraparound(); + srsran::test_correct_destruction(); + + printf("Success\n"); return SRSRAN_SUCCESS; }