From d7572fbded104028285d19d698f3d29e149db7f0 Mon Sep 17 00:00:00 2001
From: Davide Pianca <davidepianca98@gmail.com>
Date: Thu, 8 Aug 2024 10:26:16 +0200
Subject: [PATCH] Prevent concurrent access to the send buffer on JVM socket

---
 .../src/jvmMain/kotlin/socket/tcp/Socket.kt   | 51 +++++++++++--------
 1 file changed, 31 insertions(+), 20 deletions(-)

diff --git a/kmqtt-common/src/jvmMain/kotlin/socket/tcp/Socket.kt b/kmqtt-common/src/jvmMain/kotlin/socket/tcp/Socket.kt
index da56d74..45da346 100644
--- a/kmqtt-common/src/jvmMain/kotlin/socket/tcp/Socket.kt
+++ b/kmqtt-common/src/jvmMain/kotlin/socket/tcp/Socket.kt
@@ -8,6 +8,8 @@ import java.nio.BufferOverflowException
 import java.nio.ByteBuffer
 import java.nio.channels.SelectionKey
 import java.nio.channels.SocketChannel
+import java.util.concurrent.locks.ReentrantLock
+import kotlin.concurrent.withLock
 
 public actual open class Socket(
     protected val channel: SocketChannel,
@@ -16,31 +18,40 @@ public actual open class Socket(
     private val receiveBuffer: ByteBuffer
 ) : SocketInterface {
 
+    private val lock = ReentrantLock()
+
     actual override fun send(data: UByteArray) {
-        val byteArray = data.toByteArray()
-        try {
-            sendBuffer.put(byteArray)
-        } catch (e: BufferOverflowException) {
-            sendBuffer = ByteBuffer.allocate(sendBuffer.capacity() + data.size)
-            sendBuffer.put(byteArray)
+        lock.withLock {
+            val byteArray = data.toByteArray()
+            try {
+                sendBuffer.put(byteArray)
+            } catch (e: BufferOverflowException) {
+                sendBuffer = ByteBuffer.allocate(sendBuffer.capacity() + data.size)
+                sendBuffer.put(byteArray)
+            }
+            sendFromBuffer()
         }
-        sendFromBuffer()
     }
 
     protected fun sendFromBuffer() {
-        sendBuffer.flip()
-        val size = sendBuffer.remaining()
-        try {
-            val count = channel.write(sendBuffer)
-            if (count < size) {
-                key?.interestOps(SelectionKey.OP_WRITE)
-            } else {
-                key?.interestOps(SelectionKey.OP_READ)
-            }
-            sendBuffer.compact()
-        } catch (e: java.io.IOException) {
-            close()
-            throw IOException(e.message)
+        lock.withLock {
+            sendBuffer.flip()
+            val size = sendBuffer.remaining()
+            try {
+                val count = channel.write(sendBuffer)
+                if (count < size) {
+                    key?.interestOps(SelectionKey.OP_WRITE)
+                } else {
+                    key?.interestOps(SelectionKey.OP_READ)
+                }
+                sendBuffer.compact()
+            } catch (e: java.io.IOException) {
+                close()
+                throw IOException(e.message)
+            }/* catch (e: IllegalArgumentException) {
+                close()
+                throw IOException(e.message)
+            }*/
         }
     }
 
-- 
GitLab