diff --git a/src/quic/src/msquic/stream.cpp b/src/quic/src/msquic/stream.cpp index 1f10f554271f2f3de0a21662e93225997802e39a..ec78f8a4f55236d8c2c08f2a696956578699539c 100644 --- a/src/quic/src/msquic/stream.cpp +++ b/src/quic/src/msquic/stream.cpp @@ -9,7 +9,6 @@ using namespace beyond_impl; - std::unique_ptr<MsQuicStream> MsQuicStream::FromRemotePeer(MsQuicContext* MsQuic, HQUIC hStream) { CHECK(MsQuic); @@ -63,6 +62,8 @@ QUIC_STATUS MsQuicStream::EventHandler(HQUIC hStream, void* Context, QUIC_STREAM MsQuicStream* Stream = static_cast<MsQuicStream*>(Context); auto* MsQuic = Stream->MsQuic; + auto Lock = std::unique_lock(Stream->Mtx); + switch (Event->Type) { case QUIC_STREAM_EVENT_START_COMPLETE: @@ -243,5 +244,6 @@ bool MsQuicStream::Write(nonstd::span<const QUIC_BUFFER> Buffers, void* Context, void MsQuicStream::SetStreamHandler(IMsQuicStreamHandler* ObserverIn) { + auto Lock = std::unique_lock(Mtx); Observer = ObserverIn; } diff --git a/src/quic/src/msquic/stream.hpp b/src/quic/src/msquic/stream.hpp index 2e0755326f865c3f454e37203e9f980ceb8dac4d..6bd507fcd5f1e2f26db15b96c4168c440953c408 100644 --- a/src/quic/src/msquic/stream.hpp +++ b/src/quic/src/msquic/stream.hpp @@ -47,7 +47,8 @@ public: */ void Abort(); - /** Inform source that reader is done with data. When less than received is consumed, . + /** Inform source that reader is done with data. When less than received is consumed, no further + * callbacks are called before EnableRecv() is called again. */ void Consume(int32_t BytesConsumed); @@ -58,6 +59,10 @@ public: * queued data. */ bool Write(nonstd::span<const QUIC_BUFFER> Buffers, void* Context, bool Delay = false); + /** Set stream handler. If handler is replaced once stream is opened, caller must ensure any pending writes + * are correctly handled (buffers not released before completion, new handler can process callbacks from + * previous writes) or that there are no pending writes. + */ void SetStreamHandler(IMsQuicStreamHandler*); /** Calls to Start() */ @@ -74,6 +79,7 @@ private: // races during StreamClose). HQUIC hStream; + std::mutex Mtx; // Necessary to allow changing observer after stream is started IMsQuicStreamHandler* Observer; std::atomic_int PendingSends; diff --git a/src/quic/src/quic_peer.cpp b/src/quic/src/quic_peer.cpp index e07e05b17e36bea7e2de28e972eec847111e2fe5..5d454e9908102041db4a4d0371e634effe057c68 100644 --- a/src/quic/src/quic_peer.cpp +++ b/src/quic/src/quic_peer.cpp @@ -6,14 +6,199 @@ #include "quic_websocket.cpp" -using namespace beyond_impl; - using ftl::protocol::NodeStatus; +#include <random> + +namespace beyond_impl +{ +/** WebSocket wrapper + * With RFC 9220 and HTTP/3 protocol this could be made to work directly with compatible webserver + */ +class WebSocketHandshake : public IMsQuicStreamHandler +{ + QuicPeerStream* peer_; + MsQuicStream* stream_; + + enum Status { kContinue, kError, kDone }; + std::map<std::string, std::string> headers_; + std::vector<char> http_reply_buffer_; + + std::string websocket_key_; + std::string http_req_; + QUIC_BUFFER http_req_buffer_; + + std::string prepare_http_request(const ftl::URI& uri) + { + std::string request; + auto generate_sec_websocket_key = []() + { + std::random_device rd; + std::mt19937 mt(rd()); + std::uniform_int_distribution<unsigned char> dist(32, 128); + std::array<unsigned char, 16> key; + for(auto& v : key) { v = dist(mt); } + return base64_encode(key.data(), key.size()); + }; + websocket_key_ = generate_sec_websocket_key(); + + request += "GET " + uri.getPath() + " HTTP/1.1\r\n"; + request += "Host: " + uri.getHost() + "\r\n"; + + if (uri.hasUserInfo()) { + request += "Authorization: Basic "; + request += base64_encode(uri.getUserInfo()) + "\r\n"; + } + + request += "Upgrade: websocket\r\n"; + request += "Connection: Upgrade\r\n"; + request += "Sec-WebSocket-Key: " + websocket_key_ + "\r\n"; + request += "Sec-WebSocket-Version: 13\r\n"; + request += "\r\n"; + + return request; + } + + Status process_headers(std::map<std::string, std::string>& headers) + { + if (headers.count("Connection") == 0 || headers.at("Connection") != "upgrade") + { + LOG(ERROR) << "[WebSocket] Wrong Connection header"; + return kError; + } + if (headers.count("Upgrade") == 0 || headers.at("Upgrade") != "websocket") + { + LOG(ERROR) << "[WebSocket] Wrong Upgrade header"; + return kError; + } + if (headers.count("Sec-WebSocket-Accept") == 0) + { + LOG(ERROR) << "[WebSocket] Missing Sec-WebSocket-Accept header"; + return kError; + } + + return kDone; + } + + Status process_recv(const std::string_view& line) + { + try { + if ((headers_.size() == 0)) + { + if (line.substr(0,9) != "HTTP/1.1 ") { return kError; } + auto code = atoi(line.substr(10).data()); + if (code != 101) { return kError; } + return kContinue; + } + + if (line == "\r\n") + { + // Returns either kError or kDone + return process_headers(headers_); + } + + uint32_t i = 1; + std::string_view key; + + // Find header name + for (; i < line.size(); i++) + { + if (line[i] == ':') + { + key = line.substr(0, i); + break; + } + } + // Return error if not a valid header + if (key.empty()) { return kError; } + + // Skip whitespace + for (; i < line.size() && std::isspace(line[i]); i++) + + headers_[std::string(key)] = std::string(line.substr(i)); + return kContinue; + } + catch(...) { /** Always return error in the end (no other status returned) */} + + return kError; + } + + void OnData(MsQuicStream* stream, nonstd::span<const QUIC_BUFFER> buffers) override + { + const unsigned int max_buffer_size = 1024; + uint32_t recvd = 0; + http_reply_buffer_.reserve(120); + + for (const auto& data : buffers) + { + for (uint32_t i = 0; i < data.Length; i++, recvd++) + { + if (http_reply_buffer_.size() > max_buffer_size) + { + LOG(ERROR) << "[WebSocket] HTTP response too large"; + goto error; + } + if ((data.Buffer[i] == '\n') && (http_reply_buffer_.size() > 0) && (http_reply_buffer_.back() == '\r')) + { + http_reply_buffer_[http_reply_buffer_.size() - 1] = '\0'; // Not strictly necessary + auto status = process_recv({http_reply_buffer_.data(), http_reply_buffer_.size() - 1}); + if (status == kDone) { goto done; } + else if (status == kError) { goto error; } + else if (status == kContinue) { http_reply_buffer_.clear(); } + } + else + { + http_reply_buffer_.push_back(data.Buffer[i]); + } + } + } + + done: + stream->Consume(recvd); + // Next EnableRecv() call will re-submit any unread data. + peer_->enable_quic_stream(); + + error: + stream->Consume(0); + stream->EnableRecv(false); + return; + } + +public: + WebSocketHandshake(MsQuicStream* stream, QuicPeerStream* peer) : + IMsQuicStreamHandler(), peer_(peer), stream_(stream) + { + stream->SetStreamHandler(this); // FIXME: not thread safe/correct + http_req_ = prepare_http_request(peer_->getURIObject()); + http_req_buffer_ = { (uint32_t)http_req_.size(), (uint8_t*)http_req_.c_str() }; + + stream->Write({&http_req_buffer_, 1}, nullptr, false); + stream->EnableRecv(); + } + + ~WebSocketHandshake(); + + void OnShutdown(MsQuicStream* stream) override + { + //peer_->OnShutdown(stream); + } + + void OnShutdownComplete(MsQuicStream* stream) override + { + //peer_->OnShutdownComplete(stream); + } +}; + +WebSocketHandshake::~WebSocketHandshake() {} + +} + +using namespace beyond_impl; + //////////////////////////////////////////////////////////////////////////////// QuicPeerStream::SendEvent::SendEvent(msgpack_buffer_t buffer_in) : - buffer(std::move(buffer_in)), pending(true), complete(false), t(0), n_buffers(1) + buffer(std::move(buffer_in)), pending(true), complete(false), n_buffers(1) { quic_vector[0].Buffer = (uint8_t*) buffer.data(); quic_vector[0].Length = buffer.size(); @@ -27,6 +212,7 @@ QuicPeerStream::QuicPeerStream(MsQuicConnection* connection, MsQuicStreamPtr str ftl::net::PeerBase(ftl::URI(), net, disp), connection_(connection), stream_(std::move(stream)), ws_frame_(true) { + ws_handshake_ = std::make_unique<WebSocketHandshake>(stream.get(), this); // TODO: remove connection_ (can't use with proxy) CHECK(stream_.get()); @@ -39,7 +225,6 @@ QuicPeerStream::QuicPeerStream(MsQuicConnection* connection, MsQuicStreamPtr str recv_buffer_.reserve_buffer(recv_buffer_default_size_); - #ifdef ENABLE_PROFILER profiler_name_ = PROFILER_RUNTIME_PERSISTENT_NAME("QuicStream[" + std::to_string(profiler_name_ctr_++) + "]"); profiler_id_.plt_pending_buffers = PROFILER_RUNTIME_PERSISTENT_NAME(getURI() + ": " + name_ + " pending buffers"); @@ -75,11 +260,16 @@ void QuicPeerStream::set_stream(MsQuicStreamPtr stream) while(recv_busy_) { lk.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(1)); lk.lock(); } } stream_ = std::move(stream); + enable_quic_stream(); +} + +void QuicPeerStream::enable_quic_stream() +{ + CHECK(stream_); stream_->SetStreamHandler(this); stream_->EnableRecv(); } - void QuicPeerStream::close(bool reconnect) { UNIQUE_LOCK_T(send_mtx_) lk_send(send_mtx_, std::defer_lock); @@ -153,14 +343,12 @@ msgpack_buffer_t QuicPeerStream::get_buffer_() << pending_sends_ << " writes pending (" << pending_bytes_ / 1024 << " KiB). Network performance degraded."; - send_cv_.wait(lk); + // What happens on disconnect/exit? + send_cv_.wait(lk, [&]() { return send_buffers_free_.size() > 0; }); - if (send_buffers_free_.size() > 0) - { - auto buffer = std::move(send_buffers_free_.back()); - send_buffers_free_.pop_back(); - return buffer; - } + auto buffer = std::move(send_buffers_free_.back()); + send_buffers_free_.pop_back(); + return buffer; } // create a new buffer diff --git a/src/quic/src/quic_peer.hpp b/src/quic/src/quic_peer.hpp index ae95c7ef5886d5424903793def8d5551286c6f3b..7993b14f229458d80db86885685fe5ecbb300dfe 100644 --- a/src/quic/src/quic_peer.hpp +++ b/src/quic/src/quic_peer.hpp @@ -20,6 +20,7 @@ public: virtual ~QuicPeerStream(); void set_stream(MsQuicStreamPtr stream); + void enable_quic_stream(); void set_type(ftl::protocol::NodeType t) { node_type_ = t; } @@ -63,6 +64,9 @@ private: MsQuicConnection* connection_; MsQuicStreamPtr stream_; + + std::unique_ptr<class WebSocketHandshake> ws_handshake_; + std::string name_; const bool ws_frame_; @@ -83,8 +87,6 @@ private: bool pending; bool complete; - int t; - uint8_t n_buffers; std::array<QUIC_BUFFER, 2> quic_vector; PROFILER_ASYNC_ZONE_CTX(profiler_ctx);