/**
 * @file websocket.cpp
 * @copyright Copyright (c) 2022 University of Turku, MIT License
 * @author Sebastian Hahta
 */

#include <string>
#include <unordered_map>
#include <algorithm>
#include "websocket.hpp"
#include <ftl/lib/loguru.hpp>

#include <ftl/utility/base64.hpp>

using uchar = unsigned char;

#ifdef HAVE_GNUTLS
#include <gnutls/crypto.h>

inline uint32_t secure_rnd() {
    uint32_t rnd;
    gnutls_rnd(GNUTLS_RND_NONCE, &rnd, sizeof(uint32_t));
    return rnd;
}
#else
#include <random>
static std::random_device rd_;
static std::uniform_int_distribution<uint32_t> dist_(0);

inline uint32_t secure_rnd() {
    // TODO(Seb)
    return dist_(rd_);
}
#endif

using ftl::URI;
using ftl::net::internal::WebSocketBase;
using ftl::net::internal::Connection_TCP;
using ftl::net::internal::Connection_TLS;

/* Taken from easywsclient */
struct wsheader_type {
    unsigned header_size;
    bool fin;
    bool mask;
    enum opcode_type {
        CONTINUATION = 0x0,
        TEXT_FRAME = 0x1,
        BINARY_FRAME = 0x2,
        CLOSE = 8,
        PING = 9,
        PONG = 0xa,
    } opcode;
    int N0;
    uint64_t N;
    uint8_t masking_key[4];
};

struct ws_options {
    std::string userinfo = "";
};

// prepare ws header
int ws_prepare(wsheader_type::opcode_type op, bool useMask, uint32_t mask,
                size_t len, char *data, size_t maxlen) {
    uint8_t* masking_key = reinterpret_cast<uint8_t*>(&mask);

    char *header = data;
    size_t header_size = 2 + (len >= 126 ? 2 : 0) + (len >= 65536 ? 6 : 0) + (useMask ? 4 : 0);
    if (header_size > maxlen) return -1;

    memset(header, 0, header_size);
    header[0] = 0x80 | op;

    if (len < 126) {
        header[1] = (len & 0xff) | (useMask ? 0x80 : 0);
        if (useMask) {
            header[2] = masking_key[0];
            header[3] = masking_key[1];
            header[4] = masking_key[2];
            header[5] = masking_key[3];
        }
    } else if (len < 65536) {
        header[1] = 126 | (useMask ? 0x80 : 0);
        header[2] = (len >> 8) & 0xff;
        header[3] = (len >> 0) & 0xff;
        if (useMask) {
            header[4] = masking_key[0];
            header[5] = masking_key[1];
            header[6] = masking_key[2];
            header[7] = masking_key[3];
        }
    } else {
        header[1] = 127 | (useMask ? 0x80 : 0);
        header[2] = (len >> 56) & 0xff;
        header[3] = (len >> 48) & 0xff;
        header[4] = (len >> 40) & 0xff;
        header[5] = (len >> 32) & 0xff;
        header[6] = (len >> 24) & 0xff;
        header[7] = (len >> 16) & 0xff;
        header[8] = (len >>  8) & 0xff;
        header[9] = (len >>  0) & 0xff;
        if (useMask) {
            header[10] = masking_key[0];
            header[11] = masking_key[1];
            header[12] = masking_key[2];
            header[13] = masking_key[3];
        }
    }

    return static_cast<int>(header_size);
}

