From 3f92ca5c47bebda67f7c0ba5a77340c2bd2bd527 Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nwpope@utu.fi>
Date: Sat, 23 Feb 2019 09:45:39 +0200
Subject: [PATCH] Refactor net code, fix tests and fix connection bugs

---
 net/include/ftl/net/socket.hpp |  48 ++++---
 net/src/net.cpp                |  14 +-
 net/src/socket.cpp             | 228 ++++++++++++++++++---------------
 net/test/net_raw.cpp           |  33 ++++-
 net/test/rpc.cpp               |   7 +-
 5 files changed, 188 insertions(+), 142 deletions(-)

diff --git a/net/include/ftl/net/socket.hpp b/net/include/ftl/net/socket.hpp
index 76fbb58c9..cea77f25a 100644
--- a/net/include/ftl/net/socket.hpp
+++ b/net/include/ftl/net/socket.hpp
@@ -22,10 +22,12 @@
 namespace ftl {
 namespace net {
 
+#pragma pack(push,1)
 struct Header {
 	uint32_t size;
 	uint32_t service;
 };
+#pragma pack(pop)
 
 class Socket {
 	public:
@@ -36,27 +38,29 @@ class Socket {
 	int close();
 
 	int send(uint32_t service, const std::string &data);
-	int send(uint32_t service, std::stringstream &data) { return send(service, data.str()); };
+	int send(uint32_t service, std::stringstream &data) {
+		return send(service, data.str()); };
 	int send(uint32_t service, void *data, int length);
 	
-	int send2(uint32_t service, const std::string &data1, const std::string &data2);
+	int send2(uint32_t service, const std::string &data1,
+			const std::string &data2);
 
 	//friend bool ftl::net::run(bool);
 
-	int _socket() const { return m_sock; };
+	int _socket() const { return sock_; };
 
-	bool isConnected() const { return m_sock != INVALID_SOCKET; };
-	bool isValid() const { return m_valid; };
-	std::string getURI() const { return m_uri; };
+	bool isConnected() const { return sock_ != INVALID_SOCKET && connected_; };
+	bool isValid() const { return valid_ && sock_ != INVALID_SOCKET; };
+	std::string getURI() const { return uri_; };
 	
 	/**
 	 * Bind a function to a RPC call name.
 	 */
 	template <typename F>
 	void bind(const std::string &name, F func) {
-		//disp_.enforce_unique_name(name);
-		disp_.bind(name, func, typename ftl::internal::func_kind_info<F>::result_kind(),
-		     typename ftl::internal::func_kind_info<F>::args_kind());
+		disp_.bind(name, func,
+			typename ftl::internal::func_kind_info<F>::result_kind(),
+		    typename ftl::internal::func_kind_info<F>::args_kind());
 	}
 	
 	/**
@@ -96,7 +100,7 @@ class Socket {
 		auto rpcid = rpcid__++;
 		auto call_obj = std::make_tuple(0,rpcid,name,args_obj);
 		
-		LOG(INFO) << "RPC " << name << "() -> " << m_uri;
+		LOG(INFO) << "RPC " << name << "() -> " << uri_;
 		
 		std::stringstream buf;
 		msgpack::pack(buf, call_obj);
@@ -107,9 +111,16 @@ class Socket {
 		send(FTL_PROTOCOL_RPC, buf.str());
 	}
 	
-	void dispatch(const std::string &b) { disp_.dispatch(b); }
+	/**
+	 * Internal handlers for specific event types.
+	 * @{
+	 */
+	void dispatchRPC(const std::string &d) { disp_.dispatch(d); }
+	void dispatchReturn(const std::string &d);
+	void handshake1(const std::string &d);
+	void handshake2(const std::string &d);
+	/** @} */
 
-	//void onMessage(sockdatahandler_t handler) { m_handler = handler; }
 	void onError(sockerrorhandler_t handler) {}
 	void onConnect(std::function<void(Socket&)> f);
 	void onDisconnect(sockdisconnecthandler_t handler) {}
@@ -118,18 +129,19 @@ class Socket {
 	void error();
 
 	private:
-	std::string m_uri;
-	int m_sock;
-	size_t m_pos;
-	char *m_buffer;
+	std::string uri_;
+	int sock_;
+	size_t pos_;
+	char *buffer_;
 	std::map<uint32_t,std::function<void(Socket&,const std::string&)>> handlers_;
 	std::vector<std::function<void(Socket&)>> connect_handlers_;
-	bool m_valid;
-	bool m_connected;
+	bool valid_;
+	bool connected_;
 	std::map<int, std::function<void(msgpack::object&)>> callbacks_;
 	ftl::net::Dispatcher disp_;
 	
 	void _connected();
+	void _updateURI();
 	
 	static int rpcid__;
 
diff --git a/net/src/net.cpp b/net/src/net.cpp
index dafb471a0..160a47ef6 100644
--- a/net/src/net.cpp
+++ b/net/src/net.cpp
@@ -57,7 +57,7 @@ static int setDescriptors() {
 
 	//Set the file descriptors for each client
 	for (auto s : sockets) {
-		if (s != nullptr && s->isConnected()) {
+		if (s != nullptr && s->isValid()) {
 			
 			if (s->_socket() > n) {
 				n = s->_socket();
@@ -78,7 +78,7 @@ shared_ptr<Listener> ftl::net::listen(const char *uri) {
 }
 
 shared_ptr<Socket> ftl::net::connect(const char *uri) {
-	shared_ptr<Socket> s(new Socket(uri));
+	shared_ptr<Socket> s(new Socket((uri == NULL) ? "" : uri));
 	sockets.push_back(s);
 	return s;
 }
@@ -142,12 +142,10 @@ bool _run(bool blocking, bool nodelay) {
 
 						if (csock != INVALID_SOCKET) {
 							auto sock = make_shared<Socket>(csock);
-							//sockets[freeclient] = sock;
+							sockets.push_back(sock);
 							
 							// Call connection handlers
 							l->connection(sock);
-							
-							sockets.push_back(std::move(sock));
 						}
 					//}
 				}
@@ -156,12 +154,12 @@ bool _run(bool blocking, bool nodelay) {
 
 		//Also check each clients socket to see if any messages or errors are waiting
 		for (auto s : sockets) {
-			if (s != NULL && s->isConnected()) {
+			if (s != NULL && s->isValid()) {
 				//If message received from this client then deal with it
 				if (FD_ISSET(s->_socket(), &sfdread)) {
 					repeat |= s->data();
-				//An error occured with this client.
-				} else if (FD_ISSET(s->_socket(), &sfderror)) {
+				}
+				if (FD_ISSET(s->_socket(), &sfderror)) {
 					s->error();
 				}
 			}
diff --git a/net/src/socket.cpp b/net/src/socket.cpp
index cae4efde7..10610cf6c 100644
--- a/net/src/socket.cpp
+++ b/net/src/socket.cpp
@@ -27,7 +27,7 @@ using namespace ftl;
 using ftl::net::Socket;
 using namespace std;
 
-static std::string hexStr(const std::string &s)
+/*static std::string hexStr(const std::string &s)
 {
 	const char *data = s.data();
 	int len = s.size();
@@ -36,7 +36,7 @@ static std::string hexStr(const std::string &s)
     for(int i=0;i<len;++i)
         ss << std::setw(2) << std::setfill('0') << (int)data[i];
     return ss.str();
-}
+}*/
 
 int Socket::rpcid__ = 0;
 
@@ -120,14 +120,39 @@ static int wsConnect(URI &uri) {
 	return 1;
 }
 
-Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) {
-	m_valid = true;
-	m_buffer = new char[BUFFER_SIZE];
-	m_connected = false;
+Socket::Socket(int s) : sock_(s), pos_(0), disp_(this) {
+	valid_ = true;
+	buffer_ = new char[BUFFER_SIZE];
+	connected_ = false;
+	
+	_updateURI();
+}
+
+Socket::Socket(const char *pUri) : uri_(pUri), pos_(0), disp_(this) {
+	// Allocate buffer
+	buffer_ = new char[BUFFER_SIZE];
+	
+	URI uri(pUri);
 	
+	valid_ = false;
+	connected_ = false;
+	sock_ = INVALID_SOCKET;
+
+	if (uri.getProtocol() == URI::SCHEME_TCP) {
+		sock_ = tcpConnect(uri);
+		valid_ = true;
+	} else if (uri.getProtocol() == URI::SCHEME_WS) {
+		wsConnect(uri);
+		LOG(ERROR) << "Websocket currently unsupported";
+	} else {
+		LOG(ERROR) << "Unrecognised connection protocol: " << pUri;
+	}
+}
+
+void Socket::_updateURI() {
 	sockaddr_storage addr;
 	int rsize = sizeof(sockaddr_storage);
-	if (getpeername(s, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) {
+	if (getpeername(sock_, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) {
 		char addrbuf[INET6_ADDRSTRLEN];
 		int port;
 		
@@ -143,39 +168,23 @@ Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) {
 			port = s->sin6_port;
 		}
 		
-		m_uri = std::string("tcp://")+addrbuf;
-		m_uri += ":";
-		m_uri += std::to_string(port);
-	}
-}
-
-Socket::Socket(const char *pUri) : m_uri(pUri), m_pos(0), disp_(this) {
-	// Allocate buffer
-	m_buffer = new char[BUFFER_SIZE];
-	
-	URI uri(pUri);
-	
-	m_valid = false;
-	m_connected = false;
-	m_sock = INVALID_SOCKET;
-
-	if (uri.getProtocol() == URI::SCHEME_TCP) {
-		m_sock = tcpConnect(uri);
-		m_valid = true;
-	} else if (uri.getProtocol() == URI::SCHEME_WS) {
-		wsConnect(uri);
-	} else {
+		// TODO verify tcp or udp etc.
+		
+		uri_ = std::string("tcp://")+addrbuf;
+		uri_ += ":";
+		uri_ += std::to_string(port);
 	}
 }
 
 int Socket::close() {
 	if (isConnected()) {
 		#ifndef WIN32
-		::close(m_sock);
+		::close(sock_);
 		#else
-		closesocket(m_sock);
+		closesocket(sock_);
 		#endif
-		m_sock = INVALID_SOCKET;
+		sock_ = INVALID_SOCKET;
+		connected_ = false;
 
 		// Attempt auto reconnect?
 	}
@@ -185,108 +194,117 @@ int Socket::close() {
 void Socket::error() {
 	int err;
 	uint32_t optlen = sizeof(err);
-	getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &err, &optlen);
-	LOG(ERROR) << "Socket: " << m_uri << " - error " << err;
+	getsockopt(sock_, SOL_SOCKET, SO_ERROR, &err, &optlen);
+	LOG(ERROR) << "Socket: " << uri_ << " - error " << err;
 }
 
 bool Socket::data() {
-	//std::cerr << "GOT SOCKET DATA" << std::endl;
-
 	//Read data from socket
 	size_t n = 0;
 	uint32_t len = 0;
 
-	if (m_pos < 4) {
-		n = 4 - m_pos;
+	if (pos_ < 4) {
+		n = 4 - pos_;
 	} else {
-		len = *(int*)m_buffer;
-		n = len+4-m_pos;
+		len = *(int*)buffer_;
+		n = len+4-pos_;
 	}
 
-	while (m_pos < len+4) {
+	while (pos_ < len+4) {
 		if (len > MAX_MESSAGE) {
 			close();
-			LOG(ERROR) << "Socket: " << m_uri << " - message attack";
-			return false; // Prevent DoS
+			LOG(ERROR) << "Socket: " << uri_ << " - message attack";
+			return false;
 		}
 
-		const int rc = recv(m_sock, m_buffer+m_pos, n, 0);
+		const int rc = recv(sock_, buffer_+pos_, n, 0);
 
 		if (rc > 0) {
-			m_pos += static_cast<size_t>(rc);
+			pos_ += static_cast<size_t>(rc);
 
-			if (m_pos < 4) {
-				n = 4 - m_pos;
+			if (pos_ < 4) {
+				n = 4 - pos_;
 			} else {
-				len = *(int*)m_buffer;
-				n = len+4-m_pos;
+				len = *(int*)buffer_;
+				n = len+4-pos_;
 			}
 		} else if (rc == EWOULDBLOCK || rc == 0) {
 			// Data not yet available
-			//std::cout << "No data to read" << std::endl;
 			return false;
 		} else {
-			LOG(ERROR) << "Socket: " << m_uri << " - error " << rc;
-			// Close socket due to error
+			LOG(ERROR) << "Socket: " << uri_ << " - error " << rc;
 			close();
 			return false;
 		}
 	}
 
-	// All data available
-	//if (m_handler) {
-		uint32_t service = ((uint32_t*)m_buffer)[1];
-		auto d = std::string(m_buffer+8, len-4);
-		//std::cerr << "DATA : " << service << " -> " << d << std::endl;
-		
-		if (service == FTL_PROTOCOL_HS1 && !m_connected) {
-			// TODO Verify data
-			std::string hs2("HELLO");
-			send(FTL_PROTOCOL_HS2, hs2);
-			LOG(INFO) << "Handshake confirmed from " << m_uri;
-			_connected();
-		} else if (service == FTL_PROTOCOL_HS2 && !m_connected) {
-			// TODO Verify data
-			LOG(INFO) << "Handshake finalised for " << m_uri;
-			_connected();
-		} else if (service == FTL_PROTOCOL_RPC) {
-			dispatch(d);
-		} else if (service == FTL_PROTOCOL_RPCRETURN) {
-			auto unpacked = msgpack::unpack(d.data(), d.size());
-			Dispatcher::response_t the_result;
-			unpacked.get().convert(the_result);
-
-			// TODO: proper validation of protocol (and responding to it)
-			// auto &&type = std::get<0>(the_call);
-			// assert(type == 0);
-
-			// auto &&id = std::get<1>(the_call);
-			auto &&id = std::get<1>(the_result);
-			//auto &&err = std::get<2>(the_result);
-			auto &&res = std::get<3>(the_result);
-			
-			std::cout << " ROSULT " << hexStr(d) << std::endl;
-			
-			if (callbacks_.count(id) > 0) {
-				LOG(INFO) << "Received return RPC value";
-				callbacks_[id](res);
-				callbacks_.erase(id);
-			} else {
-				LOG(ERROR) << "Missing RPC callback for result";
-			}
+	// Route the message...
+	uint32_t service = ((uint32_t*)buffer_)[1];
+	auto d = std::string(buffer_+8, len-4);
+	
+	if (service == FTL_PROTOCOL_HS1 && !connected_) {
+		handshake1(d);
+	} else if (service == FTL_PROTOCOL_HS2 && !connected_) {
+		handshake2(d);
+	} else if (service == FTL_PROTOCOL_RPC) {
+		dispatchRPC(d);
+	} else if (service == FTL_PROTOCOL_RPCRETURN) {
+		dispatchReturn(d);
+	} else {
+		// Lookup raw message handler
+		if (handlers_.count(service) > 0) {
+			handlers_[service](*this, d);
 		} else {
-			// Lookup raw message handler
-			if (handlers_.count(service) > 0) handlers_[service](*this, d);
-		} 
-	//}
+			LOG(ERROR) << "Unrecognised service request (" << service << ") from " << uri_;
+		}
+	} 
 
-	m_pos = 0;
+	pos_ = 0;
 
 	return true;
 }
 
+void Socket::handshake1(const std::string &d) {
+	// TODO Verify data
+	std::string hs2("HELLO");
+	send(FTL_PROTOCOL_HS2, hs2);
+	LOG(INFO) << "Handshake confirmed from " << uri_;
+	_connected();
+}
+
+void Socket::handshake2(const std::string &d) {
+	// TODO Verify data
+	LOG(INFO) << "Handshake finalised for " << uri_;
+	_connected();
+}
+
+void Socket::dispatchReturn(const std::string &d) {
+	auto unpacked = msgpack::unpack(d.data(), d.size());
+	Dispatcher::response_t the_result;
+	unpacked.get().convert(the_result);
+
+	if (std::get<0>(the_result) != 1) {
+		LOG(ERROR) << "Bad RPC return message";
+		return;
+	}
+
+	auto &&id = std::get<1>(the_result);
+	//auto &&err = std::get<2>(the_result);
+	auto &&res = std::get<3>(the_result);
+	
+	// TODO Handle error reporting...
+	
+	if (callbacks_.count(id) > 0) {
+		LOG(INFO) << "Received return RPC value";
+		callbacks_[id](res);
+		callbacks_.erase(id);
+	} else {
+		LOG(ERROR) << "Missing RPC callback for result";
+	}
+}
+
 void Socket::onConnect(std::function<void(Socket&)> f) {
-	if (m_connected) {
+	if (connected_) {
 		f(*this);
 	} else {
 		connect_handlers_.push_back(f);
@@ -294,11 +312,11 @@ void Socket::onConnect(std::function<void(Socket&)> f) {
 }
 
 void Socket::_connected() {
-	m_connected = true;
+	connected_ = true;
 	for (auto h : connect_handlers_) {
 		h(*this);
 	}
-	connect_handlers_.clear();
+	//connect_handlers_.clear();
 }
 
 void Socket::bind(uint32_t service, std::function<void(Socket&,
@@ -321,7 +339,7 @@ int Socket::send(uint32_t service, const std::string &data) {
 	vec[1].iov_base = const_cast<char*>(data.data());
 	vec[1].iov_len = data.size();
 	
-	::writev(m_sock, &vec[0], 2);
+	::writev(sock_, &vec[0], 2);
 	
 	return 0;
 }
@@ -339,7 +357,7 @@ int Socket::send2(uint32_t service, const std::string &data1, const std::string
 	vec[2].iov_base = const_cast<char*>(data2.data());
 	vec[2].iov_len = data2.size();
 	
-	::writev(m_sock, &vec[0], 3);
+	::writev(sock_, &vec[0], 3);
 	
 	return 0;
 }
@@ -349,7 +367,7 @@ Socket::~Socket() {
 	close();
 	
 	// Delete socket buffer
-	if (m_buffer) delete [] m_buffer;
-	m_buffer = NULL;
+	if (buffer_) delete [] buffer_;
+	buffer_ = NULL;
 }
 
diff --git a/net/test/net_raw.cpp b/net/test/net_raw.cpp
index 3f48834a0..5fc57e492 100644
--- a/net/test/net_raw.cpp
+++ b/net/test/net_raw.cpp
@@ -157,6 +157,18 @@ int setDescriptors() {
 	return n;
 }
 
+void mocksend(int sd, uint32_t  service, const std::string &data) {
+	//std::cout << "HEX SEND: " << hexStr(data) << std::endl;
+	char buf[8+data.size()];
+	ftl::net::Header *h = (ftl::net::Header*)&buf;
+	h->size = data.size()+4;
+	h->service = service;
+	std::memcpy(&buf[8],data.data(),data.size());
+	
+	//std::cout << "HEX SEND2: " << hexStr(fakedata[sd]) << std::endl;
+	::send(sd, buf, 8+data.size(), 0);
+}
+
 void accept_connection() {
 	int n = setDescriptors();
 
@@ -174,6 +186,7 @@ void accept_connection() {
 
 		//Finally accept this client connection.
 		csock = accept(ssock, (sockaddr*)&addr, (socklen_t*)&rsize);
+		mocksend(csock, FTL_PROTOCOL_HS1, "HELLO");
 	} else {
 
 	}
@@ -190,6 +203,7 @@ TEST_CASE("net::connect()", "[net]") {
 	SECTION("valid tcp connection using ipv4") {
 		sock = ftl::net::connect("tcp://127.0.0.1:7077");
 		REQUIRE(sock != nullptr);
+		REQUIRE(sock->isValid());
 		accept_connection();
 	}
 
@@ -212,6 +226,7 @@ TEST_CASE("net::connect()", "[net]") {
 	SECTION("null uri") {
 		sock = ftl::net::connect(NULL);
 		REQUIRE(!sock->isValid());
+		sock = nullptr;
 	}
 
 	// Disabled due to long timeout
@@ -224,12 +239,14 @@ TEST_CASE("net::connect()", "[net]") {
 
 	SECTION("incorrect dns address") {
 		sock = ftl::net::connect("tcp://xryyrrgrtgddgr.com:7077");
-		REQUIRE(sock->isValid());
+		REQUIRE(!sock->isValid());
 		REQUIRE(sock->isConnected() == false);
 		sock = nullptr;
 	}
 
 	if (sock && sock->isValid()) {
+		ftl::net::wait();
+		//sock->data();
 		REQUIRE(sock->isConnected());
 		REQUIRE(csock != INVALID_SOCKET);
 		sock->close();
@@ -243,10 +260,11 @@ TEST_CASE("net::listen()", "[net]") {
 		REQUIRE( ftl::net::listen("tcp://*:7078")->isListening() );
 
 		SECTION("can connect to listening socket") {
-			shared_ptr<Socket> sock = ftl::net::connect("tcp://127.0.0.1:7078");
+			auto sock = ftl::net::connect("tcp://127.0.0.1:7078");
 			REQUIRE(sock->isValid());
+			ftl::net::wait(); // Handshake 1
+			ftl::net::wait(); // Handshake 2
 			REQUIRE(sock->isConnected());
-			ftl::net::wait();
 
 			// TODO Need way of knowing about connection
 		}
@@ -261,6 +279,7 @@ TEST_CASE("net::listen()", "[net]") {
 		bool connected = false;
 		
 		l->onConnection([&](shared_ptr<Socket> s) {
+			ftl::net::wait(); // Wait for handshake
 			REQUIRE( s->isConnected() );
 			connected = true;
 		});
@@ -268,19 +287,18 @@ TEST_CASE("net::listen()", "[net]") {
 		auto sock = ftl::net::connect("tcp://127.0.0.1:7078");
 		ftl::net::wait();
 		REQUIRE( connected );
-		std::cout << "PRE STOP" << std::endl;
 		ftl::net::stop();
-		std::cout << "POST STOP" << std::endl;
 	}
 }
 
-TEST_CASE("Socket.onMessage()", "[net]") {
+TEST_CASE("Socket.bind(int)", "[net]") {
 	// Need a fake server...
 	init_server();
 	shared_ptr<Socket> sock = ftl::net::connect("tcp://127.0.0.1:7077");
 	REQUIRE(sock->isValid());
-	REQUIRE(sock->isConnected());
 	accept_connection();
+	ftl::net::wait(); // Wait for handshake
+	REQUIRE(sock->isConnected());
 
 	SECTION("small valid message") {
 		send_json(1, "{message: \"Hello\"}");
@@ -322,6 +340,7 @@ TEST_CASE("Socket.onMessage()", "[net]") {
 			msg++;	
 		});
 
+		ftl::net::wait();
 		ftl::net::wait();
 		REQUIRE(msg == 2);
 	}
diff --git a/net/test/rpc.cpp b/net/test/rpc.cpp
index 1f4f9d04c..d07067a83 100644
--- a/net/test/rpc.cpp
+++ b/net/test/rpc.cpp
@@ -78,7 +78,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
 		std::stringstream buf;
 		msgpack::pack(buf, call_obj);
 		
-		s->dispatch(buf.str());
+		s->dispatchRPC(buf.str());
 		REQUIRE( called );
 	}
 	
@@ -96,7 +96,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
 		std::stringstream buf;
 		msgpack::pack(buf, call_obj);
 		
-		s->dispatch(buf.str());
+		s->dispatchRPC(buf.str());
 		REQUIRE( called );
 	}
 	
@@ -115,7 +115,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
 		std::stringstream buf;
 		msgpack::pack(buf, call_obj);
 		
-		s->dispatch(buf.str());
+		s->dispatchRPC(buf.str());
 		REQUIRE( called );
 	}
 	
@@ -220,7 +220,6 @@ TEST_CASE("Socket::call+bind loop", "[rpc]") {
 		});
 		
 		int res = s->call<int>("test1", 5);
-		
 		REQUIRE( res == 10 );
 	}
 	
-- 
GitLab