Skip to content
Snippets Groups Projects
socket.cpp 7.69 KiB
Newer Older
#include <glog/logging.h>

#include <ftl/uri.hpp>
#include <ftl/net/socket.hpp>

#ifndef WIN32
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <arpa/inet.h>
#define INVALID_SOCKET -1
#define SOCKET_ERROR -1
#endif

#ifdef WIN32
#include <windows.h>
#include <winsock.h>
typedef int socklen_t;
#define MSG_WAITALL 0
#endif

#include <iostream>

using namespace ftl;
using ftl::net::Socket;
using namespace std;

/*static std::string hexStr(const std::string &s)
{
	const char *data = s.data();
	int len = s.size();
    std::stringstream ss;
    ss << std::hex;
    for(int i=0;i<len;++i)
        ss << std::setw(2) << std::setfill('0') << (int)data[i];
    return ss.str();
int Socket::rpcid__ = 0;

static int tcpConnect(URI &uri) {
	int rc;
	sockaddr_in destAddr;

	//std::cerr << "TCP Connect: " << uri.getHost() << " : " << uri.getPort() << std::endl;

	#ifdef WIN32
	WSAData wsaData;
	if (WSAStartup(MAKEWORD(1,1), &wsaData) != 0) {
		//ERROR
		return INVALID_SOCKET;
	}
	#endif
	
	//We want a TCP socket
	int csocket = socket(AF_INET, SOCK_STREAM, 0);

	if (csocket == INVALID_SOCKET) {
		return INVALID_SOCKET;
	}

	#ifdef WIN32
	HOSTENT *host = gethostbyname(uri.getHost().c_str());
	#else
	hostent *host = gethostbyname(uri.getHost().c_str());
	#endif

	if (host == NULL) {
		#ifndef WIN32
		close(csocket);
		#else
		closesocket(csocket);
		#endif

		LOG(ERROR) << "Address not found : " << uri.getHost() << std::endl;

		return INVALID_SOCKET;
	}

	destAddr.sin_family = AF_INET;
	destAddr.sin_addr.s_addr = ((in_addr *)(host->h_addr))->s_addr;
	destAddr.sin_port = htons(uri.getPort());

	// Make nonblocking
	/*long arg = fcntl(csocket, F_GETFL, NULL));
	arg |= O_NONBLOCK;
	fcntl(csocket, F_SETFL, arg) < 0)*/
	
	rc = ::connect(csocket, (struct sockaddr*)&destAddr, sizeof(destAddr));

	if (rc < 0) {
		if (errno == EINPROGRESS) {

		} else {
			#ifndef WIN32
			close(csocket);
			#else
			closesocket(csocket);
			#endif

			LOG(ERROR) << "Could not connect to " << uri.getBaseURI();

			return INVALID_SOCKET;
		}
	}

	// Make blocking again
	/*rg = fcntl(csocket, F_GETFL, NULL));
	arg &= (~O_NONBLOCK);
	fcntl(csocket, F_SETFL, arg) < 0)*/
	
	// Handshake??

	return csocket;
}

static int wsConnect(URI &uri) {
	return 1;
}

Socket::Socket(int s) : sock_(s), pos_(0), disp_(this) {
	valid_ = true;
	buffer_ = new char[BUFFER_SIZE];
	connected_ = false;
	
	_updateURI();
}

Socket::Socket(const char *pUri) : uri_(pUri), pos_(0), disp_(this) {
	// Allocate buffer
	buffer_ = new char[BUFFER_SIZE];
	
	URI uri(pUri);
	valid_ = false;
	connected_ = false;
	sock_ = INVALID_SOCKET;

	if (uri.getProtocol() == URI::SCHEME_TCP) {
		sock_ = tcpConnect(uri);
		valid_ = true;
	} else if (uri.getProtocol() == URI::SCHEME_WS) {
		wsConnect(uri);
		LOG(ERROR) << "Websocket currently unsupported";
	} else {
		LOG(ERROR) << "Unrecognised connection protocol: " << pUri;
	}
}

void Socket::_updateURI() {
	sockaddr_storage addr;
	int rsize = sizeof(sockaddr_storage);
	if (getpeername(sock_, (sockaddr*)&addr, (socklen_t*)&rsize) == 0) {
		char addrbuf[INET6_ADDRSTRLEN];
		int port;
		
		if (addr.ss_family == AF_INET) {
			struct sockaddr_in *s = (struct sockaddr_in *)&addr;
			//port = ntohs(s->sin_port);
			inet_ntop(AF_INET, &s->sin_addr, addrbuf, INET6_ADDRSTRLEN);
			port = s->sin_port;
		} else { // AF_INET6
			struct sockaddr_in6 *s = (struct sockaddr_in6 *)&addr;
			//port = ntohs(s->sin6_port);
			inet_ntop(AF_INET6, &s->sin6_addr, addrbuf, INET6_ADDRSTRLEN);
			port = s->sin6_port;
		}
		
		// TODO verify tcp or udp etc.
		
		uri_ = std::string("tcp://")+addrbuf;
		uri_ += ":";
		uri_ += std::to_string(port);
	}
}

