diff --git a/lib/include/srslte/common/choice_type.h b/lib/include/srslte/common/choice_type.h index db3150109..6e9417319 100644 --- a/lib/include/srslte/common/choice_type.h +++ b/lib/include/srslte/common/choice_type.h @@ -211,10 +211,14 @@ struct tagged_union_t using base_t::get_buffer; using base_t::get_unsafe; - auto construct_unsafe() -> decltype(default_type(), void()) + template + void construct_emplace_unsafe(Args2&&... args) { - type_id = sizeof...(Args) - 1; - new (get_buffer()) default_type(); + using U2 = typename std::decay::type; + static_assert(type_indexer::index != invalid_idx, + "The provided type to ctor is not part of the list of possible types"); + new (get_buffer()) U2(std::forward(args)...); + type_id = type_indexer::index; } template @@ -254,12 +258,45 @@ struct tagged_union_t throw choice_details::bad_choice_access{"in get"}; } + template + const T& get() const + { + if (is()) { + return get_unsafe(); + } + throw choice_details::bad_choice_access{"in get"}; + } + + template ::type> + T& get() + { + if (is()) { + return get_unsafe(); + } + throw choice_details::bad_choice_access{"in get"}; + } + + template ::type> + const T& get() const + { + if (is()) { + return get_unsafe(); + } + throw choice_details::bad_choice_access{"in get"}; + } + template T* get_if() { return (is()) ? &get_unsafe() : nullptr; } + template + const T* get_if() const + { + return (is()) ? &get_unsafe() : nullptr; + } + template constexpr static bool can_hold_type() { @@ -283,14 +320,20 @@ public: using base_t::get_if; using base_t::is; - choice_t() noexcept { base_t::construct_unsafe(); } + template < + typename... Args2, + typename = typename std::enable_if::value>::type> + explicit choice_t(Args2&&... args) noexcept + { + base_t::template construct_emplace_unsafe(std::forward(args)...); + } choice_t(const choice_t& other) noexcept { base_t::copy_unsafe(other); } choice_t(choice_t&& other) noexcept { base_t::move_unsafe(std::move(other)); } template > - explicit choice_t(U&& u) noexcept + choice_t(U&& u) noexcept { base_t::construct_unsafe(std::forward(u)); } @@ -307,10 +350,11 @@ public: return *this; } - template - void emplace(U&& u) noexcept + template + void emplace(Args2&&... args) noexcept { - *this = std::forward(u); + base_t::dtor_unsafe(); + base_t::template construct_emplace_unsafe(std::forward(args)...); } choice_t& operator=(const choice_t& other) noexcept @@ -346,6 +390,30 @@ public: private: }; +template +bool holds_alternative(const Choice& u) +{ + return u.template is(); +} + +template +T* get_if(Choice& c) +{ + return c.template get_if(); +} + +template +const T* get_if(const choice_t& c) +{ + return c.template get_if(); +} + +template +auto get(const choice_t& c) -> decltype(c.template get()) +{ + return c.template get(); +} + template void visit(choice_t& u, Functor&& f) { diff --git a/lib/include/srslte/common/fsm.h b/lib/include/srslte/common/fsm.h new file mode 100644 index 000000000..eb210935e --- /dev/null +++ b/lib/include/srslte/common/fsm.h @@ -0,0 +1,239 @@ +/* + * Copyright 2013-2020 Software Radio Systems Limited + * + * This file is part of srsLTE. + * + * srsLTE is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of + * the License, or (at your option) any later version. + * + * srsLTE is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * A copy of the GNU Affero General Public License can be found in + * the LICENSE file in the top-level directory of this distribution + * and at http://www.gnu.org/licenses/. + * + */ + +#ifndef SRSLTE_FSM_H +#define SRSLTE_FSM_H + +#include "choice_type.h" +#include +#include +#include +#include +#include +#include + +namespace srslte { + +// using same_state = mpark::monostate; +struct same_state { +}; +template +using state_list = choice_t; + +namespace fsm_details { + +//! Visitor to get a state's name string +struct state_name_visitor { + template + void operator()(State&& s) + { + name = s.name(); + } + const char* name = "invalid state"; +}; + +template +struct variant_convert { + template + void operator()(State&& s) + { + static_assert(not std::is_same::type, typename std::decay::type>::value, + "State cannot transition to itself.\n"); + *v = s; + } + TargetVariant* v; + PrevState* p; +}; + +struct fsm_helper { + //! Stayed in same state + template + static void handle_state_transition(FSM* f, same_state s, PrevState* p) + { + // do nothing + } + //! TargetState is type-erased. Apply its stored type to the fsm current state + template + static void handle_state_transition(FSM* f, choice_t& s, PrevState* p) + { + fsm_details::variant_convertstates), PrevState> visitor{.v = &f->states, .p = p}; + s.visit(visitor); + } + //! Simple state transition in FSM + template + static auto handle_state_transition(FSM* f, State& s, PrevState* p) -> decltype(f->states = s, void()) + { + static_assert(not std::is_same::value, "State cannot transition to itself.\n"); + f->states = s; + } + //! State not present in current FSM. Attempt state transition in parent FSM in the case of NestedFSM + template + static void handle_state_transition(FSM* f, Args&&... args) + { + static_assert(FSM::is_nested, "State is not present in the FSM list of valid states"); + handle_state_transition(f->parent_fsm()->derived(), args...); + } + + //! Trigger Event, that will result in a state transition + template + struct trigger_visitor { + trigger_visitor(FSM* f_, Event&& ev_) : f(f_), ev(std::forward(ev_)) {} + + template + void operator()(State& s) + { + call_trigger(s); + } + + template + using NextState = decltype(std::declval().react(std::declval(), std::declval())); + + template + auto call_trigger(State& s) -> NextState + { + using next_state = NextState; + static_assert(not std::is_same::value, "State cannot transition to itself.\n"); + auto target_state = f->react(s, std::move(ev)); + fsm_helper::handle_state_transition(f, target_state, &s); + return target_state; + } + template + auto call_trigger(State& s) -> decltype(std::declval().trigger(std::declval())) + { + s.trigger(std::move(ev)); + } + same_state call_trigger(...) + { + // do nothing if no react was found + return same_state{}; + } + + FSM* f; + Event ev; + }; +}; + +} // namespace fsm_details + +//! Base class for states and FSMs +class state_t +{ +public: + state_t() = default; + // // forbid copies, allow move + // state_t(const state_t&) = delete; + // state_t(state_t&&) noexcept = default; + // state_t& operator=(const state_t&) = delete; + // state_t& operator=(state_t&&) noexcept = default; + + virtual const char* name() const = 0; +}; + +template +class fsm_t +{ +public: + // get access to derived protected members from the base + class derived_view : public Derived + { + public: + using Derived::react; + using Derived::states; + }; + + static const bool is_nested = false; + + virtual const char* name() const = 0; + + // Push Events to FSM + template + void trigger(Ev&& e) + { + fwd_trigger(std::forward(e)); + } + + template + bool is_in_state() const + { + return derived()->states.template is(); + } + + template + const State* get_state() const + { + return srslte::get_if(derived()->states); + } + + const char* get_state_name() const + { + fsm_details::state_name_visitor visitor{}; + derived()->states.visit(visitor); + return visitor.name; + } + +protected: + friend struct fsm_details::fsm_helper; + + // Forward an event to FSM states and handle transition return + template + void fwd_trigger(Ev&& e) + { + fsm_details::fsm_helper::trigger_visitor visitor{derived(), std::forward(e)}; + derived()->states.visit(visitor); + } + + template + void change_state(State& s) + { + derived()->states = std::move(s); + } + + // Access to CRTP derived class + derived_view* derived() { return static_cast(this); } + const derived_view* derived() const { return static_cast(this); } +}; + +template +class nested_fsm_t : public fsm_t +{ + using base_t = fsm_t; + using parent_t = ParentFSM; + using parent_view = typename parent_t::derived_view; + +public: + static const bool is_nested = true; + + explicit nested_fsm_t(ParentFSM* parent_fsm_) : fsm_ptr(parent_fsm_) {} + + // Get pointer to outer FSM in case of HSM + const parent_t* parent_fsm() const { return fsm_ptr; } + parent_t* parent_fsm() { return fsm_ptr; } + +protected: + friend struct fsm_details::fsm_helper; + using parent_fsm_t = ParentFSM; + + ParentFSM* fsm_ptr = nullptr; +}; + +} // namespace srslte + +#endif // SRSLTE_FSM_H diff --git a/lib/test/common/CMakeLists.txt b/lib/test/common/CMakeLists.txt index f1613faa0..9e6319f1a 100644 --- a/lib/test/common/CMakeLists.txt +++ b/lib/test/common/CMakeLists.txt @@ -96,6 +96,10 @@ add_executable(tti_point_test tti_point_test.cc) target_link_libraries(tti_point_test srslte_common) add_test(tti_point_test tti_point_test) +add_executable(fsm_test fsm_test.cc) +target_link_libraries(fsm_test srslte_common) +add_test(fsm_test fsm_test) + add_executable(choice_type_test choice_type_test.cc) target_link_libraries(choice_type_test srslte_common) add_test(choice_type_test choice_type_test) diff --git a/lib/test/common/choice_type_test.cc b/lib/test/common/choice_type_test.cc index 880dabfba..dde9cdc61 100644 --- a/lib/test/common/choice_type_test.cc +++ b/lib/test/common/choice_type_test.cc @@ -115,6 +115,7 @@ int test_choice() choice_t c, c2{i}, c3{c0}; TESTASSERT(c.is()); TESTASSERT(c2.is() and c2.get() == i and *c2.get_if() == i); + TESTASSERT(c2.get<1>() == c2.get()); TESTASSERT(c3.is()); TESTASSERT(C::counter == 2); diff --git a/lib/test/common/fsm_test.cc b/lib/test/common/fsm_test.cc new file mode 100644 index 000000000..fe48d0cd9 --- /dev/null +++ b/lib/test/common/fsm_test.cc @@ -0,0 +1,183 @@ +/* + * Copyright 2013-2020 Software Radio Systems Limited + * + * This file is part of srsLTE. + * + * srsLTE is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of + * the License, or (at your option) any later version. + * + * srsLTE is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * A copy of the GNU Affero General Public License can be found in + * the LICENSE file in the top-level directory of this distribution + * and at http://www.gnu.org/licenses/. + * + */ + +#include "srslte/common/fsm.h" +#include "srslte/common/test_common.h" + +srslte::log_ref test_log{"TEST"}; + +///////////////////////////// + +// Events +struct ev1 { +}; +struct ev2 { +}; + +class fsm1 : public srslte::fsm_t +{ +public: + uint32_t idle_enter_counter = 0, state1_enter_counter = 0; + uint32_t foo_counter = 0; + + fsm1() : states(idle_st{this}) {} + + const char* name() const override { return "fsm1"; } + + // idle state + struct idle_st : srslte::state_t { + idle_st(fsm1* f) + { + test_log->info("fsm1::%s::enter called\n", name()); + f->idle_enter_counter++; + } + ~idle_st() { test_log->info("fsm1::%s::exit called\n", name()); } + const char* name() const final { return "idle"; } + }; + + // simple state + class state1 : public srslte::state_t + { + public: + state1(fsm1* f) + { + test_log->info("fsm1::%s::enter called\n", name()); + f->state1_enter_counter++; + } + const char* name() const final { return "state1"; } + }; + + // this state is another FSM + class fsm2 : public srslte::nested_fsm_t + { + public: + struct state_inner : public srslte::state_t { + const char* name() const final { return "state_inner"; } + state_inner() { test_log->info("fsm2::%s::enter called\n", name()); } + ~state_inner() { test_log->info("fsm2::%s::exit called\n", name()); } + }; + + fsm2(fsm1* f_) : nested_fsm_t(f_) { test_log->info("%s::enter called\n", name()); } + ~fsm2() { test_log->info("%s::exit called\n", name()); } + const char* name() const final { return "fsm2"; } + + protected: + // FSM2 transitions + auto react(state_inner& s, ev1 e) -> srslte::same_state; + auto react(state_inner& s, ev2 e) -> state1; + + // list of states + srslte::state_list states; + }; + +protected: + // transitions + auto react(idle_st& s, ev1 e) -> state1; + auto react(state1& s, ev1 e) -> fsm2; + auto react(state1& s, ev2 e) -> srslte::state_list; + + void foo(ev1 e) { foo_counter++; } + + // list of states + srslte::state_list states{idle_st{this}}; +}; + +// FSM event handlers +auto fsm1::fsm2::react(state_inner& s, ev1 e) -> srslte::same_state +{ + test_log->info("fsm2::state_inner::react called\n"); + return {}; +} + +auto fsm1::fsm2::react(state_inner& s, ev2 e) -> state1 +{ + test_log->info("fsm2::state_inner::react called\n"); + return state1{parent_fsm()}; +} + +auto fsm1::react(idle_st& s, ev1 e) -> state1 +{ + test_log->info("fsm1::%s::react called\n", s.name()); + foo(e); + return state1{this}; +} +auto fsm1::react(state1& s, ev1 e) -> fsm2 +{ + test_log->info("fsm1::%s::react called\n", s.name()); + return fsm2{this}; +} +auto fsm1::react(state1& s, ev2 e) -> srslte::state_list +{ + test_log->info("fsm1::%s::react called\n", s.name()); + return idle_st{this}; +} + +int test_hsm() +{ + test_log->set_level(srslte::LOG_LEVEL_INFO); + + fsm1 f; + TESTASSERT(std::string{f.name()} == "fsm1"); + TESTASSERT(std::string{f.get_state_name()} == "idle"); + TESTASSERT(f.is_in_state()); + TESTASSERT(f.foo_counter == 0); + TESTASSERT(f.idle_enter_counter == 1); + + // Moving Idle -> State1 + ev1 e; + f.trigger(e); + TESTASSERT(std::string{f.get_state_name()} == "state1"); + TESTASSERT(f.is_in_state()); + + // Moving State1 -> fsm2 + f.trigger(e); + TESTASSERT(std::string{f.get_state_name()} == "fsm2"); + TESTASSERT(f.is_in_state()); + TESTASSERT(std::string{f.get_state()->get_state_name()} == "state_inner"); + + // Fsm2 does not listen to ev1 + f.trigger(e); + TESTASSERT(std::string{f.get_state_name()} == "fsm2"); + TESTASSERT(f.is_in_state()); + TESTASSERT(std::string{f.get_state()->get_state_name()} == "state_inner"); + + // Moving fsm2 -> state1 + f.trigger(ev2{}); + TESTASSERT(std::string{f.get_state_name()} == "state1"); + TESTASSERT(f.is_in_state()); + TESTASSERT(f.state1_enter_counter == 2); + + // Moving state1 -> idle + f.trigger(ev2{}); + TESTASSERT(std::string{f.get_state_name()} == "idle"); + TESTASSERT(f.is_in_state()); + TESTASSERT(f.foo_counter == 1); + TESTASSERT(f.idle_enter_counter == 2); + + return SRSLTE_SUCCESS; +} + +int main() +{ + TESTASSERT(test_hsm() == SRSLTE_SUCCESS); + + return SRSLTE_SUCCESS; +}