From 2ba9f10f1e906c34e4273b571ef6ffd7f3237902 Mon Sep 17 00:00:00 2001
From: Sebastian Hahta <joseha@utu.fi>
Date: Wed, 24 Jan 2024 18:10:48 +0200
Subject: [PATCH] HTTP websocket upgrade over quic (requires a proxy)

---
 src/quic/src/msquic/stream.cpp |   4 +-
 src/quic/src/msquic/stream.hpp |   8 +-
 src/quic/src/quic_peer.cpp     | 212 +++++++++++++++++++++++++++++++--
 src/quic/src/quic_peer.hpp     |   6 +-
 4 files changed, 214 insertions(+), 16 deletions(-)

diff --git a/src/quic/src/msquic/stream.cpp b/src/quic/src/msquic/stream.cpp
index 1f10f55..ec78f8a 100644
--- a/src/quic/src/msquic/stream.cpp
+++ b/src/quic/src/msquic/stream.cpp
@@ -9,7 +9,6 @@
 
 using namespace beyond_impl;
 
-
 std::unique_ptr<MsQuicStream> MsQuicStream::FromRemotePeer(MsQuicContext* MsQuic, HQUIC hStream)
 {
     CHECK(MsQuic);
@@ -63,6 +62,8 @@ QUIC_STATUS MsQuicStream::EventHandler(HQUIC hStream, void* Context, QUIC_STREAM
     MsQuicStream* Stream = static_cast<MsQuicStream*>(Context);
     auto* MsQuic = Stream->MsQuic;
 
+    auto Lock = std::unique_lock(Stream->Mtx);
+
     switch (Event->Type)
     {
     case QUIC_STREAM_EVENT_START_COMPLETE:
@@ -243,5 +244,6 @@ bool MsQuicStream::Write(nonstd::span<const QUIC_BUFFER> Buffers, void* Context,
 
 void MsQuicStream::SetStreamHandler(IMsQuicStreamHandler* ObserverIn) 
 {
+    auto Lock = std::unique_lock(Mtx);
     Observer = ObserverIn;
 }
diff --git a/src/quic/src/msquic/stream.hpp b/src/quic/src/msquic/stream.hpp
index 2e07553..6bd507f 100644
--- a/src/quic/src/msquic/stream.hpp
+++ b/src/quic/src/msquic/stream.hpp
@@ -47,7 +47,8 @@ public:
     */
     void Abort();
 
-    /** Inform source that reader is done with data. When less than received is consumed, .
+    /** Inform source that reader is done with data. When less than received is consumed, no further
+     *  callbacks are called before EnableRecv() is called again.
      */
     void Consume(int32_t BytesConsumed);
 
@@ -58,6 +59,10 @@ public:
      *  queued data. */
     bool Write(nonstd::span<const QUIC_BUFFER> Buffers, void* Context, bool Delay = false);
 
+    /** Set stream handler. If handler is replaced once stream is opened, caller must ensure any pending writes
+     *  are correctly handled (buffers not released before completion, new handler can process callbacks from 
+     *  previous writes) or that there are no pending writes.
+     */
     void SetStreamHandler(IMsQuicStreamHandler*);
 
     /** Calls to Start() */
@@ -74,6 +79,7 @@ private:
     // races during StreamClose).
     HQUIC hStream;
 
+    std::mutex Mtx; // Necessary to allow changing observer after stream is started
     IMsQuicStreamHandler* Observer;
 
     std::atomic_int PendingSends;
diff --git a/src/quic/src/quic_peer.cpp b/src/quic/src/quic_peer.cpp
index e07e05b..5d454e9 100644
--- a/src/quic/src/quic_peer.cpp
+++ b/src/quic/src/quic_peer.cpp
@@ -6,14 +6,199 @@
 
 #include "quic_websocket.cpp"
 
-using namespace beyond_impl;
-
 using ftl::protocol::NodeStatus;
 
+#include <random>
+
+namespace beyond_impl
+{
+/** WebSocket wrapper
+ *  With RFC 9220 and HTTP/3 protocol this could be made to work directly with compatible webserver
+ */
+class WebSocketHandshake : public IMsQuicStreamHandler
+{
+    QuicPeerStream* peer_;
+    MsQuicStream* stream_;
+    
+    enum Status { kContinue, kError, kDone };
+    std::map<std::string, std::string> headers_;
+    std::vector<char> http_reply_buffer_;
+
+    std::string websocket_key_;
+    std::string http_req_;
+    QUIC_BUFFER http_req_buffer_;
+
+    std::string prepare_http_request(const ftl::URI& uri)
+    {
+        std::string request;
+        auto generate_sec_websocket_key = []()
+        {
+            std::random_device rd;
+            std::mt19937 mt(rd());
+            std::uniform_int_distribution<unsigned char> dist(32, 128);
+            std::array<unsigned char, 16> key;
+            for(auto& v : key) { v = dist(mt); }
+            return base64_encode(key.data(), key.size());
+        };
+        websocket_key_ = generate_sec_websocket_key();
+
+        request += "GET " + uri.getPath() + " HTTP/1.1\r\n";
+        request += "Host: " + uri.getHost() + "\r\n";
+
+        if (uri.hasUserInfo()) {
+            request += "Authorization: Basic ";
+            request += base64_encode(uri.getUserInfo()) + "\r\n";
+        }
+
+        request += "Upgrade: websocket\r\n";
+        request += "Connection: Upgrade\r\n";
+        request += "Sec-WebSocket-Key: " + websocket_key_ + "\r\n";
+        request += "Sec-WebSocket-Version: 13\r\n";
+        request += "\r\n";
+
+        return request;
+    }
+
+    Status process_headers(std::map<std::string, std::string>& headers)
+    {
+        if (headers.count("Connection") == 0 || headers.at("Connection") != "upgrade")
+        {
+            LOG(ERROR) << "[WebSocket] Wrong Connection header";
+            return kError;
+        }
+        if (headers.count("Upgrade") == 0 || headers.at("Upgrade") != "websocket")
+        {
+            LOG(ERROR) << "[WebSocket] Wrong Upgrade header";
+            return kError;
+        }
+        if (headers.count("Sec-WebSocket-Accept") == 0)
+        {
+            LOG(ERROR) << "[WebSocket] Missing Sec-WebSocket-Accept header";
+            return kError;
+        }
+
+        return kDone;
+    }
+
+    Status process_recv(const std::string_view& line)
+    {
+        try {
+            if ((headers_.size() == 0))
+            {   
+                if (line.substr(0,9) != "HTTP/1.1 ") { return kError; }
+                auto code = atoi(line.substr(10).data());
+                if (code != 101) { return kError; }
+                return kContinue;
+            }
+
+            if (line == "\r\n")
+            {
+                // Returns either kError or kDone
+                return process_headers(headers_);
+            }
+
+            uint32_t i = 1;
+            std::string_view key;
+
+            // Find header name
+            for (; i < line.size(); i++)
+            {
+                if (line[i] == ':')
+                {
+                    key = line.substr(0, i);
+                    break;
+                }
+            }
+            // Return error if not a valid header
+            if (key.empty()) { return kError; }
+
+            // Skip whitespace
+            for (; i < line.size() && std::isspace(line[i]); i++)
+
+            headers_[std::string(key)] = std::string(line.substr(i));
+            return kContinue;
+        }
+        catch(...) { /** Always return error in the end (no other status returned) */}
+
+        return kError;
+    }
+
+    void OnData(MsQuicStream* stream, nonstd::span<const QUIC_BUFFER> buffers) override
+    {
+        const unsigned int max_buffer_size = 1024;
+        uint32_t recvd = 0;
+        http_reply_buffer_.reserve(120);
+        
+        for (const auto& data : buffers)
+        {
+            for (uint32_t i = 0; i < data.Length; i++, recvd++)
+            {
+                if (http_reply_buffer_.size() > max_buffer_size)
+                {
+                    LOG(ERROR) << "[WebSocket] HTTP response too large";
+                    goto error; 
+                }
+                if ((data.Buffer[i] == '\n') && (http_reply_buffer_.size() > 0) && (http_reply_buffer_.back() == '\r'))
+                {
+                    http_reply_buffer_[http_reply_buffer_.size() - 1] = '\0'; // Not strictly necessary
+                    auto status = process_recv({http_reply_buffer_.data(), http_reply_buffer_.size() - 1});
+                    if      (status == kDone)     { goto done; }
+                    else if (status == kError)    { goto error; }
+                    else if (status == kContinue) { http_reply_buffer_.clear(); }
+                }
+                else
+                {
+                    http_reply_buffer_.push_back(data.Buffer[i]);
+                }
+            }
+        }
+
+        done:
+        stream->Consume(recvd);
+        // Next EnableRecv() call will re-submit any unread data.
+        peer_->enable_quic_stream();
+
+        error:
+        stream->Consume(0);
+        stream->EnableRecv(false);
+        return;
+    }
+
+public:
+    WebSocketHandshake(MsQuicStream* stream, QuicPeerStream* peer) : 
+        IMsQuicStreamHandler(), peer_(peer), stream_(stream)
+    {
+        stream->SetStreamHandler(this); // FIXME: not thread safe/correct
+        http_req_ = prepare_http_request(peer_->getURIObject());
+        http_req_buffer_ = { (uint32_t)http_req_.size(),  (uint8_t*)http_req_.c_str() };
+
+        stream->Write({&http_req_buffer_, 1}, nullptr, false);
+        stream->EnableRecv();
+    }
+
+    ~WebSocketHandshake();
+
+    void OnShutdown(MsQuicStream* stream) override
+    {
+        //peer_->OnShutdown(stream);
+    }
+
+    void OnShutdownComplete(MsQuicStream* stream) override
+    {
+        //peer_->OnShutdownComplete(stream);
+    }
+};
+
+WebSocketHandshake::~WebSocketHandshake() {}
+
+}
+
+using namespace beyond_impl;
+
 ////////////////////////////////////////////////////////////////////////////////
 
 QuicPeerStream::SendEvent::SendEvent(msgpack_buffer_t buffer_in) :
-    buffer(std::move(buffer_in)), pending(true), complete(false), t(0), n_buffers(1)
+    buffer(std::move(buffer_in)), pending(true), complete(false),  n_buffers(1)
 {
     quic_vector[0].Buffer = (uint8_t*) buffer.data();
     quic_vector[0].Length = buffer.size();
@@ -27,6 +212,7 @@ QuicPeerStream::QuicPeerStream(MsQuicConnection* connection, MsQuicStreamPtr str
     ftl::net::PeerBase(ftl::URI(), net, disp),
     connection_(connection), stream_(std::move(stream)), ws_frame_(true)
 {
+    ws_handshake_ = std::make_unique<WebSocketHandshake>(stream.get(), this);
     // TODO: remove connection_ (can't use with proxy)
     CHECK(stream_.get());
 
@@ -39,7 +225,6 @@ QuicPeerStream::QuicPeerStream(MsQuicConnection* connection, MsQuicStreamPtr str
 
     recv_buffer_.reserve_buffer(recv_buffer_default_size_);
     
-    
     #ifdef ENABLE_PROFILER
     profiler_name_ = PROFILER_RUNTIME_PERSISTENT_NAME("QuicStream[" + std::to_string(profiler_name_ctr_++) + "]");
     profiler_id_.plt_pending_buffers =  PROFILER_RUNTIME_PERSISTENT_NAME(getURI() + ": " + name_ + " pending buffers");
@@ -75,11 +260,16 @@ void QuicPeerStream::set_stream(MsQuicStreamPtr stream)
         while(recv_busy_) { lk.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(1)); lk.lock(); }
     }
     stream_ = std::move(stream);
+    enable_quic_stream();
+}
+
+void QuicPeerStream::enable_quic_stream()
+{
+    CHECK(stream_);
     stream_->SetStreamHandler(this);
     stream_->EnableRecv();
 }
 
-
 void QuicPeerStream::close(bool reconnect)
 {
     UNIQUE_LOCK_T(send_mtx_) lk_send(send_mtx_, std::defer_lock);
@@ -153,14 +343,12 @@ msgpack_buffer_t QuicPeerStream::get_buffer_()
             << pending_sends_ << " writes pending ("
             << pending_bytes_ / 1024 << " KiB). Network performance degraded.";
 
-        send_cv_.wait(lk);
+        // What happens on disconnect/exit?
+        send_cv_.wait(lk, [&]() { return send_buffers_free_.size() > 0; });
 
-        if (send_buffers_free_.size() > 0)
-        {
-            auto buffer = std::move(send_buffers_free_.back());
-            send_buffers_free_.pop_back();
-            return buffer;
-        }
+        auto buffer = std::move(send_buffers_free_.back());
+        send_buffers_free_.pop_back();
+        return buffer;
     }
     
     // create a new buffer
diff --git a/src/quic/src/quic_peer.hpp b/src/quic/src/quic_peer.hpp
index ae95c7e..7993b14 100644
--- a/src/quic/src/quic_peer.hpp
+++ b/src/quic/src/quic_peer.hpp
@@ -20,6 +20,7 @@ public:
     virtual ~QuicPeerStream();
 
     void set_stream(MsQuicStreamPtr stream);
+    void enable_quic_stream();
 
     void set_type(ftl::protocol::NodeType t) { node_type_ = t; } 
 
@@ -63,6 +64,9 @@ private:
 
     MsQuicConnection* connection_;
     MsQuicStreamPtr stream_;
+
+    std::unique_ptr<class WebSocketHandshake> ws_handshake_;
+
     std::string name_;
     const bool ws_frame_;
 
@@ -83,8 +87,6 @@ private:
         bool pending;
         bool complete;
 
-        int t;
-
         uint8_t n_buffers;
         std::array<QUIC_BUFFER, 2> quic_vector;
         PROFILER_ASYNC_ZONE_CTX(profiler_ctx);
-- 
GitLab