int Socket::close() {
	if (isConnected()) {
		#ifndef WIN32
		sock_ = INVALID_SOCKET;
		connected_ = false;

		// Attempt auto reconnect?
	}
	return 0;
}

void Socket::error() {
	int err;
	uint32_t optlen = sizeof(err);
	getsockopt(sock_, SOL_SOCKET, SO_ERROR, &err, &optlen);
	LOG(ERROR) << "Socket: " << uri_ << " - error " << err;
}

bool Socket::data() {
	//Read data from socket
	size_t n = 0;
	uint32_t len = 0;

	if (pos_ < 4) {
		n = 4 - pos_;
		len = *(int*)buffer_;
		n = len+4-pos_;
	while (pos_ < len+4) {
		if (len > MAX_MESSAGE) {
			close();
			LOG(ERROR) << "Socket: " << uri_ << " - message attack";
			return false;
		const int rc = recv(sock_, buffer_+pos_, n, 0);
			pos_ += static_cast<size_t>(rc);
			if (pos_ < 4) {
				n = 4 - pos_;
				len = *(int*)buffer_;
				n = len+4-pos_;
			}
		} else if (rc == EWOULDBLOCK || rc == 0) {
			// Data not yet available
			return false;
		} else {
			LOG(ERROR) << "Socket: " << uri_ << " - error " << rc;
	// Route the message...
	uint32_t service = ((uint32_t*)buffer_)[1];
	auto d = std::string(buffer_+8, len-4);
	
	if (service == FTL_PROTOCOL_HS1 && !connected_) {
		handshake1(d);
	} else if (service == FTL_PROTOCOL_HS2 && !connected_) {
		handshake2(d);
	} else if (service == FTL_PROTOCOL_RPC) {
		dispatchRPC(d);
	} else if (service == FTL_PROTOCOL_RPCRETURN) {
		dispatchReturn(d);
	} else {
		// Lookup raw message handler
		if (handlers_.count(service) > 0) {
			handlers_[service](*this, d);
			LOG(ERROR) << "Unrecognised service request (" << service << ") from " << uri_;
		}
	} 
void Socket::handshake1(const std::string &d) {
	// TODO Verify data
	std::string hs2("HELLO");
	send(FTL_PROTOCOL_HS2, hs2);
	LOG(INFO) << "Handshake confirmed from " << uri_;
	_connected();
}

void Socket::handshake2(const std::string &d) {
	// TODO Verify data
	LOG(INFO) << "Handshake finalised for " << uri_;
	_connected();
}

void Socket::dispatchReturn(const std::string &d) {
	auto unpacked = msgpack::unpack(d.data(), d.size());
	Dispatcher::response_t the_result;
	unpacked.get().convert(the_result);

	if (std::get<0>(the_result) != 1) {
		LOG(ERROR) << "Bad RPC return message";
		return;
	}

	auto &&id = std::get<1>(the_result);
	//auto &&err = std::get<2>(the_result);
	auto &&res = std::get<3>(the_result);
	
	// TODO Handle error reporting...
	
	if (callbacks_.count(id) > 0) {
		LOG(INFO) << "Received return RPC value";
		callbacks_[id](res);
		callbacks_.erase(id);
	} else {
		LOG(ERROR) << "Missing RPC callback for result";
	}
}

void Socket::onConnect(std::function<void(Socket&)> f) {
		f(*this);
	} else {
		connect_handlers_.push_back(f);
	}
}

void Socket::_connected() {
	for (auto h : connect_handlers_) {
		h(*this);
	}
	//connect_handlers_.clear();
}

void Socket::bind(uint32_t service, std::function<void(Socket&,
			const std::string&)> func) {
	if (handlers_.count(service) == 0) {
		handlers_[service] = func;
	} else {
		LOG(ERROR) << "Message service " << service << " already bound";
	}
}

int Socket::send(uint32_t service, const std::string &data) {
	ftl::net::Header h;
	h.size = data.size()+4;
	h.service = service;
	
	iovec vec[2];
	vec[0].iov_base = &h;
	vec[0].iov_len = sizeof(h);
	vec[1].iov_base = const_cast<char*>(data.data());
	vec[1].iov_len = data.size();
	
	::writev(sock_, &vec[0], 2);
int Socket::send2(uint32_t service, const std::string &data1, const std::string &data2) {
	ftl::net::Header h;
	h.size = data1.size()+4+data2.size();
	h.service = service;
	
	iovec vec[3];
	vec[0].iov_base = &h;
	vec[0].iov_len = sizeof(h);
	vec[1].iov_base = const_cast<char*>(data1.data());
	vec[1].iov_len = data1.size();
	vec[2].iov_base = const_cast<char*>(data2.data());
	vec[2].iov_len = data2.size();
	
	::writev(sock_, &vec[0], 3);
	std::cerr << "DESTROYING SOCKET" << std::endl;
	close();
	
	// Delete socket buffer
	if (buffer_) delete [] buffer_;
	buffer_ = NULL;