From e8d65b548d8357337a4a824a047f5b8a90a8a29e Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nwpope@utu.fi>
Date: Fri, 22 Feb 2019 13:50:33 +0200
Subject: [PATCH] Functioning and tested RPC bind and call

---
 net/include/ftl/net/dispatcher.hpp |   6 +
 net/include/ftl/net/protocol.hpp   |   1 +
 net/include/ftl/net/socket.hpp     |  19 ++--
 net/src/dispatcher.cpp             |  31 +++++-
 net/src/socket.cpp                 |  52 ++++++++-
 net/test/CMakeLists.txt            |  22 +++-
 net/test/rpc.cpp                   | 171 ++++++++++++++++++++++++++++-
 7 files changed, 278 insertions(+), 24 deletions(-)

diff --git a/net/include/ftl/net/dispatcher.hpp b/net/include/ftl/net/dispatcher.hpp
index 0173858de..840cc79ad 100644
--- a/net/include/ftl/net/dispatcher.hpp
+++ b/net/include/ftl/net/dispatcher.hpp
@@ -33,9 +33,11 @@ namespace internal {
 }
 
 namespace net {
+class Socket;
 
 class Dispatcher {
 	public:
+	Dispatcher(Socket *s) : sock_(s) {}
 	
 	void dispatch(const std::string &msg);
 	
@@ -115,8 +117,12 @@ class Dispatcher {
 
     //! \brief This is the type of notification messages.
     using notification_t = std::tuple<int8_t, std::string, msgpack::object>;
+    
+    using response_t =
+        std::tuple<uint32_t, uint32_t, msgpack::object, msgpack::object>;
 	
 	private:
+	ftl::net::Socket *sock_;
 	std::unordered_map<std::string, adaptor_type> funcs_;
 	
 	static void enforce_arg_count(std::string const &func, std::size_t found,
diff --git a/net/include/ftl/net/protocol.hpp b/net/include/ftl/net/protocol.hpp
index 3e85f9c1f..000e983de 100644
--- a/net/include/ftl/net/protocol.hpp
+++ b/net/include/ftl/net/protocol.hpp
@@ -2,6 +2,7 @@
 #define _FTL_NET_PROTOCOL_HPP_
 
 #define FTL_PROTOCOL_RPC		0x0100
+#define FTL_PROTOCOL_RPCRETURN	0x0101
 #define FTL_PROTOCOL_P2P		0x1000
 
 #endif // _FTL_NET_PROTOCOL_HPP_
diff --git a/net/include/ftl/net/socket.hpp b/net/include/ftl/net/socket.hpp
index bebcade26..01eb87f03 100644
--- a/net/include/ftl/net/socket.hpp
+++ b/net/include/ftl/net/socket.hpp
@@ -21,6 +21,11 @@
 namespace ftl {
 namespace net {
 
+struct Header {
+	uint32_t size;
+	uint32_t service;
+};
+
 class Socket {
 	public:
 	Socket(const char *uri);
@@ -49,13 +54,13 @@ class Socket {
 		     typename ftl::internal::func_kind_info<F>::args_kind());
 	}
 	
-	template <typename... ARGS>
-	msgpack::object_handle call(const std::string &name, ARGS... args) {
+	template <typename T, typename... ARGS>
+	T call(const std::string &name, ARGS... args) {
 		bool hasreturned = false;
-		msgpack::object_handle result;
-		async_call(name, [result,hasreturned](msgpack::object_handle r) {
+		T result;
+		async_call(name, [&result,&hasreturned](msgpack::object &r) {
 			hasreturned = true;
-			result = r;
+			result = r.as<T>();
 		}, std::forward<ARGS>(args)...);
 		
 		// Loop the network
@@ -71,7 +76,7 @@ class Socket {
 	template <typename... ARGS>
 	void async_call(
 			const std::string &name,
-			std::function<void(msgpack::object_handle)> cb,
+			std::function<void(msgpack::object&)> cb,
 			ARGS... args) {
 		auto args_obj = std::make_tuple(args...);
 		auto rpcid = rpcid__++;
@@ -107,7 +112,7 @@ class Socket {
 	char *m_buffer;
 	sockdatahandler_t m_handler;
 	bool m_valid;
-	std::map<int, std::function<void(msgpack::object_handle)>> callbacks_;
+	std::map<int, std::function<void(msgpack::object&)>> callbacks_;
 	ftl::net::Dispatcher disp_;
 	
 	static int rpcid__;
diff --git a/net/src/dispatcher.cpp b/net/src/dispatcher.cpp
index 294d1398b..9a3e815ea 100644
--- a/net/src/dispatcher.cpp
+++ b/net/src/dispatcher.cpp
@@ -1,6 +1,20 @@
 #include <ftl/net/dispatcher.hpp>
+#include <ftl/net/socket.hpp>
+#include <iostream>
+
+/*static std::string hexStr(const std::string &s)
+{
+	const char *data = s.data();
+	int len = s.size();
+    std::stringstream ss;
+    ss << std::hex;
+    for(int i=0;i<len;++i)
+        ss << std::setw(2) << std::setfill('0') << (int)data[i];
+    return ss.str();
+}*/
 
 void ftl::net::Dispatcher::dispatch(const std::string &msg) {
+	//std::cout << "Received dispatch : " << hexStr(msg) << std::endl;
     auto unpacked = msgpack::unpack(msg.data(), msg.size());
     dispatch(unpacked.get());
 }
@@ -8,10 +22,11 @@ void ftl::net::Dispatcher::dispatch(const std::string &msg) {
 void ftl::net::Dispatcher::dispatch(const msgpack::object &msg) {
     switch (msg.via.array.size) {
     case 3:
-        dispatch_notification(msg);
+        dispatch_notification(msg); break;
     case 4:
-        dispatch_call(msg);
+        dispatch_call(msg); break;
     default:
+    	std::cout << "Unrecognised msgpack : " << msg.via.array.size << std::endl;
         return;
     }
 }
@@ -24,7 +39,7 @@ void ftl::net::Dispatcher::dispatch_call(const msgpack::object &msg) {
     // auto &&type = std::get<0>(the_call);
     // assert(type == 0);
 
-   // auto &&id = std::get<1>(the_call);
+    auto &&id = std::get<1>(the_call);
     auto &&name = std::get<2>(the_call);
     auto &&args = std::get<3>(the_call);
 
@@ -32,8 +47,12 @@ void ftl::net::Dispatcher::dispatch_call(const msgpack::object &msg) {
 
     if (it_func != end(funcs_)) {
         try {
-            auto result = (it_func->second)(args);
-            // TODO SEND RESULTS
+            auto result = (it_func->second)(args)->get();
+			auto res_obj = std::make_tuple(1,id,msgpack::object(),result);
+			std::stringstream buf;
+			msgpack::pack(buf, res_obj);
+			
+			sock_->send(FTL_PROTOCOL_RPCRETURN, buf.str());
         } catch (...) {
 			throw;
 		}
@@ -50,6 +69,8 @@ void ftl::net::Dispatcher::dispatch_notification(msgpack::object const &msg) {
 
     auto &&name = std::get<1>(the_call);
     auto &&args = std::get<2>(the_call);
+    
+    std::cout << "RPC NOTIFY" << name << std::endl;
 
     auto it_func = funcs_.find(name);
 
diff --git a/net/src/socket.cpp b/net/src/socket.cpp
index b81c7151c..2fd886031 100644
--- a/net/src/socket.cpp
+++ b/net/src/socket.cpp
@@ -25,6 +25,8 @@ using namespace ftl;
 using ftl::net::Socket;
 using namespace std;
 
+int Socket::rpcid__ = 0;
+
 static int tcpConnect(URI &uri) {
 	int rc;
 	sockaddr_in destAddr;
@@ -105,13 +107,13 @@ static int wsConnect(URI &uri) {
 	return 1;
 }
 
-Socket::Socket(int s) : m_sock(s), m_pos(0) {
+Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) {
 	// TODO Get the remote address.
 	m_valid = true;
 	m_buffer = new char[BUFFER_SIZE];
 }
 
-Socket::Socket(const char *pUri) : m_uri(pUri), m_pos(0) {
+Socket::Socket(const char *pUri) : m_uri(pUri), m_pos(0), disp_(this) {
 	// Allocate buffer
 	m_buffer = new char[BUFFER_SIZE];
 	
@@ -169,6 +171,7 @@ bool Socket::data() {
 	while (m_pos < len+4) {
 		if (len > MAX_MESSAGE) {
 			close();
+			std::cout << "Length is too big" << std::endl;
 			return false; // Prevent DoS
 		}
 
@@ -185,8 +188,10 @@ bool Socket::data() {
 			}
 		} else if (rc == EWOULDBLOCK || rc == 0) {
 			// Data not yet available
+			std::cout << "No data to read" << std::endl;
 			return false;
 		} else {
+			std::cout << "Socket error" << std::endl;
 			// Close socket due to error
 			close();
 			return false;
@@ -194,18 +199,55 @@ bool Socket::data() {
 	}
 
 	// All data available
-	if (m_handler) {
+	//if (m_handler) {
 		uint32_t service = ((uint32_t*)m_buffer)[1];
 		auto d = std::string(m_buffer+8, len-4);
 		//std::cerr << "DATA : " << service << " -> " << d << std::endl;
-		m_handler(service, d); 
-	}
+		
+		if (service == FTL_PROTOCOL_RPC) {
+			dispatch(d);
+		} else if (service == FTL_PROTOCOL_RPCRETURN) {
+			auto unpacked = msgpack::unpack(d.data(), d.size());
+			Dispatcher::response_t the_result;
+			unpacked.get().convert(the_result);
+
+			// TODO: proper validation of protocol (and responding to it)
+			// auto &&type = std::get<0>(the_call);
+			// assert(type == 0);
+
+			// auto &&id = std::get<1>(the_call);
+			auto &&id = std::get<1>(the_result);
+			//auto &&err = std::get<2>(the_result);
+			auto &&res = std::get<3>(the_result);
+
+			if (callbacks_.count(id) > 0) callbacks_[id](res);
+			else std::cout << "NO CALLBACK FOUND FOR RPC RESULT" << std::endl;
+		} else {
+			if (m_handler) m_handler(service, d);
+		} 
+	//}
 
 	m_pos = 0;
 
 	return true;
 }
 
+int Socket::send(uint32_t service, const std::string &data) {
+	ftl::net::Header h;
+	h.size = data.size()+4;
+	h.service = service;
+	
+	iovec vec[2];
+	vec[0].iov_base = &h;
+	vec[0].iov_len = sizeof(h);
+	vec[1].iov_base = const_cast<char*>(data.data());
+	vec[1].iov_len = data.size();
+	
+	::writev(m_sock, &vec[0], 2);
+	
+	return 0;
+}
+
 Socket::~Socket() {
 	close();
 	
diff --git a/net/test/CMakeLists.txt b/net/test/CMakeLists.txt
index e565783b9..f8c7f42cc 100644
--- a/net/test/CMakeLists.txt
+++ b/net/test/CMakeLists.txt
@@ -1,4 +1,16 @@
-add_executable(tests EXCLUDE_FROM_ALL
+include(CTest)
+enable_testing()
+
+add_executable(rpc_test EXCLUDE_FROM_ALL
+	./tests.cpp
+	./rpc.cpp
+	../src/dispatcher.cpp
+	../src/socket.cpp
+)
+target_include_directories(rpc_test PUBLIC ${PROJECT_SOURCE_DIR}/include)
+target_link_libraries(rpc_test uriparser)
+
+add_executable(socket_test EXCLUDE_FROM_ALL
 	./tests.cpp
 	./net_raw.cpp
 	../src/net.cpp
@@ -7,9 +19,11 @@ add_executable(tests EXCLUDE_FROM_ALL
 	./ice.cpp
 	../src/ice.cpp
 	./uri.cpp
-	./rpc.cpp
 	../src/dispatcher.cpp
 )
-target_include_directories(tests PUBLIC ${PROJECT_SOURCE_DIR}/include)
-target_link_libraries(tests uriparser)
+target_include_directories(socket_test PUBLIC ${PROJECT_SOURCE_DIR}/include)
+target_link_libraries(socket_test uriparser)
+
+add_custom_target(tests)
+add_dependencies(tests rpc_test socket_test)
 
diff --git a/net/test/rpc.cpp b/net/test/rpc.cpp
index 9d48684fc..1b83ce8c0 100644
--- a/net/test/rpc.cpp
+++ b/net/test/rpc.cpp
@@ -1,9 +1,71 @@
 #include "catch.hpp"
 #include <ftl/net/socket.hpp>
 #include <iostream>
+#include <memory>
+#include <map>
+
+/*struct FakeHeader {
+	uint32_t size;
+	uint32_t service;
+};*/
+
+static std::map<int, std::string> fakedata;
+
+/*static std::string hexStr(const std::string &s)
+{
+	const char *data = s.data();
+	int len = s.size();
+    std::stringstream ss;
+    ss << std::hex;
+    for(int i=0;i<len;++i)
+        ss << std::setw(2) << std::setfill('0') << (int)data[i];
+    return ss.str();
+}*/
+
+void fake_send(int sd, uint32_t  service, const std::string &data) {
+	//std::cout << "HEX SEND: " << hexStr(data) << std::endl;
+	char buf[8+data.size()];
+	ftl::net::Header *h = (ftl::net::Header*)&buf;
+	h->size = data.size()+4;
+	h->service = service;
+	std::memcpy(&buf[8],data.data(),data.size());
+	fakedata[sd] = std::string(&buf[0], 8+data.size());
+	
+	//std::cout << "HEX SEND2: " << hexStr(fakedata[sd]) << std::endl;
+}
+
+extern ssize_t recv(int sd, void *buf, size_t n, int f) {
+	if (fakedata.count(sd) == 0) {
+		std::cout << "Unrecognised socket" << std::endl;
+		return 0;
+	}
+	
+	int l = fakedata[sd].size();
+	
+	std::memcpy(buf, fakedata[sd].c_str(), l);
+	fakedata.erase(sd);
+	return l;
+}
+
+extern ssize_t writev(int sd, const struct iovec *v, int cnt) {
+	size_t len = v[0].iov_len+v[1].iov_len;
+	char buf[len];
+	std::memcpy(&buf[0],v[0].iov_base,v[0].iov_len);
+	std::memcpy(&buf[v[0].iov_len],v[1].iov_base,v[1].iov_len);
+	fakedata[sd] = std::string(&buf[0], len);
+	return 0;
+}
+
+static std::function<void()> waithandler;
+
+bool ftl::net::wait() {
+	if (waithandler) waithandler();
+	//waithandler = nullptr;
+	return true;
+}
 
 TEST_CASE("Socket::bind()", "[rpc]") {
-	SECTION("no argument bind") {
+	SECTION("no argument bind fake dispatch") {
 		auto s = new ftl::net::Socket(0);
 		bool called = false;
 		
@@ -20,7 +82,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
 		REQUIRE( called );
 	}
 	
-	SECTION("one argument bind") {
+	SECTION("one argument bind fake dispatch") {
 		auto s = new ftl::net::Socket(0);
 		bool called = false;
 		
@@ -38,7 +100,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
 		REQUIRE( called );
 	}
 	
-	SECTION("two argument bind") {
+	SECTION("two argument bind fake dispatch") {
 		auto s = new ftl::net::Socket(0);
 		bool called = false;
 		
@@ -56,5 +118,108 @@ TEST_CASE("Socket::bind()", "[rpc]") {
 		s->dispatch(buf.str());
 		REQUIRE( called );
 	}
+	
+	SECTION("no argument bind fake data") {
+		auto s = new ftl::net::Socket(0);
+		bool called = false;
+		
+		s->bind("test1", [&]() {
+			called = true;
+		});
+		
+		auto args_obj = std::make_tuple();
+		auto call_obj = std::make_tuple(0,0,"test1",args_obj);
+		std::stringstream buf;
+		msgpack::pack(buf, call_obj);
+		
+		fake_send(0, FTL_PROTOCOL_RPC, buf.str());
+		REQUIRE( s->data() );
+		REQUIRE( called );
+	}
+	
+	SECTION("non-void bind fake data") {
+		auto s = new ftl::net::Socket(0);
+		bool called = false;
+		
+		s->bind("test1", [&]() -> int {
+			called = true;
+			return 55;
+		});
+		
+		auto args_obj = std::make_tuple();
+		auto call_obj = std::make_tuple(0,0,"test1",args_obj);
+		std::stringstream buf;
+		msgpack::pack(buf, call_obj);
+		
+		fake_send(0, FTL_PROTOCOL_RPC, buf.str());
+		REQUIRE( s->data() );
+		REQUIRE( called );
+		
+		// TODO Require that a writev occurred with result value
+	}
+}
+
+TEST_CASE("Socket::call()", "[rpc]") {
+	SECTION("no argument call") {
+		auto s = new ftl::net::Socket(0);
+		
+		waithandler = [&]() {
+			// Read fakedata sent
+			// TODO Validate data
+			
+			// Do a fake send
+			auto res_obj = std::make_tuple(1,0,msgpack::object(),66);
+			std::stringstream buf;
+			msgpack::pack(buf, res_obj);
+		
+			fake_send(0, FTL_PROTOCOL_RPCRETURN, buf.str());
+			s->data();
+		};
+		
+		int res = s->call<int>("test1");
+		
+		REQUIRE( res == 66 );
+	}
+	
+	SECTION("one argument call") {
+		auto s = new ftl::net::Socket(0);
+		
+		waithandler = [&]() {
+			// Read fakedata sent
+			// TODO Validate data
+			
+			// Do a fake send
+			auto res_obj = std::make_tuple(1,1,msgpack::object(),43);
+			std::stringstream buf;
+			msgpack::pack(buf, res_obj);
+		
+			fake_send(0, FTL_PROTOCOL_RPCRETURN, buf.str());
+			s->data();
+		};
+		
+		int res = s->call<int>("test1", 78);
+		
+		REQUIRE( res == 43 );
+	}
+	
+	waithandler = nullptr;
+}
+
+TEST_CASE("Socket::call+bind loop", "[rpc]") {
+	auto s = new ftl::net::Socket(0);
+
+	// Just loop the send back to the recv
+	waithandler = [&]() {
+		s->data();
+	};
+	
+	s->bind("test1", [](int a) -> int {
+		std::cout << "Bind test1 called" << std::endl;
+		return a*2;
+	});
+	
+	int res = s->call<int>("test1", 5);
+	
+	REQUIRE( res == 10 );
 }
 
-- 
GitLab