From 3f92ca5c47bebda67f7c0ba5a77340c2bd2bd527 Mon Sep 17 00:00:00 2001 From: Nicolas Pope <nwpope@utu.fi> Date: Sat, 23 Feb 2019 09:45:39 +0200 Subject: [PATCH] Refactor net code, fix tests and fix connection bugs --- net/include/ftl/net/socket.hpp | 48 ++++--- net/src/net.cpp | 14 +- net/src/socket.cpp | 228 ++++++++++++++++++--------------- net/test/net_raw.cpp | 33 ++++- net/test/rpc.cpp | 7 +- 5 files changed, 188 insertions(+), 142 deletions(-) diff --git a/net/include/ftl/net/socket.hpp b/net/include/ftl/net/socket.hpp index 76fbb58c9..cea77f25a 100644 --- a/net/include/ftl/net/socket.hpp +++ b/net/include/ftl/net/socket.hpp @@ -22,10 +22,12 @@ namespace ftl { namespace net { +#pragma pack(push,1) struct Header { uint32_t size; uint32_t service; }; +#pragma pack(pop) class Socket { public: @@ -36,27 +38,29 @@ class Socket { int close(); int send(uint32_t service, const std::string &data); - int send(uint32_t service, std::stringstream &data) { return send(service, data.str()); }; + int send(uint32_t service, std::stringstream &data) { + return send(service, data.str()); }; int send(uint32_t service, void *data, int length); - int send2(uint32_t service, const std::string &data1, const std::string &data2); + int send2(uint32_t service, const std::string &data1, + const std::string &data2); //friend bool ftl::net::run(bool); - int _socket() const { return m_sock; }; + int _socket() const { return sock_; }; - bool isConnected() const { return m_sock != INVALID_SOCKET; }; - bool isValid() const { return m_valid; }; - std::string getURI() const { return m_uri; }; + bool isConnected() const { return sock_ != INVALID_SOCKET && connected_; }; + bool isValid() const { return valid_ && sock_ != INVALID_SOCKET; }; + std::string getURI() const { return uri_; }; /** * Bind a function to a RPC call name. */ template <typename F> void bind(const std::string &name, F func) { - //disp_.enforce_unique_name(name); - disp_.bind(name, func, typename ftl::internal::func_kind_info<F>::result_kind(), - typename ftl::internal::func_kind_info<F>::args_kind()); + disp_.bind(name, func, + typename ftl::internal::func_kind_info<F>::result_kind(), + typename ftl::internal::func_kind_info<F>::args_kind()); } /** @@ -96,7 +100,7 @@ class Socket { auto rpcid = rpcid__++; auto call_obj = std::make_tuple(0,rpcid,name,args_obj); - LOG(INFO) << "RPC " << name << "() -> " << m_uri; + LOG(INFO) << "RPC " << name << "() -> " << uri_; std::stringstream buf; msgpack::pack(buf, call_obj); @@ -107,9 +111,16 @@ class Socket { send(FTL_PROTOCOL_RPC, buf.str()); } - void dispatch(const std::string &b) { disp_.dispatch(b); } + /** + * Internal handlers for specific event types. + * @{ + */ + void dispatchRPC(const std::string &d) { disp_.dispatch(d); } + void dispatchReturn(const std::string &d); + void handshake1(const std::string &d); + void handshake2(const std::string &d); + /** @} */ - //void onMessage(sockdatahandler_t handler) { m_handler = handler; } void onError(sockerrorhandler_t handler) {} void onConnect(std::function<void(Socket&)> f); void onDisconnect(sockdisconnecthandler_t handler) {} @@ -118,18 +129,19 @@ class Socket { void error(); private: - std::string m_uri; - int m_sock; - size_t m_pos; - char *m_buffer; + std::string uri_; + int sock_; + size_t pos_; + char *buffer_; std::map<uint32_t,std::function<void(Socket&,const std::string&)>> handlers_; std::vector<std::function<void(Socket&)>> connect_handlers_; - bool m_valid; - bool m_connected; + bool valid_; + bool connected_; std::map<int, std::function<void(msgpack::object&)>> callbacks_; ftl::net::Dispatcher disp_; void _connected(); + void _updateURI(); static int rpcid__; diff --git a/net/src/net.cpp b/net/src/net.cpp index dafb471a0..160a47ef6 100644 --- a/net/src/net.cpp +++ b/net/src/net.cpp @@ -57,7 +57,7 @@ static int setDescriptors() { //Set the file descriptors for each client for (auto s : sockets) { - if (s != nullptr && s->isConnected()) { + if (s != nullptr && s->isValid()) { if (s->_socket() > n) { n = s->_socket(); @@ -78,7 +78,7 @@ shared_ptr<Listener> ftl::net::listen(const char *uri) { } shared_ptr<Socket> ftl::net::connect(const char *uri) { - shared_ptr<Socket> s(new Socket(uri)); + shared_ptr<Socket> s(new Socket((uri == NULL) ? "" : uri)); sockets.push_back(s); return s; } @@ -142,12 +142,10 @@ bool _run(bool blocking, bool nodelay) { if (csock != INVALID_SOCKET) { auto sock = make_shared<Socket>(csock); - //sockets[freeclient] = sock; + sockets.push_back(sock); // Call connection handlers l->connection(sock); - - sockets.push_back(std::move(sock)); } //} } @@ -156,12 +154,12 @@ bool _run(bool blocking, bool nodelay) { //Also check each clients socket to see if any messages or errors are waiting for (auto s : sockets) { - if (s != NULL && s->isConnected()) { + if (s != NULL && s->isValid()) { //If message received from this client then deal with it if (FD_ISSET(s->_socket(), &sfdread)) { repeat |= s->data(); - //An error occured with this client. - } else if (FD_ISSET(s->_socket(), &sfderror)) { + } + if (FD_ISSET(s->_socket(), &sfderror)) { s->error(); } } diff --git a/net/src/socket.cpp b/net/src/socket.cpp index cae4efde7..10610cf6c 100644 --- a/net/src/socket.cpp +++ b/net/src/socket.cpp @@ -27,7 +27,7 @@ using namespace ftl; using ftl::net::Socket; using namespace std; -static std::string hexStr(const std::string &s) +/*static std::string hexStr(const std::string &s) { const char *data = s.data(); int len = s.size(); @@ -36,7 +36,7 @@ static std::string hexStr(const std::string &s) for(int i=0;i<len;++i) ss << std::setw(2) << std::setfill('0') << (int)data[i]; return ss.str(); -} +}*/ int Socket::rpcid__ = 0; @@ -120,14 +120,39 @@ static int wsConnect(URI &uri) { return 1; } -Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) { - m_valid = true; - m_buffer = new char[BUFFER_SIZE]; - m_connected = false; +Socket::Socket(int s) : sock_(s), pos_(0), disp_(this) { + valid_ = true; + buffer_ = new char[BUFFER_SIZE]; + connected_ = false; + + _updateURI(); +} + +Socket::Socket(const char *pUri) : uri_(pUri), pos_(0), disp_(this) { + // Allocate buffer + buffer_ = new char[BUFFER_SIZE]; + + URI uri(pUri); + valid_ = false; + connected_ = false; + sock_ = INVALID_SOCKET; + + if (uri.getProtocol() == URI::SCHEME_TCP) { + sock_ = tcpConnect(uri); + valid_ = true; + } else if (uri.getProtocol() == URI::SCHEME_WS) { + wsConnect(uri); + LOG(ERROR) << "Websocket currently unsupported"; + } else { + LOG(ERROR) << "Unrecognised connection protocol: " << pUri; + } +} + +void Socket::_updateURI() { sockaddr_storage addr; int rsize = sizeof(sockaddr_storage); - if (getpeername(s, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) { + if (getpeername(sock_, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) { char addrbuf[INET6_ADDRSTRLEN]; int port; @@ -143,39 +168,23 @@ Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) { port = s->sin6_port; } - m_uri = std::string("tcp://")+addrbuf; - m_uri += ":"; - m_uri += std::to_string(port); - } -} - -Socket::Socket(const char *pUri) : m_uri(pUri), m_pos(0), disp_(this) { - // Allocate buffer - m_buffer = new char[BUFFER_SIZE]; - - URI uri(pUri); - - m_valid = false; - m_connected = false; - m_sock = INVALID_SOCKET; - - if (uri.getProtocol() == URI::SCHEME_TCP) { - m_sock = tcpConnect(uri); - m_valid = true; - } else if (uri.getProtocol() == URI::SCHEME_WS) { - wsConnect(uri); - } else { + // TODO verify tcp or udp etc. + + uri_ = std::string("tcp://")+addrbuf; + uri_ += ":"; + uri_ += std::to_string(port); } } int Socket::close() { if (isConnected()) { #ifndef WIN32 - ::close(m_sock); + ::close(sock_); #else - closesocket(m_sock); + closesocket(sock_); #endif - m_sock = INVALID_SOCKET; + sock_ = INVALID_SOCKET; + connected_ = false; // Attempt auto reconnect? } @@ -185,108 +194,117 @@ int Socket::close() { void Socket::error() { int err; uint32_t optlen = sizeof(err); - getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &err, &optlen); - LOG(ERROR) << "Socket: " << m_uri << " - error " << err; + getsockopt(sock_, SOL_SOCKET, SO_ERROR, &err, &optlen); + LOG(ERROR) << "Socket: " << uri_ << " - error " << err; } bool Socket::data() { - //std::cerr << "GOT SOCKET DATA" << std::endl; - //Read data from socket size_t n = 0; uint32_t len = 0; - if (m_pos < 4) { - n = 4 - m_pos; + if (pos_ < 4) { + n = 4 - pos_; } else { - len = *(int*)m_buffer; - n = len+4-m_pos; + len = *(int*)buffer_; + n = len+4-pos_; } - while (m_pos < len+4) { + while (pos_ < len+4) { if (len > MAX_MESSAGE) { close(); - LOG(ERROR) << "Socket: " << m_uri << " - message attack"; - return false; // Prevent DoS + LOG(ERROR) << "Socket: " << uri_ << " - message attack"; + return false; } - const int rc = recv(m_sock, m_buffer+m_pos, n, 0); + const int rc = recv(sock_, buffer_+pos_, n, 0); if (rc > 0) { - m_pos += static_cast<size_t>(rc); + pos_ += static_cast<size_t>(rc); - if (m_pos < 4) { - n = 4 - m_pos; + if (pos_ < 4) { + n = 4 - pos_; } else { - len = *(int*)m_buffer; - n = len+4-m_pos; + len = *(int*)buffer_; + n = len+4-pos_; } } else if (rc == EWOULDBLOCK || rc == 0) { // Data not yet available - //std::cout << "No data to read" << std::endl; return false; } else { - LOG(ERROR) << "Socket: " << m_uri << " - error " << rc; - // Close socket due to error + LOG(ERROR) << "Socket: " << uri_ << " - error " << rc; close(); return false; } } - // All data available - //if (m_handler) { - uint32_t service = ((uint32_t*)m_buffer)[1]; - auto d = std::string(m_buffer+8, len-4); - //std::cerr << "DATA : " << service << " -> " << d << std::endl; - - if (service == FTL_PROTOCOL_HS1 && !m_connected) { - // TODO Verify data - std::string hs2("HELLO"); - send(FTL_PROTOCOL_HS2, hs2); - LOG(INFO) << "Handshake confirmed from " << m_uri; - _connected(); - } else if (service == FTL_PROTOCOL_HS2 && !m_connected) { - // TODO Verify data - LOG(INFO) << "Handshake finalised for " << m_uri; - _connected(); - } else if (service == FTL_PROTOCOL_RPC) { - dispatch(d); - } else if (service == FTL_PROTOCOL_RPCRETURN) { - auto unpacked = msgpack::unpack(d.data(), d.size()); - Dispatcher::response_t the_result; - unpacked.get().convert(the_result); - - // TODO: proper validation of protocol (and responding to it) - // auto &&type = std::get<0>(the_call); - // assert(type == 0); - - // auto &&id = std::get<1>(the_call); - auto &&id = std::get<1>(the_result); - //auto &&err = std::get<2>(the_result); - auto &&res = std::get<3>(the_result); - - std::cout << " ROSULT " << hexStr(d) << std::endl; - - if (callbacks_.count(id) > 0) { - LOG(INFO) << "Received return RPC value"; - callbacks_[id](res); - callbacks_.erase(id); - } else { - LOG(ERROR) << "Missing RPC callback for result"; - } + // Route the message... + uint32_t service = ((uint32_t*)buffer_)[1]; + auto d = std::string(buffer_+8, len-4); + + if (service == FTL_PROTOCOL_HS1 && !connected_) { + handshake1(d); + } else if (service == FTL_PROTOCOL_HS2 && !connected_) { + handshake2(d); + } else if (service == FTL_PROTOCOL_RPC) { + dispatchRPC(d); + } else if (service == FTL_PROTOCOL_RPCRETURN) { + dispatchReturn(d); + } else { + // Lookup raw message handler + if (handlers_.count(service) > 0) { + handlers_[service](*this, d); } else { - // Lookup raw message handler - if (handlers_.count(service) > 0) handlers_[service](*this, d); - } - //} + LOG(ERROR) << "Unrecognised service request (" << service << ") from " << uri_; + } + } - m_pos = 0; + pos_ = 0; return true; } +void Socket::handshake1(const std::string &d) { + // TODO Verify data + std::string hs2("HELLO"); + send(FTL_PROTOCOL_HS2, hs2); + LOG(INFO) << "Handshake confirmed from " << uri_; + _connected(); +} + +void Socket::handshake2(const std::string &d) { + // TODO Verify data + LOG(INFO) << "Handshake finalised for " << uri_; + _connected(); +} + +void Socket::dispatchReturn(const std::string &d) { + auto unpacked = msgpack::unpack(d.data(), d.size()); + Dispatcher::response_t the_result; + unpacked.get().convert(the_result); + + if (std::get<0>(the_result) != 1) { + LOG(ERROR) << "Bad RPC return message"; + return; + } + + auto &&id = std::get<1>(the_result); + //auto &&err = std::get<2>(the_result); + auto &&res = std::get<3>(the_result); + + // TODO Handle error reporting... + + if (callbacks_.count(id) > 0) { + LOG(INFO) << "Received return RPC value"; + callbacks_[id](res); + callbacks_.erase(id); + } else { + LOG(ERROR) << "Missing RPC callback for result"; + } +} + void Socket::onConnect(std::function<void(Socket&)> f) { - if (m_connected) { + if (connected_) { f(*this); } else { connect_handlers_.push_back(f); @@ -294,11 +312,11 @@ void Socket::onConnect(std::function<void(Socket&)> f) { } void Socket::_connected() { - m_connected = true; + connected_ = true; for (auto h : connect_handlers_) { h(*this); } - connect_handlers_.clear(); + //connect_handlers_.clear(); } void Socket::bind(uint32_t service, std::function<void(Socket&, @@ -321,7 +339,7 @@ int Socket::send(uint32_t service, const std::string &data) { vec[1].iov_base = const_cast<char*>(data.data()); vec[1].iov_len = data.size(); - ::writev(m_sock, &vec[0], 2); + ::writev(sock_, &vec[0], 2); return 0; } @@ -339,7 +357,7 @@ int Socket::send2(uint32_t service, const std::string &data1, const std::string vec[2].iov_base = const_cast<char*>(data2.data()); vec[2].iov_len = data2.size(); - ::writev(m_sock, &vec[0], 3); + ::writev(sock_, &vec[0], 3); return 0; } @@ -349,7 +367,7 @@ Socket::~Socket() { close(); // Delete socket buffer - if (m_buffer) delete [] m_buffer; - m_buffer = NULL; + if (buffer_) delete [] buffer_; + buffer_ = NULL; } diff --git a/net/test/net_raw.cpp b/net/test/net_raw.cpp index 3f48834a0..5fc57e492 100644 --- a/net/test/net_raw.cpp +++ b/net/test/net_raw.cpp @@ -157,6 +157,18 @@ int setDescriptors() { return n; } +void mocksend(int sd, uint32_t service, const std::string &data) { + //std::cout << "HEX SEND: " << hexStr(data) << std::endl; + char buf[8+data.size()]; + ftl::net::Header *h = (ftl::net::Header*)&buf; + h->size = data.size()+4; + h->service = service; + std::memcpy(&buf[8],data.data(),data.size()); + + //std::cout << "HEX SEND2: " << hexStr(fakedata[sd]) << std::endl; + ::send(sd, buf, 8+data.size(), 0); +} + void accept_connection() { int n = setDescriptors(); @@ -174,6 +186,7 @@ void accept_connection() { //Finally accept this client connection. csock = accept(ssock, (sockaddr*)&addr, (socklen_t*)&rsize); + mocksend(csock, FTL_PROTOCOL_HS1, "HELLO"); } else { } @@ -190,6 +203,7 @@ TEST_CASE("net::connect()", "[net]") { SECTION("valid tcp connection using ipv4") { sock = ftl::net::connect("tcp://127.0.0.1:7077"); REQUIRE(sock != nullptr); + REQUIRE(sock->isValid()); accept_connection(); } @@ -212,6 +226,7 @@ TEST_CASE("net::connect()", "[net]") { SECTION("null uri") { sock = ftl::net::connect(NULL); REQUIRE(!sock->isValid()); + sock = nullptr; } // Disabled due to long timeout @@ -224,12 +239,14 @@ TEST_CASE("net::connect()", "[net]") { SECTION("incorrect dns address") { sock = ftl::net::connect("tcp://xryyrrgrtgddgr.com:7077"); - REQUIRE(sock->isValid()); + REQUIRE(!sock->isValid()); REQUIRE(sock->isConnected() == false); sock = nullptr; } if (sock && sock->isValid()) { + ftl::net::wait(); + //sock->data(); REQUIRE(sock->isConnected()); REQUIRE(csock != INVALID_SOCKET); sock->close(); @@ -243,10 +260,11 @@ TEST_CASE("net::listen()", "[net]") { REQUIRE( ftl::net::listen("tcp://*:7078")->isListening() ); SECTION("can connect to listening socket") { - shared_ptr<Socket> sock = ftl::net::connect("tcp://127.0.0.1:7078"); + auto sock = ftl::net::connect("tcp://127.0.0.1:7078"); REQUIRE(sock->isValid()); + ftl::net::wait(); // Handshake 1 + ftl::net::wait(); // Handshake 2 REQUIRE(sock->isConnected()); - ftl::net::wait(); // TODO Need way of knowing about connection } @@ -261,6 +279,7 @@ TEST_CASE("net::listen()", "[net]") { bool connected = false; l->onConnection([&](shared_ptr<Socket> s) { + ftl::net::wait(); // Wait for handshake REQUIRE( s->isConnected() ); connected = true; }); @@ -268,19 +287,18 @@ TEST_CASE("net::listen()", "[net]") { auto sock = ftl::net::connect("tcp://127.0.0.1:7078"); ftl::net::wait(); REQUIRE( connected ); - std::cout << "PRE STOP" << std::endl; ftl::net::stop(); - std::cout << "POST STOP" << std::endl; } } -TEST_CASE("Socket.onMessage()", "[net]") { +TEST_CASE("Socket.bind(int)", "[net]") { // Need a fake server... init_server(); shared_ptr<Socket> sock = ftl::net::connect("tcp://127.0.0.1:7077"); REQUIRE(sock->isValid()); - REQUIRE(sock->isConnected()); accept_connection(); + ftl::net::wait(); // Wait for handshake + REQUIRE(sock->isConnected()); SECTION("small valid message") { send_json(1, "{message: \"Hello\"}"); @@ -322,6 +340,7 @@ TEST_CASE("Socket.onMessage()", "[net]") { msg++; }); + ftl::net::wait(); ftl::net::wait(); REQUIRE(msg == 2); } diff --git a/net/test/rpc.cpp b/net/test/rpc.cpp index 1f4f9d04c..d07067a83 100644 --- a/net/test/rpc.cpp +++ b/net/test/rpc.cpp @@ -78,7 +78,7 @@ TEST_CASE("Socket::bind()", "[rpc]") { std::stringstream buf; msgpack::pack(buf, call_obj); - s->dispatch(buf.str()); + s->dispatchRPC(buf.str()); REQUIRE( called ); } @@ -96,7 +96,7 @@ TEST_CASE("Socket::bind()", "[rpc]") { std::stringstream buf; msgpack::pack(buf, call_obj); - s->dispatch(buf.str()); + s->dispatchRPC(buf.str()); REQUIRE( called ); } @@ -115,7 +115,7 @@ TEST_CASE("Socket::bind()", "[rpc]") { std::stringstream buf; msgpack::pack(buf, call_obj); - s->dispatch(buf.str()); + s->dispatchRPC(buf.str()); REQUIRE( called ); } @@ -220,7 +220,6 @@ TEST_CASE("Socket::call+bind loop", "[rpc]") { }); int res = s->call<int>("test1", 5); - REQUIRE( res == 10 ); } -- GitLab