Skip to content
Snippets Groups Projects
Commit 1dacda50 authored by Nicolas Pope's avatar Nicolas Pope
Browse files

Merge branch 'bug/eagain' into 'main'

Handle GNUTLS_E_AGAIN in recv

See merge request beyondaka/beyond-protocol!77
parents 419ce850 9e04f809
No related branches found
No related tags found
No related merge requests found
...@@ -19,8 +19,10 @@ using ftl::URI; ...@@ -19,8 +19,10 @@ using ftl::URI;
// SocketConnection //////////////////////////////////////////////////////////// // SocketConnection ////////////////////////////////////////////////////////////
SocketConnection::~SocketConnection() { SocketConnection::~SocketConnection() {
if (sock_.is_valid()) {
sock_.close(); sock_.close();
} }
}
socket_t SocketConnection::fd() { return sock_.fd(); } socket_t SocketConnection::fd() { return sock_.fd(); }
......
...@@ -12,12 +12,34 @@ ...@@ -12,12 +12,34 @@
#include <iomanip> #include <iomanip>
#include <string> #include <string>
#include <thread>
#include <chrono>
#include <ftl/exception.hpp> #include <ftl/exception.hpp>
#include <ftl/lib/loguru.hpp> #include <ftl/lib/loguru.hpp>
using ftl::net::internal::Connection_TLS; using ftl::net::internal::Connection_TLS;
using uchar = unsigned char; using uchar = unsigned char;
void log_gnutls(int level, const char* msg) {
// msg contains newline
auto str = std::string(msg);
str.pop_back();
LOG(INFO) << "gnutls: %s" << str;
}
/*
static bool gnutls_debug_logging_enabled = false;
void ftl::net::enableTlsDebugLogging(int level)
{
gnutls_global_set_log_level(level);
if (gnutls_debug_logging_enabled) { return; }
gnutls_debug_logging_enabled = true;
gnutls_global_set_log_function(&log_gnutls);
}
*/
/** get basic certificate info: Distinguished Name (DN), issuer DN, /** get basic certificate info: Distinguished Name (DN), issuer DN,
* certificate fingerprint */ * certificate fingerprint */
std::string get_cert_info(gnutls_session_t session) { std::string get_cert_info(gnutls_session_t session) {
...@@ -115,15 +137,28 @@ bool Connection_TLS::close() { ...@@ -115,15 +137,28 @@ bool Connection_TLS::close() {
} }
ssize_t Connection_TLS::recv(char *buffer, size_t len) { ssize_t Connection_TLS::recv(char *buffer, size_t len) {
auto recvd = gnutls_record_recv(session_, buffer, len); int tries = 30;
while (tries-- > 0) {
ssize_t recvd = gnutls_record_recv(session_, buffer, len);
if (recvd == 0) { if (recvd == 0) {
DLOG(1) << "recv returned 0 (buffer size " << len << "), closing connection"; DLOG(1) << "recv returned 0 (buffer size " << len << "), closing connection";
close(); close();
} }
if (recvd > 0) {
return recvd;
}
if (recvd == GNUTLS_E_AGAIN) {
std::this_thread::sleep_for(std::chrono::nanoseconds(100));
continue;
}
return check_gnutls_error_(recvd); return check_gnutls_error_(recvd);
} }
return -1;
}
ssize_t Connection_TLS::send(const char* buffer, size_t len) { ssize_t Connection_TLS::send(const char* buffer, size_t len) {
return check_gnutls_error_(gnutls_record_send(session_, buffer, len)); return check_gnutls_error_(gnutls_record_send(session_, buffer, len));
} }
......
...@@ -23,6 +23,7 @@ typedef SSIZE_T ssize_t; ...@@ -23,6 +23,7 @@ typedef SSIZE_T ssize_t;
namespace ftl { namespace ftl {
namespace net { namespace net {
namespace internal { namespace internal {
class Connection_TLS : public Connection_TCP { class Connection_TLS : public Connection_TCP {
......
...@@ -14,6 +14,11 @@ ...@@ -14,6 +14,11 @@
#include "socket_linux.cpp" #include "socket_linux.cpp"
#endif #endif
Socket::~Socket() {
LOG_IF(ERROR, !(is_valid() || is_closed())) << "socket wrapper destroyed before socket is closed";
DCHECK(is_valid() || is_closed());
}
bool Socket::is_open() { return status_ == STATUS::OPEN; } bool Socket::is_open() { return status_ == STATUS::OPEN; }
bool Socket::is_closed() { return status_ == STATUS::CLOSED; } bool Socket::is_closed() { return status_ == STATUS::CLOSED; }
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <ftl/exception.hpp> #include <ftl/exception.hpp>
#include <ftl/lib/loguru.hpp> #include <ftl/lib/loguru.hpp>
#pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Ws2_32.lib")
// winsock2 documentation // winsock2 documentation
...@@ -41,28 +40,67 @@ bool ftl::net::internal::resolve_inet_address(const std::string& hostname, int p ...@@ -41,28 +40,67 @@ bool ftl::net::internal::resolve_inet_address(const std::string& hostname, int p
class WinSock { class WinSock {
public: public:
WinSock() { WinSock() {
CHECK(!winsock_initialized_);
if (WSAStartup(MAKEWORD(1, 1), &wsaData_) != 0) { if (WSAStartup(MAKEWORD(1, 1), &wsaData_) != 0) {
LOG(FATAL) << "could not initialize sockets"; LOG(FATAL) << "WSAStartup() failed";
// is it possible to retry/recover? // is it possible to retry/recover?
} }
winsock_initialized_ = true;
LOG(INFO) << "WSAStartup() done";
}
~WinSock() {
// is this safe in DLL? Documentation warns of deadlock if called from
// DllMain() (not clear if static initialization ok or not).
CHECK(winsock_initialized_);
if (WSACleanup() != 0) {
LOG(FATAL) << "WSACleanup() failed: " << getErrorMsg(WSAGetLastError());
}
winsock_initialized_ = false;
LOG(INFO) << "WSACleanup() done";
}
static bool isInitialized() { return winsock_initialized_ ; }
static std::string getErrorMsg(int code) {
wchar_t* s = NULL;
FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, NULL);
if (!s) {
return "Unknown";
}
std::wstring ws(s);
std::string msg(ws.begin(), ws.end());
LocalFree(s);
return msg;
} }
private: private:
static bool winsock_initialized_;
WSAData wsaData_; WSAData wsaData_;
}; };
// Do WSAStartup as static initialisation so no threads active. bool WinSock::winsock_initialized_ = false;
// Do WSAStartup as sttic initialisation so no threads active.
static WinSock winSock; static WinSock winSock;
Socket::Socket(int domain, int type, int protocol) : Socket::Socket(int domain, int type, int protocol) :
status_(STATUS::UNCONNECTED), fd_(-1), family_(domain) { status_(STATUS::INVALID), fd_(INVALID_SOCKET), family_(domain) {
CHECK(WinSock::isInitialized());
fd_ = ::socket(domain, type, protocol); fd_ = ::socket(domain, type, protocol);
if (fd_ == INVALID_SOCKET) { if (fd_ == INVALID_SOCKET) {
err_ = WSAGetLastError(); err_ = WSAGetLastError();
throw FTL_Error("socket() failed" + get_error_string()); throw FTL_Error("socket() failed" + get_error_string());
} }
status_ = STATUS::UNCONNECTED;
} }
bool Socket::is_valid() { bool Socket::is_valid() {
...@@ -70,16 +108,19 @@ bool Socket::is_valid() { ...@@ -70,16 +108,19 @@ bool Socket::is_valid() {
} }
ssize_t Socket::recv(char* buffer, size_t len, int flags) { ssize_t Socket::recv(char* buffer, size_t len, int flags) {
CHECK(WinSock::isInitialized());
auto err = ::recv(fd_, buffer, len, flags); auto err = ::recv(fd_, buffer, len, flags);
if (err < 0) { err_ = WSAGetLastError(); } if (err < 0) { err_ = WSAGetLastError(); }
return err; return err;
} }
ssize_t Socket::send(const char* buffer, size_t len, int flags) { ssize_t Socket::send(const char* buffer, size_t len, int flags) {
CHECK(WinSock::isInitialized());
return ::send(fd_, buffer, len, flags); return ::send(fd_, buffer, len, flags);
} }
ssize_t Socket::writev(const struct iovec* iov, int iovcnt) { ssize_t Socket::writev(const struct iovec* iov, int iovcnt) {
CHECK(WinSock::isInitialized());
std::vector<WSABUF> wsabuf(iovcnt); std::vector<WSABUF> wsabuf(iovcnt);
for (int i = 0; i < iovcnt; i++) { for (int i = 0; i < iovcnt; i++) {
...@@ -94,6 +135,9 @@ ssize_t Socket::writev(const struct iovec* iov, int iovcnt) { ...@@ -94,6 +135,9 @@ ssize_t Socket::writev(const struct iovec* iov, int iovcnt) {
} }
int Socket::bind(const SocketAddress& addr) { int Socket::bind(const SocketAddress& addr) {
CHECK(WinSock::isInitialized());
CHECK(status_ == STATUS::UNCONNECTED);
int retval = ::bind(fd_, reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)); int retval = ::bind(fd_, reinterpret_cast<const sockaddr*>(&addr), sizeof(addr));
if (retval == 0) { if (retval == 0) {
status_ = STATUS::OPEN; status_ = STATUS::OPEN;
...@@ -108,6 +152,7 @@ int Socket::bind(const SocketAddress& addr) { ...@@ -108,6 +152,7 @@ int Socket::bind(const SocketAddress& addr) {
} }
int Socket::listen(int backlog) { int Socket::listen(int backlog) {
CHECK(WinSock::isInitialized());
int retval = ::listen(fd_, backlog); int retval = ::listen(fd_, backlog);
if (retval == 0) { if (retval == 0) {
return 0; return 0;
...@@ -121,6 +166,9 @@ int Socket::listen(int backlog) { ...@@ -121,6 +166,9 @@ int Socket::listen(int backlog) {
} }
Socket Socket::accept(SocketAddress& addr) { Socket Socket::accept(SocketAddress& addr) {
CHECK(WinSock::isInitialized());
CHECK(status_ == STATUS::OPEN);
Socket socket; Socket socket;
int addrlen = sizeof(addr); int addrlen = sizeof(addr);
int retval = ::accept(fd_, reinterpret_cast<sockaddr*>(&addr), &addrlen); int retval = ::accept(fd_, reinterpret_cast<sockaddr*>(&addr), &addrlen);
...@@ -138,6 +186,7 @@ Socket Socket::accept(SocketAddress& addr) { ...@@ -138,6 +186,7 @@ Socket Socket::accept(SocketAddress& addr) {
} }
int Socket::connect(const SocketAddress& address) { int Socket::connect(const SocketAddress& address) {
CHECK(WinSock::isInitialized());
int err = 0; int err = 0;
if (status_ != STATUS::UNCONNECTED) { if (status_ != STATUS::UNCONNECTED) {
return -1; return -1;
...@@ -163,19 +212,35 @@ int Socket::connect(const SocketAddress& address) { ...@@ -163,19 +212,35 @@ int Socket::connect(const SocketAddress& address) {
} }
int Socket::connect(const SocketAddress& address, int timeout) { int Socket::connect(const SocketAddress& address, int timeout) {
CHECK(WinSock::isInitialized());
// connect() blocks on Windows // connect() blocks on Windows
return connect(address); return connect(address);
} }
bool Socket::close() { bool Socket::close() {
bool retval = true; CHECK(status_ != STATUS::CLOSED) << "socket status_: " << status_;
if (is_valid() && status_ != STATUS::CLOSED) { CHECK(fd_ != INVALID_SOCKET) << "not a valid socket";
auto fd = fd_;
status_ = STATUS::CLOSED; status_ = STATUS::CLOSED;
retval = closesocket(fd_) == 0;
err_ = errno;
}
fd_ = INVALID_SOCKET; fd_ = INVALID_SOCKET;
return retval;
if (!WinSock::isInitialized()) {
// Constructor would fail if WinSock was not started. It is possible
// that ~WinSock() is called before all connections are closed at
// program exit.
LOG(ERROR) << "WinSock stopped before socket was closed";
return false;
}
auto retval = closesocket(fd);
if (retval != 0) {
err_ = WSAGetLastError();
LOG(ERROR) << "closesocket() returned " << retval << ": " << WinSock::getErrorMsg(err_);
}
return (retval == 0);
} }
...@@ -192,16 +257,7 @@ void Socket::set_blocking(bool val) { ...@@ -192,16 +257,7 @@ void Socket::set_blocking(bool val) {
} }
std::string Socket::get_error_string(int code) { std::string Socket::get_error_string(int code) {
wchar_t* s = NULL; return WinSock::getErrorMsg((code == 0) ? err_ : code);
FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, (code != 0) ? code : err_, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, NULL);
if (!s) {
return "Unknown";
}
std::wstring ws(s);
std::string msg(ws.begin(), ws.end());
LocalFree(s);
return msg;
} }
bool Socket::is_fatal(int code) { bool Socket::is_fatal(int code) {
......
...@@ -30,6 +30,7 @@ class Socket { ...@@ -30,6 +30,7 @@ class Socket {
public: public:
Socket(int domain, int type, int protocol); Socket(int domain, int type, int protocol);
~Socket();
bool is_valid(); bool is_valid();
bool is_open(); bool is_open();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment