diff --git a/src/protocol/websocket.cpp b/src/protocol/websocket.cpp index b1c64ea45cbdec1aaaf4bb4b98b65f49a60935fe..514ff6db61c91df4ead6e9f81b37f20988b2c865 100644 --- a/src/protocol/websocket.cpp +++ b/src/protocol/websocket.cpp @@ -5,6 +5,7 @@ */ #include <string> +#include <unordered_map> #include <algorithm> #include "websocket.hpp" #include <ftl/lib/loguru.hpp> @@ -228,7 +229,6 @@ void WebSocketBase<SocketT>::connect(const ftl::URI& uri, int timeout) { http += "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"; http += "Sec-WebSocket-Version: 13\r\n"; http += "\r\n"; - // TODO(Seb): check/process HTTP response code int rc = SocketT::send(http.c_str(), static_cast<int>(http.length())); if (rc != static_cast<int>(http.length())) { @@ -252,7 +252,8 @@ void WebSocketBase<SocketT>::connect(const ftl::URI& uri, int timeout) { + uri.getHost() + ": " + line); } - // TODO(Seb): verify response headers, + std::unordered_map<std::string, std::string> headers; + while (true) { for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) { if (SocketT::recv(line+i, 1) == 0) { @@ -260,7 +261,23 @@ void WebSocketBase<SocketT>::connect(const ftl::URI& uri, int timeout) { } } if (line[0] == '\r' && line[1] == '\n') { break; } + + // Split the headers into a map for checking + line[i] = 0; + const std::string cppline(line); + const auto ix = cppline.find(":"); + const auto label = cppline.substr(0, ix); + const auto value = cppline.substr(ix + 2, cppline.size() - ix - 4); + headers[label] = value; } + + // Validate some of the headers + if (headers.count("Connection") == 0 || headers.at("Connection") != "upgrade") + throw FTL_Error("Missing WS connection header"); + if (headers.count("Upgrade") == 0 || headers.at("Upgrade") != "websocket") + throw FTL_Error("Missing WS Upgrade"); + if (headers.count("Sec-WebSocket-Accept") == 0) + throw FTL_Error("Missing WS accept header"); } template<typename SocketT>