Skip to content
Snippets Groups Projects
Commit 3f92ca5c authored by Nicolas Pope's avatar Nicolas Pope
Browse files

Refactor net code, fix tests and fix connection bugs

parent aaf9b03c
Branches
Tags
No related merge requests found
......@@ -22,10 +22,12 @@
namespace ftl {
namespace net {
#pragma pack(push,1)
struct Header {
uint32_t size;
uint32_t service;
};
#pragma pack(pop)
class Socket {
public:
......@@ -36,26 +38,28 @@ class Socket {
int close();
int send(uint32_t service, const std::string &data);
int send(uint32_t service, std::stringstream &data) { return send(service, data.str()); };
int send(uint32_t service, std::stringstream &data) {
return send(service, data.str()); };
int send(uint32_t service, void *data, int length);
int send2(uint32_t service, const std::string &data1, const std::string &data2);
int send2(uint32_t service, const std::string &data1,
const std::string &data2);
//friend bool ftl::net::run(bool);
int _socket() const { return m_sock; };
int _socket() const { return sock_; };
bool isConnected() const { return m_sock != INVALID_SOCKET; };
bool isValid() const { return m_valid; };
std::string getURI() const { return m_uri; };
bool isConnected() const { return sock_ != INVALID_SOCKET && connected_; };
bool isValid() const { return valid_ && sock_ != INVALID_SOCKET; };
std::string getURI() const { return uri_; };
/**
* Bind a function to a RPC call name.
*/
template <typename F>
void bind(const std::string &name, F func) {
//disp_.enforce_unique_name(name);
disp_.bind(name, func, typename ftl::internal::func_kind_info<F>::result_kind(),
disp_.bind(name, func,
typename ftl::internal::func_kind_info<F>::result_kind(),
typename ftl::internal::func_kind_info<F>::args_kind());
}
......@@ -96,7 +100,7 @@ class Socket {
auto rpcid = rpcid__++;
auto call_obj = std::make_tuple(0,rpcid,name,args_obj);
LOG(INFO) << "RPC " << name << "() -> " << m_uri;
LOG(INFO) << "RPC " << name << "() -> " << uri_;
std::stringstream buf;
msgpack::pack(buf, call_obj);
......@@ -107,9 +111,16 @@ class Socket {
send(FTL_PROTOCOL_RPC, buf.str());
}
void dispatch(const std::string &b) { disp_.dispatch(b); }
/**
* Internal handlers for specific event types.
* @{
*/
void dispatchRPC(const std::string &d) { disp_.dispatch(d); }
void dispatchReturn(const std::string &d);
void handshake1(const std::string &d);
void handshake2(const std::string &d);
/** @} */
//void onMessage(sockdatahandler_t handler) { m_handler = handler; }
void onError(sockerrorhandler_t handler) {}
void onConnect(std::function<void(Socket&)> f);
void onDisconnect(sockdisconnecthandler_t handler) {}
......@@ -118,18 +129,19 @@ class Socket {
void error();
private:
std::string m_uri;
int m_sock;
size_t m_pos;
char *m_buffer;
std::string uri_;
int sock_;
size_t pos_;
char *buffer_;
std::map<uint32_t,std::function<void(Socket&,const std::string&)>> handlers_;
std::vector<std::function<void(Socket&)>> connect_handlers_;
bool m_valid;
bool m_connected;
bool valid_;
bool connected_;
std::map<int, std::function<void(msgpack::object&)>> callbacks_;
ftl::net::Dispatcher disp_;
void _connected();
void _updateURI();
static int rpcid__;
......
......
......@@ -57,7 +57,7 @@ static int setDescriptors() {
//Set the file descriptors for each client
for (auto s : sockets) {
if (s != nullptr && s->isConnected()) {
if (s != nullptr && s->isValid()) {
if (s->_socket() > n) {
n = s->_socket();
......@@ -78,7 +78,7 @@ shared_ptr<Listener> ftl::net::listen(const char *uri) {
}
shared_ptr<Socket> ftl::net::connect(const char *uri) {
shared_ptr<Socket> s(new Socket(uri));
shared_ptr<Socket> s(new Socket((uri == NULL) ? "" : uri));
sockets.push_back(s);
return s;
}
......@@ -142,12 +142,10 @@ bool _run(bool blocking, bool nodelay) {
if (csock != INVALID_SOCKET) {
auto sock = make_shared<Socket>(csock);
//sockets[freeclient] = sock;
sockets.push_back(sock);
// Call connection handlers
l->connection(sock);
sockets.push_back(std::move(sock));
}
//}
}
......@@ -156,12 +154,12 @@ bool _run(bool blocking, bool nodelay) {
//Also check each clients socket to see if any messages or errors are waiting
for (auto s : sockets) {
if (s != NULL && s->isConnected()) {
if (s != NULL && s->isValid()) {
//If message received from this client then deal with it
if (FD_ISSET(s->_socket(), &sfdread)) {
repeat |= s->data();
//An error occured with this client.
} else if (FD_ISSET(s->_socket(), &sfderror)) {
}
if (FD_ISSET(s->_socket(), &sfderror)) {
s->error();
}
}
......
......
......@@ -27,7 +27,7 @@ using namespace ftl;
using ftl::net::Socket;
using namespace std;
static std::string hexStr(const std::string &s)
/*static std::string hexStr(const std::string &s)
{
const char *data = s.data();
int len = s.size();
......@@ -36,7 +36,7 @@ static std::string hexStr(const std::string &s)
for(int i=0;i<len;++i)
ss << std::setw(2) << std::setfill('0') << (int)data[i];
return ss.str();
}
}*/
int Socket::rpcid__ = 0;
......@@ -120,14 +120,39 @@ static int wsConnect(URI &uri) {
return 1;
}
Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) {
m_valid = true;
m_buffer = new char[BUFFER_SIZE];
m_connected = false;
Socket::Socket(int s) : sock_(s), pos_(0), disp_(this) {
valid_ = true;
buffer_ = new char[BUFFER_SIZE];
connected_ = false;
_updateURI();
}
Socket::Socket(const char *pUri) : uri_(pUri), pos_(0), disp_(this) {
// Allocate buffer
buffer_ = new char[BUFFER_SIZE];
URI uri(pUri);
valid_ = false;
connected_ = false;
sock_ = INVALID_SOCKET;
if (uri.getProtocol() == URI::SCHEME_TCP) {
sock_ = tcpConnect(uri);
valid_ = true;
} else if (uri.getProtocol() == URI::SCHEME_WS) {
wsConnect(uri);
LOG(ERROR) << "Websocket currently unsupported";
} else {
LOG(ERROR) << "Unrecognised connection protocol: " << pUri;
}
}
void Socket::_updateURI() {
sockaddr_storage addr;
int rsize = sizeof(sockaddr_storage);
if (getpeername(s, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) {
if (getpeername(sock_, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) {
char addrbuf[INET6_ADDRSTRLEN];
int port;
......@@ -143,39 +168,23 @@ Socket::Socket(int s) : m_sock(s), m_pos(0), disp_(this) {
port = s->sin6_port;
}
m_uri = std::string("tcp://")+addrbuf;
m_uri += ":";
m_uri += std::to_string(port);
}
}
Socket::Socket(const char *pUri) : m_uri(pUri), m_pos(0), disp_(this) {
// Allocate buffer
m_buffer = new char[BUFFER_SIZE];
URI uri(pUri);
// TODO verify tcp or udp etc.
m_valid = false;
m_connected = false;
m_sock = INVALID_SOCKET;
if (uri.getProtocol() == URI::SCHEME_TCP) {
m_sock = tcpConnect(uri);
m_valid = true;
} else if (uri.getProtocol() == URI::SCHEME_WS) {
wsConnect(uri);
} else {
uri_ = std::string("tcp://")+addrbuf;
uri_ += ":";
uri_ += std::to_string(port);
}
}
int Socket::close() {
if (isConnected()) {
#ifndef WIN32
::close(m_sock);
::close(sock_);
#else
closesocket(m_sock);
closesocket(sock_);
#endif
m_sock = INVALID_SOCKET;
sock_ = INVALID_SOCKET;
connected_ = false;
// Attempt auto reconnect?
}
......@@ -185,87 +194,105 @@ int Socket::close() {
void Socket::error() {
int err;
uint32_t optlen = sizeof(err);
getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &err, &optlen);
LOG(ERROR) << "Socket: " << m_uri << " - error " << err;
getsockopt(sock_, SOL_SOCKET, SO_ERROR, &err, &optlen);
LOG(ERROR) << "Socket: " << uri_ << " - error " << err;
}
bool Socket::data() {
//std::cerr << "GOT SOCKET DATA" << std::endl;
//Read data from socket
size_t n = 0;
uint32_t len = 0;
if (m_pos < 4) {
n = 4 - m_pos;
if (pos_ < 4) {
n = 4 - pos_;
} else {
len = *(int*)m_buffer;
n = len+4-m_pos;
len = *(int*)buffer_;
n = len+4-pos_;
}
while (m_pos < len+4) {
while (pos_ < len+4) {
if (len > MAX_MESSAGE) {
close();
LOG(ERROR) << "Socket: " << m_uri << " - message attack";
return false; // Prevent DoS
LOG(ERROR) << "Socket: " << uri_ << " - message attack";
return false;
}
const int rc = recv(m_sock, m_buffer+m_pos, n, 0);
const int rc = recv(sock_, buffer_+pos_, n, 0);
if (rc > 0) {
m_pos += static_cast<size_t>(rc);
pos_ += static_cast<size_t>(rc);
if (m_pos < 4) {
n = 4 - m_pos;
if (pos_ < 4) {
n = 4 - pos_;
} else {
len = *(int*)m_buffer;
n = len+4-m_pos;
len = *(int*)buffer_;
n = len+4-pos_;
}
} else if (rc == EWOULDBLOCK || rc == 0) {
// Data not yet available
//std::cout << "No data to read" << std::endl;
return false;
} else {
LOG(ERROR) << "Socket: " << m_uri << " - error " << rc;
// Close socket due to error
LOG(ERROR) << "Socket: " << uri_ << " - error " << rc;
close();
return false;
}
}
// All data available
//if (m_handler) {
uint32_t service = ((uint32_t*)m_buffer)[1];
auto d = std::string(m_buffer+8, len-4);
//std::cerr << "DATA : " << service << " -> " << d << std::endl;
// Route the message...
uint32_t service = ((uint32_t*)buffer_)[1];
auto d = std::string(buffer_+8, len-4);
if (service == FTL_PROTOCOL_HS1 && !connected_) {
handshake1(d);
} else if (service == FTL_PROTOCOL_HS2 && !connected_) {
handshake2(d);
} else if (service == FTL_PROTOCOL_RPC) {
dispatchRPC(d);
} else if (service == FTL_PROTOCOL_RPCRETURN) {
dispatchReturn(d);
} else {
// Lookup raw message handler
if (handlers_.count(service) > 0) {
handlers_[service](*this, d);
} else {
LOG(ERROR) << "Unrecognised service request (" << service << ") from " << uri_;
}
}
pos_ = 0;
return true;
}
if (service == FTL_PROTOCOL_HS1 && !m_connected) {
void Socket::handshake1(const std::string &d) {
// TODO Verify data
std::string hs2("HELLO");
send(FTL_PROTOCOL_HS2, hs2);
LOG(INFO) << "Handshake confirmed from " << m_uri;
LOG(INFO) << "Handshake confirmed from " << uri_;
_connected();
} else if (service == FTL_PROTOCOL_HS2 && !m_connected) {
}
void Socket::handshake2(const std::string &d) {
// TODO Verify data
LOG(INFO) << "Handshake finalised for " << m_uri;
LOG(INFO) << "Handshake finalised for " << uri_;
_connected();
} else if (service == FTL_PROTOCOL_RPC) {
dispatch(d);
} else if (service == FTL_PROTOCOL_RPCRETURN) {
}
void Socket::dispatchReturn(const std::string &d) {
auto unpacked = msgpack::unpack(d.data(), d.size());
Dispatcher::response_t the_result;
unpacked.get().convert(the_result);
// TODO: proper validation of protocol (and responding to it)
// auto &&type = std::get<0>(the_call);
// assert(type == 0);
if (std::get<0>(the_result) != 1) {
LOG(ERROR) << "Bad RPC return message";
return;
}
// auto &&id = std::get<1>(the_call);
auto &&id = std::get<1>(the_result);
//auto &&err = std::get<2>(the_result);
auto &&res = std::get<3>(the_result);
std::cout << " ROSULT " << hexStr(d) << std::endl;
// TODO Handle error reporting...
if (callbacks_.count(id) > 0) {
LOG(INFO) << "Received return RPC value";
......@@ -274,19 +301,10 @@ bool Socket::data() {
} else {
LOG(ERROR) << "Missing RPC callback for result";
}
} else {
// Lookup raw message handler
if (handlers_.count(service) > 0) handlers_[service](*this, d);
}
//}
m_pos = 0;
return true;
}
void Socket::onConnect(std::function<void(Socket&)> f) {
if (m_connected) {
if (connected_) {
f(*this);
} else {
connect_handlers_.push_back(f);
......@@ -294,11 +312,11 @@ void Socket::onConnect(std::function<void(Socket&)> f) {
}
void Socket::_connected() {
m_connected = true;
connected_ = true;
for (auto h : connect_handlers_) {
h(*this);
}
connect_handlers_.clear();
//connect_handlers_.clear();
}
void Socket::bind(uint32_t service, std::function<void(Socket&,
......@@ -321,7 +339,7 @@ int Socket::send(uint32_t service, const std::string &data) {
vec[1].iov_base = const_cast<char*>(data.data());
vec[1].iov_len = data.size();
::writev(m_sock, &vec[0], 2);
::writev(sock_, &vec[0], 2);
return 0;
}
......@@ -339,7 +357,7 @@ int Socket::send2(uint32_t service, const std::string &data1, const std::string
vec[2].iov_base = const_cast<char*>(data2.data());
vec[2].iov_len = data2.size();
::writev(m_sock, &vec[0], 3);
::writev(sock_, &vec[0], 3);
return 0;
}
......@@ -349,7 +367,7 @@ Socket::~Socket() {
close();
// Delete socket buffer
if (m_buffer) delete [] m_buffer;
m_buffer = NULL;
if (buffer_) delete [] buffer_;
buffer_ = NULL;
}
......@@ -157,6 +157,18 @@ int setDescriptors() {
return n;
}
void mocksend(int sd, uint32_t service, const std::string &data) {
//std::cout << "HEX SEND: " << hexStr(data) << std::endl;
char buf[8+data.size()];
ftl::net::Header *h = (ftl::net::Header*)&buf;
h->size = data.size()+4;
h->service = service;
std::memcpy(&buf[8],data.data(),data.size());
//std::cout << "HEX SEND2: " << hexStr(fakedata[sd]) << std::endl;
::send(sd, buf, 8+data.size(), 0);
}
void accept_connection() {
int n = setDescriptors();
......@@ -174,6 +186,7 @@ void accept_connection() {
//Finally accept this client connection.
csock = accept(ssock, (sockaddr*)&addr, (socklen_t*)&rsize);
mocksend(csock, FTL_PROTOCOL_HS1, "HELLO");
} else {
}
......@@ -190,6 +203,7 @@ TEST_CASE("net::connect()", "[net]") {
SECTION("valid tcp connection using ipv4") {
sock = ftl::net::connect("tcp://127.0.0.1:7077");
REQUIRE(sock != nullptr);
REQUIRE(sock->isValid());
accept_connection();
}
......@@ -212,6 +226,7 @@ TEST_CASE("net::connect()", "[net]") {
SECTION("null uri") {
sock = ftl::net::connect(NULL);
REQUIRE(!sock->isValid());
sock = nullptr;
}
// Disabled due to long timeout
......@@ -224,12 +239,14 @@ TEST_CASE("net::connect()", "[net]") {
SECTION("incorrect dns address") {
sock = ftl::net::connect("tcp://xryyrrgrtgddgr.com:7077");
REQUIRE(sock->isValid());
REQUIRE(!sock->isValid());
REQUIRE(sock->isConnected() == false);
sock = nullptr;
}
if (sock && sock->isValid()) {
ftl::net::wait();
//sock->data();
REQUIRE(sock->isConnected());
REQUIRE(csock != INVALID_SOCKET);
sock->close();
......@@ -243,10 +260,11 @@ TEST_CASE("net::listen()", "[net]") {
REQUIRE( ftl::net::listen("tcp://*:7078")->isListening() );
SECTION("can connect to listening socket") {
shared_ptr<Socket> sock = ftl::net::connect("tcp://127.0.0.1:7078");
auto sock = ftl::net::connect("tcp://127.0.0.1:7078");
REQUIRE(sock->isValid());
ftl::net::wait(); // Handshake 1
ftl::net::wait(); // Handshake 2
REQUIRE(sock->isConnected());
ftl::net::wait();
// TODO Need way of knowing about connection
}
......@@ -261,6 +279,7 @@ TEST_CASE("net::listen()", "[net]") {
bool connected = false;
l->onConnection([&](shared_ptr<Socket> s) {
ftl::net::wait(); // Wait for handshake
REQUIRE( s->isConnected() );
connected = true;
});
......@@ -268,19 +287,18 @@ TEST_CASE("net::listen()", "[net]") {
auto sock = ftl::net::connect("tcp://127.0.0.1:7078");
ftl::net::wait();
REQUIRE( connected );
std::cout << "PRE STOP" << std::endl;
ftl::net::stop();
std::cout << "POST STOP" << std::endl;
}
}
TEST_CASE("Socket.onMessage()", "[net]") {
TEST_CASE("Socket.bind(int)", "[net]") {
// Need a fake server...
init_server();
shared_ptr<Socket> sock = ftl::net::connect("tcp://127.0.0.1:7077");
REQUIRE(sock->isValid());
REQUIRE(sock->isConnected());
accept_connection();
ftl::net::wait(); // Wait for handshake
REQUIRE(sock->isConnected());
SECTION("small valid message") {
send_json(1, "{message: \"Hello\"}");
......@@ -322,6 +340,7 @@ TEST_CASE("Socket.onMessage()", "[net]") {
msg++;
});
ftl::net::wait();
ftl::net::wait();
REQUIRE(msg == 2);
}
......
......
......@@ -78,7 +78,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
std::stringstream buf;
msgpack::pack(buf, call_obj);
s->dispatch(buf.str());
s->dispatchRPC(buf.str());
REQUIRE( called );
}
......@@ -96,7 +96,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
std::stringstream buf;
msgpack::pack(buf, call_obj);
s->dispatch(buf.str());
s->dispatchRPC(buf.str());
REQUIRE( called );
}
......@@ -115,7 +115,7 @@ TEST_CASE("Socket::bind()", "[rpc]") {
std::stringstream buf;
msgpack::pack(buf, call_obj);
s->dispatch(buf.str());
s->dispatchRPC(buf.str());
REQUIRE( called );
}
......@@ -220,7 +220,6 @@ TEST_CASE("Socket::call+bind loop", "[rpc]") {
});
int res = s->call<int>("test1", 5);
REQUIRE( res == 10 );
}
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment