From 6f45fa6967d91b36cddf70a61c882293a17e3099 Mon Sep 17 00:00:00 2001
From: Nicolas Pope <nicolas.pope@utu.fi>
Date: Fri, 30 Sep 2022 18:02:29 +0000
Subject: [PATCH] #38 Check file is valid before writing

---
 include/ftl/protocol/packet.hpp |  8 ++++----
 src/socket/socket_linux.cpp     |  9 ++++++++-
 src/streams/filestream.cpp      | 10 ++++++++++
 src/streams/filestream.hpp      |  1 +
 test/filestream_unit.cpp        | 11 +++++++++++
 5 files changed, 34 insertions(+), 5 deletions(-)

diff --git a/include/ftl/protocol/packet.hpp b/include/ftl/protocol/packet.hpp
index 0227c71..4994194 100644
--- a/include/ftl/protocol/packet.hpp
+++ b/include/ftl/protocol/packet.hpp
@@ -101,10 +101,10 @@ struct StreamPacket {
     inline int frameNumber() const { return (version >= 4) ? frame_number : streamID; }
     inline size_t frameSetID() const { return (version >= 4) ? streamID : 0; }
 
-    int64_t localTimestamp;              // Not message packet / saved
-    mutable unsigned int hint_capability;        // Is this a video stream, for example
-    size_t hint_source_total;            // Number of tracks per frame to expect
-    int retry_count = 0;                 // Decode retry count
+    int64_t localTimestamp = 0;                 // Not message packet / saved
+    mutable unsigned int hint_capability = 0;   // Is this a video stream, for example
+    size_t hint_source_total = 0;               // Number of tracks per frame to expect
+    int retry_count = 0;                        // Decode retry count
     unsigned int hint_peerid = 0;
 
     operator std::string() const;
diff --git a/src/socket/socket_linux.cpp b/src/socket/socket_linux.cpp
index 7989d7e..d9b8b8f 100644
--- a/src/socket/socket_linux.cpp
+++ b/src/socket/socket_linux.cpp
@@ -205,7 +205,14 @@ Socket ftl::net::internal::create_tcp_socket() {
 
 std::string ftl::net::internal::get_host(const SocketAddress& addr) {
     char hbuf[1024];
-    int err = getnameinfo(reinterpret_cast<const sockaddr*>(&(addr.addr)), addr.len, hbuf, sizeof(hbuf), NULL, 0, NI_NAMEREQD);
+    int err = getnameinfo(
+        reinterpret_cast<const sockaddr*>(&(addr.addr)),
+        addr.len,
+        hbuf,
+        sizeof(hbuf),
+        NULL,
+        0,
+        NI_NAMEREQD);
     if (err == 0) { return std::string(hbuf); }
     else if (err == EAI_NONAME) return ftl::net::internal::get_ip(addr);
     else
diff --git a/src/streams/filestream.cpp b/src/streams/filestream.cpp
index db69504..ff09e65 100644
--- a/src/streams/filestream.cpp
+++ b/src/streams/filestream.cpp
@@ -10,6 +10,7 @@
 #include <utility>
 #include <limits>
 #include <algorithm>
+#include <filesystem>
 #include <thread>
 #include <chrono>
 #include "filestream.hpp"
@@ -478,6 +479,14 @@ bool File::run() {
     return true;
 }
 
+bool File::_validateFilename() const {
+    std::filesystem::path file = std::filesystem::u8path(uri_.toFilePath());
+    if (!std::filesystem::exists(file)) return true;
+    if (std::string(file.extension().u8string().c_str()) == ".ftl") return true;
+    // TODO(Nick): Could also check directory path
+    return false;
+}
+
 bool File::begin() {
     if (active_) return true;
     if (mode_ == Mode::Read) {
@@ -498,6 +507,7 @@ bool File::begin() {
         run();
     } else if (mode_ == Mode::Write) {
         if (!ostream_) ostream_ = new std::ofstream;
+        if (!_validateFilename()) return false;
         ostream_->open(uri_.toFilePath(), std::ifstream::out | std::ifstream::binary);
 
         if (!ostream_->good()) {
diff --git a/src/streams/filestream.hpp b/src/streams/filestream.hpp
index 269b866..89d2991 100644
--- a/src/streams/filestream.hpp
+++ b/src/streams/filestream.hpp
@@ -122,6 +122,7 @@ class File : public Stream {
 
     bool _open();
     bool _checkFile();
+    bool _validateFilename() const;
 
     /* Apply version patches etc... */
     void _patchPackets(ftl::protocol::StreamPacket *spkt, ftl::protocol::DataPacket *pkt);
diff --git a/test/filestream_unit.cpp b/test/filestream_unit.cpp
index 2cb0f30..64ed59d 100644
--- a/test/filestream_unit.cpp
+++ b/test/filestream_unit.cpp
@@ -1,5 +1,6 @@
 #include "catch.hpp"
 
+#include <fstream>
 #include <filesystem>
 #include <ftl/protocol/streams.hpp>
 #include <ftl/protocol.hpp>
@@ -124,3 +125,13 @@ TEST_CASE("File write and read", "[stream]") {
         REQUIRE( channels[2] == Channel::kScreen );
     }
 }
+
+TEST_CASE("File write fails for bad filename", "[stream]") {
+    std::string filename = (std::filesystem::temp_directory_path() / "badfile.exe").string();
+    std::ofstream out(filename);
+    out << "something";
+    out.close();
+    auto writer = ftl::createStream(filename);
+
+    REQUIRE(!writer->begin());
+}
-- 
GitLab