From cc6b15d2172626913aab3406a871b5414425d43c Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nwpope@utu.fi>
Date: Tue, 26 Feb 2019 12:40:20 +0200
Subject: [PATCH] Rework protocol handshake

---
 net/include/ftl/net.hpp          |  3 +
 net/include/ftl/net/protocol.hpp | 18 +++---
 net/include/ftl/net/socket.hpp   |  4 +-
 net/src/listener.cpp             |  2 +
 net/src/net.cpp                  |  8 +++
 net/src/protocol.cpp             |  6 +-
 net/src/socket.cpp               | 97 +++++++++++++++-----------------
 net/test/net_integration.cpp     |  5 +-
 net/test/protocol_unit.cpp       |  2 +-
 net/test/socket_unit.cpp         |  6 +-
 10 files changed, 78 insertions(+), 73 deletions(-)

diff --git a/net/include/ftl/net.hpp b/net/include/ftl/net.hpp
index 9d1178100..0647fb3f6 100644
--- a/net/include/ftl/net.hpp
+++ b/net/include/ftl/net.hpp
@@ -2,6 +2,7 @@
 #define _FTL_NET_HPP_
 
 #include <memory>
+#include <functional>
 
 namespace ftl {
 namespace net {
@@ -39,6 +40,8 @@ bool run(bool async=false);
  */
 bool wait();
 
+void wait(std::function<bool(void)>);
+
 /**
  * Check and process any waiting messages, but do not block if there are none.
  */
diff --git a/net/include/ftl/net/protocol.hpp b/net/include/ftl/net/protocol.hpp
index 18d6a5918..fc4925e29 100644
--- a/net/include/ftl/net/protocol.hpp
+++ b/net/include/ftl/net/protocol.hpp
@@ -28,13 +28,15 @@ struct Header {
 };
 
 struct Handshake {
-	uint64_t proto;			// The protocol the other party is expected to use.
-	char peerid[16];		// GUID for the origin peer.
-	char reserved_[32];		// RESERVED, must be 0.
+	uint64_t magic;
+	uint32_t name_size;
+	uint32_t proto_size;
 };
 
 #pragma pack(pop)
 
+static const uint64_t MAGIC = 0x1099340053640912;
+
 /**
  * Each instance of this Protocol class represents a specific protocol. A
  * protocol is a set of RPC bindings and raw message handlers. A protocol is
@@ -47,7 +49,7 @@ class Protocol {
 	friend class Socket;
 	
 	public:
-	Protocol(uint64_t id);
+	Protocol(const std::string &id);
 	~Protocol();
 	
 	/**
@@ -63,9 +65,9 @@ class Protocol {
 			
 	// broadcast?
 	
-	uint64_t id() const { return id_; }
+	const std::string &id() const { return id_; }
 	
-	static Protocol *find(uint64_t id);
+	static Protocol *find(const std::string &id);
 			
 	protected:
 	void dispatchRPC(Socket &, const std::string &d);
@@ -78,9 +80,9 @@ class Protocol {
 	private:
 	ftl::net::Dispatcher disp_;
 	std::map<uint32_t,std::function<void(uint32_t,Socket&)>> handlers_;
-	uint64_t id_;
+	std::string id_;
 	
-	static std::map<uint64_t,Protocol*> protocols__;
+	static std::map<std::string,Protocol*> protocols__;
 };
 
 // --- Template Implementations ------------------------------------------------
diff --git a/net/include/ftl/net/socket.hpp b/net/include/ftl/net/socket.hpp
index 9bbd37565..09d0852e3 100644
--- a/net/include/ftl/net/socket.hpp
+++ b/net/include/ftl/net/socket.hpp
@@ -126,8 +126,8 @@ class Socket {
 	 * is current here for testing purposes.
 	 * @{
 	 */
-	void handshake1(const std::string &d);
-	void handshake2(const std::string &d);
+	void handshake1();
+	void handshake2();
 	/** @} */
 	
 	private: // Functions
diff --git a/net/src/listener.cpp b/net/src/listener.cpp
index 4eb83458f..c7cc7ed02 100644
--- a/net/src/listener.cpp
+++ b/net/src/listener.cpp
@@ -113,6 +113,8 @@ Listener::~Listener() {
 void Listener::connection(shared_ptr<Socket> &s) {
 	if (default_proto_) {
 		s->setProtocol(default_proto_);
+	} else {
+		s->setProtocol(NULL);
 	}
 	for (auto h : handler_connect_) h(s);
 }
diff --git a/net/src/net.cpp b/net/src/net.cpp
index cd15614f2..d6c120a1b 100644
--- a/net/src/net.cpp
+++ b/net/src/net.cpp
@@ -4,8 +4,10 @@
 
 #include <vector>
 #include <iostream>
+#include <chrono>
 
 using namespace std;
+using namespace std::chrono;
 using ftl::net::Listener;
 using ftl::net::Socket;
 
@@ -189,6 +191,12 @@ bool ftl::net::wait() {
 	return _run(false,false);
 }
 
+void ftl::net::wait(std::function<bool(void)> f) {
+	auto start = steady_clock::now();
+	while (!f() && duration<float>(steady_clock::now() - start).count() < 3.0)
+		_run(false,false);
+}
+
 bool ftl::net::run(bool async) {
 	if (async) {
 		// TODO Start thread
diff --git a/net/src/protocol.cpp b/net/src/protocol.cpp
index 1e511df90..59444e69e 100644
--- a/net/src/protocol.cpp
+++ b/net/src/protocol.cpp
@@ -6,14 +6,14 @@
 using ftl::net::Socket;
 using ftl::net::Protocol;
 
-std::map<uint64_t,Protocol*> Protocol::protocols__;
+std::map<std::string,Protocol*> Protocol::protocols__;
 
-Protocol *Protocol::find(uint64_t id) {
+Protocol *Protocol::find(const std::string &id) {
 	if (protocols__.count(id) > 0) return protocols__[id];
 	else return NULL;
 }
 
-Protocol::Protocol(uint64_t id) : id_(id) {
+Protocol::Protocol(const std::string &id) : id_(id) {
 	protocols__[id] = this;
 }
 
diff --git a/net/src/socket.cpp b/net/src/socket.cpp
index 338b44719..d6765b1ec 100644
--- a/net/src/socket.cpp
+++ b/net/src/socket.cpp
@@ -207,14 +207,25 @@ int Socket::close() {
 }
 
 void Socket::setProtocol(Protocol *p) {
-	if (proto_ == p) return;
-	if (proto_ && proto_->id() == p->id()) return;
-	
-	proto_ = p;
-	ftl::net::Handshake hs1;
-	hs1.proto = p->id();
-	send(FTL_PROTOCOL_HS1, std::string((char*)&hs1, sizeof(hs1)));
-	LOG(INFO) << "Handshake initiated with " << uri_;
+	if (p != NULL) {
+		if (proto_ == p) return;
+		if (proto_ && proto_->id() == p->id()) return;
+		
+		proto_ = p;
+		Handshake hs1;
+		hs1.magic = ftl::net::MAGIC;
+		hs1.name_size = 0;
+		hs1.proto_size = p->id().size();
+		send(FTL_PROTOCOL_HS1, hs1, p->id());
+		LOG(INFO) << "Handshake initiated with " << uri_;
+	} else {
+		Handshake hs1;
+		hs1.magic = ftl::net::MAGIC;
+		hs1.name_size = 0;
+		hs1.proto_size = 0;
+		send(FTL_PROTOCOL_HS1, hs1);
+		LOG(INFO) << "Handshake initiated with " << uri_;
+	}
 }
 
 void Socket::error() {
@@ -278,9 +289,9 @@ bool Socket::data() {
 	gpos_ = 0;
 	
 	if (service == FTL_PROTOCOL_HS1 && !connected_) {
-		handshake1(d);
+		handshake1();
 	} else if (service == FTL_PROTOCOL_HS2 && !connected_) {
-		handshake2(d);
+		handshake2();
 	} else if (service == FTL_PROTOCOL_RPC) {
 		if (proto_) proto_->dispatchRPC(*this, d);
 		else LOG(WARNING) << "No protocol set for socket " << uri_;
@@ -309,53 +320,33 @@ int Socket::read(std::string &s, size_t count) {
 	return count;
 }
 
-void Socket::handshake1(const std::string &d) {
-	ftl::net::Handshake *hs;
-	if (d.size() != sizeof(ftl::net::Handshake)) {
-		LOG(ERROR) << "Handshake failed for " << uri_;
-		close();
-		return;
-	}
-	
-	hs = (ftl::net::Handshake*)d.data();
-	auto proto = Protocol::find(hs->proto);
-	if (proto == NULL) {
-		LOG(ERROR) << "Protocol (" << hs->proto << ") not found during handshake for " << uri_;
-		close();
-		return;
-	} else {
-		proto_ = proto;
+void Socket::handshake1() {
+	Handshake header;
+	read(header);
+
+	std::string peer;
+	if (header.name_size > 0) read(peer,header.name_size);
+
+	std::string protouri;
+	if (header.proto_size > 0) read(protouri,header.proto_size);
+
+	if (protouri.size() > 0) {
+		auto proto = Protocol::find(protouri);
+		if (proto == NULL) {
+			LOG(ERROR) << "Protocol (" << protouri << ") not found during handshake for " << uri_;
+			close();
+			return;
+		} else {
+			proto_ = proto;
+		}
 	}
-	peerid_ = std::string(&hs->peerid[0],16);
-	
-	ftl::net::Handshake hs2;
-	//hs2.magic = ftl::net::MAGIC;
-	//hs2.version = version_;
-	// TODO Set peerid;
-	send(FTL_PROTOCOL_HS2, std::string((char*)&hs2, sizeof(hs2)));
-	LOG(INFO) << "Handshake" << " confirmed from " << uri_;
+
+	send(FTL_PROTOCOL_HS2); // TODO Counterpart protocol.
+	LOG(INFO) << "Handshake (" << protouri << ") confirmed from " << uri_;
 	_connected();
 }
 
-void Socket::handshake2(const std::string &d) {
-	ftl::net::Handshake *hs;
-	if (d.size() != sizeof(ftl::net::Handshake)) {
-		LOG(ERROR) << "Handshake failed for " << uri_;
-		close();
-		return;
-	}
-	
-	hs = (ftl::net::Handshake*)d.data();
-	/*if (hs->magic != ftl::net::MAGIC) {
-		LOG(ERROR) << "Handshake magic failed for " << uri_;
-		close();
-		return;
-	}
-	
-	version_ = (hs->version > ftl::net::version()) ?
-			ftl::net::version() :
-			hs->version;*/
-	peerid_ = std::string(&hs->peerid[0],16);
+void Socket::handshake2() {
 	LOG(INFO) << "Handshake finalised for " << uri_;
 	_connected();
 }
diff --git a/net/test/net_integration.cpp b/net/test/net_integration.cpp
index c63b63569..0989f6772 100644
--- a/net/test/net_integration.cpp
+++ b/net/test/net_integration.cpp
@@ -162,8 +162,7 @@ TEST_CASE("net::listen()", "[net]") {
 		SECTION("can connect to listening socket") {
 			auto sock = ftl::net::connect("tcp://127.0.0.1:9001");
 			REQUIRE(sock->isValid());
-			ftl::net::wait(); // Handshake 1
-			ftl::net::wait(); // Handshake 2
+			ftl::net::wait([&sock]() { return sock->isConnected(); });
 			REQUIRE(sock->isConnected());
 
 			// TODO Need way of knowing about connection
@@ -194,7 +193,7 @@ TEST_CASE("net::listen()", "[net]") {
 TEST_CASE("Net Integration", "[integrate]") {
 	std::string data;
 	
-	Protocol p(143);
+	Protocol p("ftl://utu.fi");
 	
 	p.bind("add", [](int a, int b) {
 		return a + b;
diff --git a/net/test/protocol_unit.cpp b/net/test/protocol_unit.cpp
index 33ae20ffc..65fca0b0e 100644
--- a/net/test/protocol_unit.cpp
+++ b/net/test/protocol_unit.cpp
@@ -21,7 +21,7 @@ using ftl::net::Socket;
 
 class MockProtocol : public Protocol {
 	public:
-	MockProtocol() : Protocol(33) {}
+	MockProtocol() : Protocol("ftl://utu.fi") {}
 	void mock_dispatchRPC(Socket &s, const std::string &d) { dispatchRPC(s,d); }
 	void mock_dispatchReturn(Socket &s, const std::string &d) { dispatchReturn(s,d); }
 	void mock_dispatchRaw(uint32_t msg, Socket &s) { dispatchRaw(msg,s); }
diff --git a/net/test/socket_unit.cpp b/net/test/socket_unit.cpp
index 27005e21b..19cd3a0d4 100644
--- a/net/test/socket_unit.cpp
+++ b/net/test/socket_unit.cpp
@@ -27,13 +27,13 @@ void ftl::net::Protocol::dispatchRaw(uint32_t service, Socket &s) {
 
 }
 
-ftl::net::Protocol::Protocol(uint64_t id) {
+ftl::net::Protocol::Protocol(const std::string &id) {
 }
 
 ftl::net::Protocol::~Protocol() {
 }
 
-ftl::net::Protocol *ftl::net::Protocol::find(uint64_t p) {
+ftl::net::Protocol *ftl::net::Protocol::find(const std::string &p) {
 	return NULL;
 }
 
@@ -168,7 +168,7 @@ TEST_CASE("Socket::call()", "[rpc]") {
 
 TEST_CASE("Socket receive RPC", "[rpc]") {
 	MockSocket s;
-	auto p = new Protocol(444);
+	auto p = new Protocol("ftl://utu.fi");
 	s.setProtocol(p);
 	
 	SECTION("no argument call") {		
-- 
GitLab