// parse ws header, returns true on success
// TODO(Seb): return error code for different results (not enough bytes in buffer
//    to build the header vs corrputed/invalid header)
bool ws_parse(uchar *data, size_t len, wsheader_type *ws) {
    if (len < 2) return false;

    ws->fin = (data[0] & 0x80) == 0x80;
    ws->opcode = (wsheader_type::opcode_type) (data[0] & 0x0f);
    ws->mask = (data[1] & 0x80) == 0x80;
    ws->N0 = (data[1] & 0x7f);
    ws->header_size = 2 + (ws->N0 == 126? 2 : 0) + (ws->N0 == 127? 8 : 0) + (ws->mask? 4 : 0);

    if (len < ws->header_size) return false;

    // invalid opcode, corrupted header?
    if ((ws->opcode > 10) || ((ws->opcode > 2) && (ws->opcode < 8))) return false;

    int i = 0;
    if (ws->N0 < 126) {
        ws->N = ws->N0;
        i = 2;
    } else if (ws->N0 == 126) {
        ws->N = 0;
        ws->N |= ((uint64_t) data[2]) << 8;
        ws->N |= ((uint64_t) data[3]) << 0;
        i = 4;
    } else if (ws->N0 == 127) {
        ws->N = 0;
        ws->N |= ((uint64_t) data[2]) << 56;
        ws->N |= ((uint64_t) data[3]) << 48;
        ws->N |= ((uint64_t) data[4]) << 40;
        ws->N |= ((uint64_t) data[5]) << 32;
        ws->N |= ((uint64_t) data[6]) << 24;
        ws->N |= ((uint64_t) data[7]) << 16;
        ws->N |= ((uint64_t) data[8]) << 8;
        ws->N |= ((uint64_t) data[9]) << 0;
        i = 10;
    }

    if (ws->mask) {
        ws->masking_key[0] = ((uint8_t) data[i+0]) << 0;
        ws->masking_key[1] = ((uint8_t) data[i+1]) << 0;
        ws->masking_key[2] = ((uint8_t) data[i+2]) << 0;
        ws->masking_key[3] = ((uint8_t) data[i+3]) << 0;
    } else {
        ws->masking_key[0] = 0;
        ws->masking_key[1] = 0;
        ws->masking_key[2] = 0;
        ws->masking_key[3] = 0;
    }

    return true;
}

// same as above, pointer type casted to unsigned
bool ws_parse(char *data, size_t len, wsheader_type *ws) {
    return ws_parse(reinterpret_cast<unsigned char*>(data), len, ws);
}


int getPort(const ftl::URI &uri) {
    auto port = uri.getPort();

    if (port == 0) {
        if (uri.getScheme() == URI::scheme_t::SCHEME_WS) {
            port = 80;
        } else if (uri.getScheme() == URI::scheme_t::SCHEME_WSS) {
            port = 443;
        } else {
            throw FTL_Error("Bad WS uri:" + uri.to_string());
        }
    }

    return port;
}

////////////////////////////////////////////////////////////////////////////////

template<typename SocketT>
WebSocketBase<SocketT>::WebSocketBase() {}

template<typename SocketT>
void WebSocketBase<SocketT>::connect(const ftl::URI& uri, int timeout) {
    int port = getPort(uri);

    // connect via TCP/TLS
    if (!SocketT::connect(uri.getHost(), port, timeout)) {
        throw FTL_Error("WS: connect() failed");
    }

    std::string http = "";
    int status;
    int i;
    char line[256];

    http += "GET " + uri.getPath() + " HTTP/1.1\r\n";
    if (port == 80) {
        http += "Host: " + uri.getHost() + "\r\n";
    } else {
        // TODO(Seb): is this correct when connecting over TLS
        http += "Host: " + uri.getHost() + ":"
             + std::to_string(port) + "\r\n";
    }

    if (uri.hasUserInfo()) {
        http += "Authorization: Basic ";
        http += base64_encode(uri.getUserInfo()) + "\r\n";
        // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
        if (uri.getProtocol() != URI::scheme_t::SCHEME_WSS) {
            LOG(WARNING) << "HTTP Basic Auth is being sent without TLS";
        }
    }

    http += "Upgrade: websocket\r\n";
    http += "Connection: Upgrade\r\n";
    http += "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n";
    http += "Sec-WebSocket-Version: 13\r\n";
    http += "\r\n";

    int rc = SocketT::send(http.c_str(), static_cast<int>(http.length()));
    if (rc != static_cast<int>(http.length())) {
        throw FTL_Error("Could not send Websocket http request... ("
                        + std::to_string(rc) + ", "
                         + std::to_string(errno) + ")\n" + http);
    }

    for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) {
        if (SocketT::recv(line + i, 1) == 0) {
            throw FTL_Error("Connection closed by remote");
        }
    }

    line[i] = 0;
    if (i == 255) {
        throw FTL_Error("Got invalid status line connecting to: " + uri.getHost());
    }
    if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101) {
        throw FTL_Error("ERROR: Got bad status connecting to: "
                        + uri.getHost() + ": " + line);
    }

    std::unordered_map<std::string, std::string> headers;

    while (true) {
        for (i = 0; i < 2 || (i < 255 && line[i-2] != '\r' && line[i-1] != '\n'); ++i) {
            if (SocketT::recv(line+i, 1) == 0) {
                throw FTL_Error("Connection closed by remote");
            }
        }
        if (line[0] == '\r' && line[1] == '\n') { break; }

        // Split the headers into a map for checking
        line[i] = 0;
        const std::string cppline(line);
        const auto ix = cppline.find(":");
        const auto label = cppline.substr(0, ix);
        const auto value = cppline.substr(ix + 2, cppline.size() - ix - 4);
        headers[label] = value;
    }

    // Validate some of the headers
    if (headers.count("Connection") == 0 || headers.at("Connection") != "upgrade")
        throw FTL_Error("Missing WS connection header");
    if (headers.count("Upgrade") == 0 || headers.at("Upgrade") != "websocket")
        throw FTL_Error("Missing WS Upgrade");
    if (headers.count("Sec-WebSocket-Accept") == 0)
        throw FTL_Error("Missing WS accept header");
}

