diff --git a/src/protocol/connection.cpp b/src/protocol/connection.cpp index 0ef451a5371d613396cffdd62d22f5eb7a5c76e1..52136952157bba13a680178ba62310ba4e3ea0bc 100644 --- a/src/protocol/connection.cpp +++ b/src/protocol/connection.cpp @@ -19,7 +19,9 @@ using ftl::URI; // SocketConnection //////////////////////////////////////////////////////////// SocketConnection::~SocketConnection() { - sock_.close(); + if (sock_.is_valid()) { + sock_.close(); + } } socket_t SocketConnection::fd() { return sock_.fd(); } diff --git a/src/protocol/tls.cpp b/src/protocol/tls.cpp index abb8af32367e0f890c69895037acd8faef9ffebf..c694ae3da935e0f69155c9c23ad116c510c8005d 100644 --- a/src/protocol/tls.cpp +++ b/src/protocol/tls.cpp @@ -12,12 +12,34 @@ #include <iomanip> #include <string> +#include <thread> +#include <chrono> + #include <ftl/exception.hpp> #include <ftl/lib/loguru.hpp> using ftl::net::internal::Connection_TLS; using uchar = unsigned char; +void log_gnutls(int level, const char* msg) { + // msg contains newline + auto str = std::string(msg); + str.pop_back(); + LOG(INFO) << "gnutls: %s" << str; +} + +/* +static bool gnutls_debug_logging_enabled = false; + +void ftl::net::enableTlsDebugLogging(int level) +{ + gnutls_global_set_log_level(level); + if (gnutls_debug_logging_enabled) { return; } + gnutls_debug_logging_enabled = true; + gnutls_global_set_log_function(&log_gnutls); +} +*/ + /** get basic certificate info: Distinguished Name (DN), issuer DN, * certificate fingerprint */ std::string get_cert_info(gnutls_session_t session) { @@ -115,13 +137,26 @@ bool Connection_TLS::close() { } ssize_t Connection_TLS::recv(char *buffer, size_t len) { - auto recvd = gnutls_record_recv(session_, buffer, len); - if (recvd == 0) { - DLOG(1) << "recv returned 0 (buffer size " << len << "), closing connection"; - close(); + int tries = 30; + while (tries-- > 0) { + ssize_t recvd = gnutls_record_recv(session_, buffer, len); + if (recvd == 0) { + DLOG(1) << "recv returned 0 (buffer size " << len << "), closing connection"; + close(); + } + + if (recvd > 0) { + return recvd; + } + if (recvd == GNUTLS_E_AGAIN) { + std::this_thread::sleep_for(std::chrono::nanoseconds(100)); + continue; + } + + return check_gnutls_error_(recvd); } - return check_gnutls_error_(recvd); + return -1; } ssize_t Connection_TLS::send(const char* buffer, size_t len) { diff --git a/src/protocol/tls.hpp b/src/protocol/tls.hpp index ca206301b3a4c16379d262e09b96ab09c9a9bbeb..c54d092ddb70b6dc1a9cdc82d1f3dc573eab6313 100644 --- a/src/protocol/tls.hpp +++ b/src/protocol/tls.hpp @@ -23,6 +23,7 @@ typedef SSIZE_T ssize_t; namespace ftl { namespace net { + namespace internal { class Connection_TLS : public Connection_TCP { diff --git a/src/socket/socket.cpp b/src/socket/socket.cpp index ffdf4c3c34ab6612439f133cb9ace34165ad48ca..7ff465928bea7d0aac0385ed2cef477ae5ac8b39 100644 --- a/src/socket/socket.cpp +++ b/src/socket/socket.cpp @@ -14,6 +14,11 @@ #include "socket_linux.cpp" #endif +Socket::~Socket() { + LOG_IF(ERROR, !(is_valid() || is_closed())) << "socket wrapper destroyed before socket is closed"; + DCHECK(is_valid() || is_closed()); +} + bool Socket::is_open() { return status_ == STATUS::OPEN; } bool Socket::is_closed() { return status_ == STATUS::CLOSED; } diff --git a/src/socket/socket_windows.cpp b/src/socket/socket_windows.cpp index dabcc4f97e5f1e91cd38849687f0aed6bd4a0c2f..840f1087ccb6b6398b46bba87a01a275edc4f762 100644 --- a/src/socket/socket_windows.cpp +++ b/src/socket/socket_windows.cpp @@ -14,7 +14,6 @@ #include <ftl/exception.hpp> #include <ftl/lib/loguru.hpp> - #pragma comment(lib, "Ws2_32.lib") // winsock2 documentation @@ -41,28 +40,67 @@ bool ftl::net::internal::resolve_inet_address(const std::string& hostname, int p class WinSock { public: + WinSock() { + CHECK(!winsock_initialized_); if (WSAStartup(MAKEWORD(1, 1), &wsaData_) != 0) { - LOG(FATAL) << "could not initialize sockets"; + LOG(FATAL) << "WSAStartup() failed"; // is it possible to retry/recover? } + winsock_initialized_ = true; + LOG(INFO) << "WSAStartup() done"; + } + + ~WinSock() { + // is this safe in DLL? Documentation warns of deadlock if called from + // DllMain() (not clear if static initialization ok or not). + CHECK(winsock_initialized_); + if (WSACleanup() != 0) { + LOG(FATAL) << "WSACleanup() failed: " << getErrorMsg(WSAGetLastError()); + } + winsock_initialized_ = false; + LOG(INFO) << "WSACleanup() done"; + } + + static bool isInitialized() { return winsock_initialized_ ; } + + static std::string getErrorMsg(int code) { + wchar_t* s = NULL; + + FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, NULL); + + if (!s) { + return "Unknown"; + } + + std::wstring ws(s); + std::string msg(ws.begin(), ws.end()); + LocalFree(s); + return msg; } private: + static bool winsock_initialized_; WSAData wsaData_; }; -// Do WSAStartup as static initialisation so no threads active. +bool WinSock::winsock_initialized_ = false; + +// Do WSAStartup as sttic initialisation so no threads active. static WinSock winSock; Socket::Socket(int domain, int type, int protocol) : - status_(STATUS::UNCONNECTED), fd_(-1), family_(domain) { + status_(STATUS::INVALID), fd_(INVALID_SOCKET), family_(domain) { + + CHECK(WinSock::isInitialized()); fd_ = ::socket(domain, type, protocol); if (fd_ == INVALID_SOCKET) { err_ = WSAGetLastError(); throw FTL_Error("socket() failed" + get_error_string()); } + status_ = STATUS::UNCONNECTED; } bool Socket::is_valid() { @@ -70,16 +108,19 @@ bool Socket::is_valid() { } ssize_t Socket::recv(char* buffer, size_t len, int flags) { + CHECK(WinSock::isInitialized()); auto err = ::recv(fd_, buffer, len, flags); if (err < 0) { err_ = WSAGetLastError(); } return err; } ssize_t Socket::send(const char* buffer, size_t len, int flags) { + CHECK(WinSock::isInitialized()); return ::send(fd_, buffer, len, flags); } ssize_t Socket::writev(const struct iovec* iov, int iovcnt) { + CHECK(WinSock::isInitialized()); std::vector<WSABUF> wsabuf(iovcnt); for (int i = 0; i < iovcnt; i++) { @@ -94,6 +135,9 @@ ssize_t Socket::writev(const struct iovec* iov, int iovcnt) { } int Socket::bind(const SocketAddress& addr) { + CHECK(WinSock::isInitialized()); + CHECK(status_ == STATUS::UNCONNECTED); + int retval = ::bind(fd_, reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)); if (retval == 0) { status_ = STATUS::OPEN; @@ -108,6 +152,7 @@ int Socket::bind(const SocketAddress& addr) { } int Socket::listen(int backlog) { + CHECK(WinSock::isInitialized()); int retval = ::listen(fd_, backlog); if (retval == 0) { return 0; @@ -121,6 +166,9 @@ int Socket::listen(int backlog) { } Socket Socket::accept(SocketAddress& addr) { + CHECK(WinSock::isInitialized()); + CHECK(status_ == STATUS::OPEN); + Socket socket; int addrlen = sizeof(addr); int retval = ::accept(fd_, reinterpret_cast<sockaddr*>(&addr), &addrlen); @@ -138,6 +186,7 @@ Socket Socket::accept(SocketAddress& addr) { } int Socket::connect(const SocketAddress& address) { + CHECK(WinSock::isInitialized()); int err = 0; if (status_ != STATUS::UNCONNECTED) { return -1; @@ -163,19 +212,35 @@ int Socket::connect(const SocketAddress& address) { } int Socket::connect(const SocketAddress& address, int timeout) { + CHECK(WinSock::isInitialized()); // connect() blocks on Windows return connect(address); } bool Socket::close() { - bool retval = true; - if (is_valid() && status_ != STATUS::CLOSED) { - status_ = STATUS::CLOSED; - retval = closesocket(fd_) == 0; - err_ = errno; - } + CHECK(status_ != STATUS::CLOSED) << "socket status_: " << status_; + CHECK(fd_ != INVALID_SOCKET) << "not a valid socket"; + + auto fd = fd_; + status_ = STATUS::CLOSED; fd_ = INVALID_SOCKET; - return retval; + + if (!WinSock::isInitialized()) { + // Constructor would fail if WinSock was not started. It is possible + // that ~WinSock() is called before all connections are closed at + // program exit. + LOG(ERROR) << "WinSock stopped before socket was closed"; + return false; + } + + auto retval = closesocket(fd); + + if (retval != 0) { + err_ = WSAGetLastError(); + LOG(ERROR) << "closesocket() returned " << retval << ": " << WinSock::getErrorMsg(err_); + } + + return (retval == 0); } @@ -192,16 +257,7 @@ void Socket::set_blocking(bool val) { } std::string Socket::get_error_string(int code) { - wchar_t* s = NULL; - FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, (code != 0) ? code : err_, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, NULL); - if (!s) { - return "Unknown"; - } - std::wstring ws(s); - std::string msg(ws.begin(), ws.end()); - LocalFree(s); - return msg; + return WinSock::getErrorMsg((code == 0) ? err_ : code); } bool Socket::is_fatal(int code) { diff --git a/src/socketImpl.hpp b/src/socketImpl.hpp index 6fbdddeb6ef09bb49e672801cf06099ccc6ec69b..afd0f8d4356a331ce77429c55c53bfdd6e0d7d93 100644 --- a/src/socketImpl.hpp +++ b/src/socketImpl.hpp @@ -30,6 +30,7 @@ class Socket { public: Socket(int domain, int type, int protocol); + ~Socket(); bool is_valid(); bool is_open();