diff --git a/.vscode/settings.json b/.vscode/settings.json index 667db3fa5bb417969e810dc3dff8ddf1fc2324c3..eb0d177e2d04a38167d9679e93b1fec88be559c1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -96,8 +96,10 @@ "__locale": "cpp", "ios": "cpp", "locale": "cpp", - "strstream": "cpp" + "strstream": "cpp", + "hash_map": "cpp", + "hash_set": "cpp", + "slist": "cpp" }, - "cmake.cmakePath": "cmake", "cmake.configureOnOpen": true } diff --git a/src/protocol/websocket.cpp b/src/protocol/websocket.cpp index 6a157f095cc33dfdb5038f044903c7352d8161cf..f5160ef1f4859054bd717c1f8973f63e7f9ab99f 100644 --- a/src/protocol/websocket.cpp +++ b/src/protocol/websocket.cpp @@ -7,7 +7,9 @@ #include <string> #include <unordered_map> #include <algorithm> + #include "websocket.hpp" + #include <ftl/lib/loguru.hpp> #include <ftl/utility/base64.hpp> @@ -17,6 +19,9 @@ using uchar = unsigned char; #ifdef HAVE_GNUTLS #include <gnutls/crypto.h> +// TODO: Duplicate code +#include "../quic/src/websocket.hpp" + inline uint32_t secure_rnd() { uint32_t rnd; gnutls_rnd(GNUTLS_RND_NONCE, &rnd, sizeof(uint32_t)); @@ -176,21 +181,7 @@ void ws_parse(char *data, size_t len, wsheader_type *ws) { } -int getPort(const ftl::URI &uri) { - auto port = uri.getPort(); - if (port == 0) { - if (uri.getScheme() == URI::scheme_t::SCHEME_WS) { - port = 80; - } else if (uri.getScheme() == URI::scheme_t::SCHEME_WSS) { - port = 443; - } else { - throw FTL_Error("Bad WS uri:" + uri.to_string()); - } - } - - return port; -} //////////////////////////////////////////////////////////////////////////////// @@ -199,90 +190,46 @@ WebSocketBase<SocketT>::WebSocketBase() {} template<typename SocketT> void WebSocketBase<SocketT>::connect(const ftl::URI& uri, int timeout) { - int port = getPort(uri); + int port = get_websocket_port(uri); // connect via TCP/TLS if (!SocketT::connect(uri.getHost(), port, timeout)) { throw FTL_Error("WS: connect() failed"); } - std::string http = ""; - int status; - int i; - char line[256]; - - http += "GET " + uri.getPath() + " HTTP/1.1\r\n"; - if (port == 80) { - http += "Host: " + uri.getHost() + "\r\n"; - } else { - // TODO(Seb): is this correct when connecting over TLS - http += "Host: " + uri.getHost() + ":" - + std::to_string(port) + "\r\n"; - } - - if (uri.hasUserInfo()) { - http += "Authorization: Basic "; - http += base64_encode(uri.getUserInfo()) + "\r\n"; - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization - if (uri.getProtocol() != URI::scheme_t::SCHEME_WSS) { - DLOG(WARNING) << "HTTP Basic Auth is being sent without TLS"; - } - } - - http += "Upgrade: websocket\r\n"; - http += "Connection: Upgrade\r\n"; - http += "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"; - http += "Sec-WebSocket-Version: 13\r\n"; - http += "\r\n"; + auto http = create_websocket_upgrade_http_request(uri); int rc = SocketT::send(http.c_str(), static_cast<int>(http.length())); if (rc != static_cast<int>(http.length())) { - throw FTL_Error("Could not send Websocket http request... (" - + std::to_string(rc) + ", " - + std::to_string(errno) + ")\n" + http); + throw FTL_Error("Could not send Websocket http request... (send returned: " + + std::to_string(rc) + ", data: " + + std::to_string(errno) + ")\n" + http); } - for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) { - if (SocketT::recv(line + i, 1) == 0) { + char headers[1024]; + memset(headers, 0, sizeof(headers)); + int i; + + for (i = 0; i < 4 || ( i < int(sizeof(headers) - 1) + && std::string_view(headers, i).substr(i - 4, 4) != "\r\n\r\n"); + ++i) { + if (SocketT::recv(headers + i, 1) == 0) { throw FTL_Error("Connection closed by remote"); } } + headers[i] = 0; - line[i] = 0; - if (i == 255) { - throw FTL_Error("Got invalid status line connecting to: " + uri.getHost()); + if (i == sizeof(headers)) { + throw FTL_Error("Received too reply from server (headers) (" + uri.getHost() + ")"); } - if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101) { - throw FTL_Error("ERROR: Got bad status connecting to: " - + uri.getHost() + ": " + line); - } - - std::unordered_map<std::string, std::string> headers; - while (true) { - for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) { - if (SocketT::recv(line+i, 1) == 0) { - throw FTL_Error("Connection closed by remote"); - } - } - if (line[0] == '\r' && line[1] == '\n') { break; } - - // Split the headers into a map for checking - line[i] = 0; - const std::string cppline(line); - const auto ix = cppline.find(":"); - const auto label = cppline.substr(0, ix); - const auto value = cppline.substr(ix + 2, cppline.size() - ix - 4); - headers[label] = value; + auto status = server_accepted_websocket_upgrade(std::string_view(headers, i)); + if (status == ResponseStatus::kFailure) { + throw FTL_Error("Failed to connect to websocket on " + uri.to_string()); + } + if (status == ResponseStatus::kIncomplete || ResponseStatus::kIncomplete) { + throw FTL_Error("Connection closed(?) complete headers"); } - - // Validate some of the headers - if (headers.count("Connection") == 0 || headers.at("Connection") != "upgrade") - throw FTL_Error("Missing WS connection header"); - if (headers.count("Upgrade") == 0 || headers.at("Upgrade") != "websocket") - throw FTL_Error("Missing WS Upgrade"); - if (headers.count("Sec-WebSocket-Accept") == 0) - throw FTL_Error("Missing WS accept header"); } template<typename SocketT> diff --git a/src/quic/src/msquic/connection.cpp b/src/quic/src/msquic/connection.cpp index 1628a3e41ff8257fd93b741925752957b64a1ed2..3d3d175cec8230648a12f5c7b43e4c8e529f7081 100644 --- a/src/quic/src/msquic/connection.cpp +++ b/src/quic/src/msquic/connection.cpp @@ -98,6 +98,8 @@ MsQuicConnectionPtr MsQuicConnection::Connect(IMsQuicConnectionHandler* Observer auto Connection = MsQuicConnectionPtr(new MsQuicConnection(ObserverIn, MsQuic)); Connection->Self = Connection; + CHECK(Port != 0) << "No port passed to MsQuic"; // TODO: Pass default port; MsQuic error message not as clear. + CHECK_QUIC(MsQuic->Api->ConnectionOpen( MsQuic->hRegistration, MsQuicConnection::EventHandler, diff --git a/src/quic/src/quic_peer.hpp b/src/quic/src/quic_peer.hpp index a8739d084657c8239135154790503189e94e0745..01869c72055c6ea10266ea51a814d4b3eedc132f 100644 --- a/src/quic/src/quic_peer.hpp +++ b/src/quic/src/quic_peer.hpp @@ -10,6 +10,8 @@ #include <deque> #include <array> +#include <ftl/uri.hpp> + #include "websocket.hpp" namespace beyond_impl diff --git a/src/quic/src/quic_universe.cpp b/src/quic/src/quic_universe.cpp index 8bf7ac8e80dad7f47c784ce82bc58ef3fd336cbb..abb4b2bd10e65ac5cca9e11f4a3e6a1531157692 100644 --- a/src/quic/src/quic_universe.cpp +++ b/src/quic/src/quic_universe.cpp @@ -123,6 +123,46 @@ std::vector<ftl::URI> QuicUniverseImpl::GetListeningUris() return {}; } +// Workaround until HTTP server supports websockets over quic natively +struct WebsocketHandshake : public IMsQuicStreamHandler { + std::vector<uint8_t> buffer; + std::promise<std::string_view> reply; + + std::future<std::string_view> future() { return reply.get_future(); } + + void OnData(MsQuicStream* stream, nonstd::span<const QUIC_BUFFER> data) override { + uint32_t total = 0; + for (auto in : data) { buffer.insert(buffer.end(), in.Buffer, in.Buffer + in.Length); total += in.Length; } + auto view = std::string_view((char*)buffer.data(), buffer.size()); + auto end = view.find("\r\n\r\n"); + if (end == std::string::npos) { stream->Consume(total); } + else { + stream->EnableRecv(false); + view = view.substr(0, end); + stream->Consume(view.size()); + reply.set_value(view); + } + } +}; + +void PerformWebsocketHandshake(const ftl::URI& uri, MsQuicStream* stream) { + auto https_uri = uri.to_string(); + https_uri.replace(0, 4, "https"); + + auto request = create_websocket_upgrade_http_request(ftl::URI(https_uri)); + QUIC_BUFFER buffer{ uint32_t(request.size()), (uint8_t*)request.data() }; + WebsocketHandshake handler; + auto future = handler.future(); + stream->SetStreamHandler(&handler); + stream->Write({&buffer, 1}, nullptr); + if (future.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { + throw FTL_Error("Webservice timed out"); + } + if (!server_accepted_websocket_upgrade(future.get())) { + throw FTL_Error("Webservice handshake failed over QUIC"); + } +} + PeerPtr QuicUniverseImpl::Connect(const ftl::URI& uri, bool is_webservice) { UNIQUE_LOCK_N(Lk, ConnectionMtx); @@ -148,10 +188,19 @@ PeerPtr QuicUniverseImpl::Connect(const ftl::URI& uri, bool is_webservice) } else if (uri.getScheme() == ftl::URI::SCHEME_FTL_QUIC) { LOG(INFO) << "[QUIC] Connecting to: " << uri.to_string() << (is_webservice ? " (webservice)" : ""); - auto Connection = Client->Connect(this, uri.getHost(), uri.getPort()); + auto port = uri.getPort(); + if (is_webservice && port == 0) { port = 9001; } + auto Connection = Client->Connect(this, uri.getHost(), port); auto Stream = Connection->OpenStream(); + auto* StreamPtr = Stream.get(); + if (is_webservice) { + PerformWebsocketHandshake(uri, StreamPtr); + } auto Ptr = std::make_unique<QuicPeerStream>(Connection.get(), std::move(Stream), net_, net_->dispatcher_()); - if (is_webservice) { Ptr->set_type(ftl::protocol::NodeType::kWebService); } + if (is_webservice) { + Ptr->set_type(ftl::protocol::NodeType::kWebService); + StreamPtr->EnableRecv(); // Disabled after handhsake + } // Expected to be called by net::Universe, which will insert peer to its internal list CHECK(Ptr->getType() == (is_webservice ? ftl::protocol::NodeType::kWebService : ftl::protocol::NodeType::kNode)); Connections.push_back(std::move(Connection)); diff --git a/src/quic/src/websocket.hpp b/src/quic/src/websocket.hpp index 8aba324100fd6c5f1bd660a56cebcbd031df0089..80104d57ab003779dea50980d9e04f7e9857ede0 100644 --- a/src/quic/src/websocket.hpp +++ b/src/quic/src/websocket.hpp @@ -326,3 +326,132 @@ public: return NextBuffer(); } }; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// FIXME: move implementations to shared module + +#include <ftl/uri.hpp> + +enum ResponseStatus { + kIncomplete, + kSuccess, + kFailure + }; + +inline int get_websocket_port(const ftl::URI& uri, int default_port = 0) { + auto port = uri.getPort(); + + if (port == 0) { + if (uri.getScheme() == ftl::URI::scheme_t::SCHEME_WS) { + port = 80; + } else if (uri.getScheme() == ftl::URI::scheme_t::SCHEME_WSS) { + port = 443; + } else if (default_port == 0) { + throw std::runtime_error("Bad websocket uri: " + uri.to_string()); + } else { + return default_port; + } + } + + return port; +} + +inline std::string create_websocket_upgrade_http_request(const ftl::URI& uri) { + int port = get_websocket_port(uri, 443); + std::string http = ""; + + http += "GET " + uri.getPath() + " HTTP/1.1\r\n"; + if ((port == 80) || (port == 443)) { + http += "Host: " + uri.getHost() + "\r\n"; + } else { + http += "Host: " + uri.getHost() + ":" + std::to_string(port) + "\r\n"; + } + + if (uri.hasUserInfo()) { + http += "Authorization: Basic "; + http += base64_encode(uri.getUserInfo()) + "\r\n"; + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization + if (uri.getProtocol() != ftl::URI::scheme_t::SCHEME_WSS) { + DLOG(WARNING) << "HTTP Basic Auth is being sent without TLS"; + } + } + + http += "Upgrade: websocket\r\n"; + http += "Connection: Upgrade\r\n"; + http += "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"; + http += "Sec-WebSocket-Version: 13\r\n"; + http += "\r\n"; + return http; +} + +inline ResponseStatus server_accepted_websocket_upgrade(std::string_view response) { + if (response.size() < 4) { return ResponseStatus::kIncomplete; } + if (response.substr(response.size() - 4, 4) != "\r\n\r\n") { + LOG(WARNING) << "(websocket) Headers did not end in \\r\\n\\r\\n"; + return ResponseStatus::kFailure; + } + + auto response_full = response; + size_t pos = std::string::npos; + if ((pos = response.find("\r\n")) != std::string::npos) { + auto status = response.substr(0, pos); + response = response.substr(pos + 2); + + auto status_line = status; + std::vector<std::string_view> tokens; + while ((pos = status.find(" ")) != std::string::npos) { + tokens.push_back(status.substr(0, pos)); + status = status.substr(pos + 1); + } + if (tokens.size() < 2) { + LOG(WARNING) << "(websocket) Did not receive status code in " << status_line; + return ResponseStatus::kFailure; + } + if (tokens[1] != "101") { + LOG(WARNING) << "(websocket) Server replied: " << status_line + << ", expected HTTP/1.1 101 Switching Protocols"; + return ResponseStatus::kFailure; + } + } + + // Minimal header parsing, split by lines and then by ':' + std::unordered_map<std::string, std::string> headers; + auto trim = [](const std::string_view& s) { + size_t start = 0; + size_t end = s.size(); + for (; start < s.size() && std::isspace(s[start]); start++); + for (end = start; end < s.size() && !std::isspace(s[start]); end++) {} + return s.substr(start, end - start); + }; + + pos = std::string::npos; + size_t ix = std::string::npos; + while ((pos = response.find("\r\n")) != std::string::npos) { + auto line = response.substr(0, pos); + response = response.substr(pos + 2); + + // Split the headers into a map for checking (ignore if does not have key:value format) + if ((ix = line.find(":")) != std::string::npos) { + const auto label = line.substr(0, ix); + const auto value = line.substr(ix + 1, line.size() - ix); + headers[std::string(label)] = std::string(trim(value)); // Case sensitivity? + } + } + + // Validate some of the headers + int bad_header_count = 0; + if (headers.count("Connection") == 0 || headers.at("Connection") != "upgrade") { + LOG(WARNING) << "Missing connection header for websocket"; bad_header_count++; + } + if (headers.count("Upgrade") == 0 || headers.at("Upgrade") != "websocket") { + LOG(WARNING) << "Missing upgrade header for websocket"; bad_header_count++; + } + if (headers.count("Sec-WebSocket-Accept") == 0) { + LOG(WARNING) << "Missing accept header for websocket"; bad_header_count++; + } + if (bad_header_count != 0 ) { + LOG(WARNING) << "Got: " << response_full; + return ResponseStatus::kFailure; + } + return ResponseStatus::kSuccess; +}