template<typename SocketT>
bool WebSocketBase<SocketT>::prepare_next(char* data, size_t data_len, size_t& offset) {
    offset = 0;

    // Header may be smaller than 14 bytes. If there isn't enough data,
    // do not process before receiving more data.
    if (data_len < 14) { return false; }

    wsheader_type header;
    if (!ws_parse(data, data_len, &header)) {
        throw FTL_Error("corrupted WS header");
    }

    if ((header.N + header.header_size) > data_len) {
        /*LOG(WARNING) << "buffered: " << data_len
                     << ", ws frame size: " << (header.N + header.header_size)
                     << " (not enough data in buffer)"; */
        return false;
    }

    if (header.mask) {
        throw FTL_Error("masked WebSocket data not supported");  // TODO(Seb):
    }

    // payload/application data/extension of control frames should be ignored?
    // fragments are OK (data is be in order and frames are not interleaved)

    offset = header.header_size;
    return true;
}

template<typename SocketT>
ssize_t WebSocketBase<SocketT>::writev(const struct iovec *iov, int iovcnt) {

    if ((iovcnt + 1) >= ssize_t(iovecs_.size())) { iovecs_.resize(iovcnt + 1); }
    // copy iovecs to local buffer, first iovec entry reserved for header
    std::copy(iov, iov + iovcnt, iovecs_.data() + 1);

    // masking
    size_t msglen = 0;
    uint32_t mask = secure_rnd();
    uint8_t* masking_key = reinterpret_cast<uint8_t*>(&mask);

    // calculate total size of message and mask it.
    for (int i = 1; i < iovcnt + 1; i++) {
        const size_t mlen = iovecs_[i].iov_len;
        char *buf = reinterpret_cast<char*>(iovecs_[i].iov_base);

        // TODO(Seb): Make this more efficient.
        for (size_t j = 0; j != mlen; ++j) {
            buf[j] ^= masking_key[(msglen + j)&0x3];
        }
        msglen += mlen;
    }

    // create header
    constexpr size_t kHSize = 20;
    char h_buffer[kHSize];

    auto rc = ws_prepare(wsheader_type::BINARY_FRAME, true, mask, msglen, h_buffer, kHSize);
    if (rc < 0) { return -1; }

    // send header + data
    iovecs_[0].iov_base = h_buffer;
    iovecs_[0].iov_len = rc;

    auto sent = SocketT::writev(iovecs_.data(), iovcnt + 1);
    if (sent > 0) {
        // do not report sent header size
        return sent - rc;
    }
    return sent;
}

template<typename SocketT>
ftl::URI::scheme_t WebSocketBase<SocketT>::scheme() const {return ftl::URI::SCHEME_TCP; }

// explicit instantiation
template class WebSocketBase<Connection_TCP>;  // Connection_WS
#ifdef HAVE_GNUTLS
template class WebSocketBase<Connection_TLS>;  // Connection_WSS
#endif