diff --git a/lib/include/srsran/common/network_utils.h b/lib/include/srsran/common/network_utils.h index 2bd86a2d7..9ec7a19c9 100644 --- a/lib/include/srsran/common/network_utils.h +++ b/lib/include/srsran/common/network_utils.h @@ -76,6 +76,7 @@ public: bool bind_addr(const char* bind_addr_str, int port); bool connect_to(const char* dest_addr_str, int dest_port, sockaddr_in* dest_sockaddr = nullptr); + bool start_listen(); bool open_socket(net_utils::addr_family ip, net_utils::socket_type socket_type, net_utils::protocol_type protocol); int get_socket() const { return sockfd; }; @@ -87,7 +88,6 @@ protected: namespace net_utils { bool sctp_init_socket(unique_socket* socket, net_utils::socket_type socktype, const char* bind_addr_str, int bind_port); -bool sctp_init_server(unique_socket* socket, net_utils::socket_type socktype, const char* bind_addr_str, int port); } // namespace net_utils diff --git a/lib/src/common/network_utils.cc b/lib/src/common/network_utils.cc index 232c091ea..5d829b70b 100644 --- a/lib/src/common/network_utils.cc +++ b/lib/src/common/network_utils.cc @@ -236,6 +236,22 @@ bool connect_to(int fd, const char* dest_addr_str, int dest_port, sockaddr_in* d return true; } +bool start_listen(int fd) +{ + if (fd < 0) { + srslog::fetch_basic_logger(LOGSERVICE).error("Tried to listen for connections with an invalid socket."); + return false; + } + + // Listen for connections + if (listen(fd, SOMAXCONN) != 0) { + srslog::fetch_basic_logger(LOGSERVICE).error("Failed to listen to incoming SCTP connections"); + perror("listen()"); + return false; + } + return true; +} + } // namespace net_utils /******************************************** @@ -260,6 +276,18 @@ unique_socket& unique_socket::operator=(unique_socket&& other) noexcept return *this; } +bool unique_socket::open_socket(net_utils::addr_family ip_type, + net_utils::socket_type socket_type, + net_utils::protocol_type protocol) +{ + if (is_open()) { + srslog::fetch_basic_logger(LOGSERVICE).error("Socket is already open."); + return false; + } + sockfd = net_utils::open_socket(ip_type, socket_type, protocol); + return is_open(); +} + void unique_socket::close() { if (sockfd >= 0) { @@ -279,16 +307,9 @@ bool unique_socket::connect_to(const char* dest_addr_str, int dest_port, sockadd return net_utils::connect_to(sockfd, dest_addr_str, dest_port, dest_sockaddr); } -bool unique_socket::open_socket(net_utils::addr_family ip_type, - net_utils::socket_type socket_type, - net_utils::protocol_type protocol) +bool unique_socket::start_listen() { - if (is_open()) { - srslog::fetch_basic_logger(LOGSERVICE).error("Socket is already open."); - return false; - } - sockfd = net_utils::open_socket(ip_type, socket_type, protocol); - return is_open(); + return net_utils::start_listen(sockfd); } /*********************************************************************** @@ -297,36 +318,18 @@ bool unique_socket::open_socket(net_utils::addr_family ip_type, namespace net_utils { -bool sctp_init_socket(unique_socket* socket, net_utils::socket_type socktype, const char* bind_addr_str, int port) +bool sctp_init_socket(unique_socket* socket, net_utils::socket_type socktype, const char* bind_addr_str, int bind_port) { if (not socket->open_socket(net_utils::addr_family::ipv4, socktype, net_utils::protocol_type::SCTP)) { return false; } - if (not socket->bind_addr(bind_addr_str, port)) { + if (not socket->bind_addr(bind_addr_str, bind_port)) { socket->close(); return false; } return true; } -bool sctp_init_client(unique_socket* socket, net_utils::socket_type socktype, const char* bind_addr_str, int bind_port) -{ - return sctp_init_socket(socket, socktype, bind_addr_str, bind_port); -} - -bool sctp_init_server(unique_socket* socket, net_utils::socket_type socktype, const char* bind_addr_str, int port) -{ - if (not sctp_init_socket(socket, socktype, bind_addr_str, port)) { - return false; - } - // Listen for connections - if (listen(socket->fd(), SOMAXCONN) != 0) { - srslog::fetch_basic_logger(LOGSERVICE).error("Failed to listen to incoming SCTP connections"); - return false; - } - return true; -} - } // namespace net_utils /*************************************************************** diff --git a/lib/test/common/network_utils_test.cc b/lib/test/common/network_utils_test.cc index 709799a45..c63fd97e3 100644 --- a/lib/test/common/network_utils_test.cc +++ b/lib/test/common/network_utils_test.cc @@ -52,7 +52,8 @@ int test_socket_handler() const char* server_addr = "127.0.100.1"; using namespace srsran::net_utils; - TESTASSERT(sctp_init_server(&server_socket, socket_type::seqpacket, server_addr, server_port)); + TESTASSERT(sctp_init_socket(&server_socket, socket_type::seqpacket, server_addr, server_port)); + TESTASSERT(server_socket.start_listen()); logger.info("Listening from fd=%d", server_socket.fd()); TESTASSERT(sctp_init_socket(&client_socket, socket_type::seqpacket, "127.0.0.1", 0)); @@ -114,6 +115,18 @@ int test_socket_handler() return 0; } +int test_sctp_bind_error() +{ + srsran::unique_socket sock; + TESTASSERT(not srsran::net_utils::sctp_init_socket( + &sock, srsran::net_utils::socket_type::seqpacket, "1.1.1.1", 8000)); // Bogus IP address + // should not be able to bind + TESTASSERT(srsran::net_utils::sctp_init_socket( + &sock, srsran::net_utils::socket_type::seqpacket, "127.0.0.1", 8000)); // Bogus IP address + // should not be able to bind + return SRSRAN_SUCCESS; +} + int main() { auto& logger = srslog::fetch_basic_logger("S1AP", false); @@ -123,6 +136,7 @@ int main() srslog::init(); TESTASSERT(test_socket_handler() == 0); + TESTASSERT(test_sctp_bind_error() == 0); return 0; }