From c97c4f92a336c2e5dc0ea6c5e175ed9ad91d5347 Mon Sep 17 00:00:00 2001 From: Davide Pianca <davidepianca98@gmail.com> Date: Sun, 21 Apr 2024 15:39:03 +0200 Subject: [PATCH] Add connackTimeout, onSubscribed callback and debugLog option --- .../PublishSubscribeMultipleClientsTest.kt | 24 ++-- .../src/commonMain/kotlin/MQTTClient.kt | 125 ++++++++++++------ 2 files changed, 97 insertions(+), 52 deletions(-) diff --git a/kmqtt-broker/src/commonTest/kotlin/integration/PublishSubscribeMultipleClientsTest.kt b/kmqtt-broker/src/commonTest/kotlin/integration/PublishSubscribeMultipleClientsTest.kt index bbf9a3b..2e4a460 100644 --- a/kmqtt-broker/src/commonTest/kotlin/integration/PublishSubscribeMultipleClientsTest.kt +++ b/kmqtt-broker/src/commonTest/kotlin/integration/PublishSubscribeMultipleClientsTest.kt @@ -20,21 +20,29 @@ class PublishSubscribeMultipleClientsTest { val broker = Broker() - val clientSub = MQTTClient(MQTTVersion.MQTT5, "127.0.0.1", broker.port, null, clientId = "client2") { + val clientPub = MQTTClient(MQTTVersion.MQTT5, "127.0.0.1", broker.port, null, clientId = "client1") {} + + broker.step() + clientPub.step() + + val clientSub = MQTTClient( + MQTTVersion.MQTT5, + "127.0.0.1", + broker.port, + null, + clientId = "client2", + onSubscribed = { + clientPub.publish(false, qos, topic, payload) + } + ) { assertEquals(topic, it.topicName) assertContentEquals(payload, it.payload) assertEquals(qos, it.qos) received = true } - broker.step() - clientSub.subscribe(listOf(Subscription(topic, SubscriptionOptions(qos)))) - broker.step() - val clientPub = MQTTClient(MQTTVersion.MQTT5, "127.0.0.1", broker.port, null, clientId = "client1") {} - broker.step() - - clientPub.publish(false, qos, topic, payload) + clientSub.subscribe(listOf(Subscription(topic, SubscriptionOptions(qos)))) var i = 0 while (!received && i < 1000) { diff --git a/kmqtt-client/src/commonMain/kotlin/MQTTClient.kt b/kmqtt-client/src/commonMain/kotlin/MQTTClient.kt index 634c95a..d71297a 100644 --- a/kmqtt-client/src/commonMain/kotlin/MQTTClient.kt +++ b/kmqtt-client/src/commonMain/kotlin/MQTTClient.kt @@ -1,5 +1,7 @@ import kotlinx.atomicfu.AtomicBoolean import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.locks.ReentrantLock +import kotlinx.atomicfu.locks.withLock import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -23,21 +25,29 @@ import socket.tls.TLSClientSettings /** * MQTT 3.1.1 and 5 client * - * @param mqttVersion sets the version of MQTT for this client (4 -> 3.1.1, 5 -> 5) - * @param address the URL of the server + * @param mqttVersion sets the version of MQTT for this client MQTTVersion.MQTT3_1_1 or MQTTVersion.MQTT5 + * @param address the URL of the server without ws/wss/mqtt/mqtts * @param port the port of the server - * @param tls TLS settings, null if no TLS + * @param tls TLS settings, null if no TLS, otherwise it must be set * @param keepAlive the MQTT keep alive of the connection in seconds * @param webSocket whether to use a WebSocket for the underlying connection, null if no WebSocket, otherwise the HTTP path, usually /mqtt + * @param cleanStart if set, the Client and Server MUST discard any existing session and start a new session + * @param clientId identifies the client to the server, but be unique on the server. If set to null then it will be auto generated * @param userName the username field of the CONNECT packet * @param password the password field of the CONNECT packet * @param properties the properties to be included in the CONNECT message (used only in MQTT5) * @param willProperties the properties to be included in the will PUBLISH message (used only in MQTT5) * @param willTopic the topic of the will PUBLISH message * @param willPayload the content of the will PUBLISH message + * @param willRetain set if the will PUBLISH must be retained by the server * @param willQos the QoS of the will PUBLISH message + * @param connackTimeout timeout in seconds after which the connection is closed if no CONNACK packet has been received * @param enhancedAuthCallback the callback called when authenticationData is received, it should return the data necessary to continue authentication or null if completed (used only in MQTT5 if authenticationMethod has been set in the CONNECT properties) - * @param publishReceived the callback called when a PUBLISH message is received by this client + * @param onConnected called when the CONNACK packet has been received and the connection has been established + * @param onDisconnected called when a DISCONNECT packet has been received or if the connection has been terminated + * @param onSubscribed called when a SUBACK packet has been received + * @param debugLog set to print the hex packets sent and received + * @param publishReceived called when a PUBLISH packet has been received */ public class MQTTClient( private val mqttVersion: MQTTVersion, @@ -56,9 +66,12 @@ public class MQTTClient( private val willPayload: UByteArray? = null, private val willRetain: Boolean = false, private val willQos: Qos = Qos.AT_MOST_ONCE, + private val connackTimeout: Int = 30, private val enhancedAuthCallback: (authenticationData: UByteArray?) -> UByteArray? = { null }, private val onConnected: (connack: MQTTConnack) -> Unit = {}, private val onDisconnected: (disconnect: MQTTDisconnect?) -> Unit = {}, + private val onSubscribed: (suback: MQTTSuback) -> Unit = {}, + private val debugLog: Boolean = false, private val publishReceived: (publish: MQTTPublish) -> Unit ) { @@ -74,14 +87,16 @@ public class MQTTClient( // Session private var packetIdentifier: UInt = 1u // QoS 1 and QoS 2 messages which have been sent to the Server, but have not been completely acknowledged - private val pendingAcknowledgeMessages = atomic(mutableMapOf<UInt, MQTTPublish>()) - private val pendingAcknowledgePubrel = atomic(mutableMapOf<UInt, MQTTPubrel>()) + private val pendingAcknowledgeMessages = mutableMapOf<UInt, MQTTPublish>() + private val pendingAcknowledgePubrel = mutableMapOf<UInt, MQTTPubrel>() // QoS 2 messages which have been received from the Server, but have not been completely acknowledged private val qos2ListReceived = mutableListOf<UInt>() // List of messages to be sent after CONNACK has been received private val pendingSendMessages = atomic(mutableListOf<UByteArray>()) + private val lock = ReentrantLock() + // Connection private val topicAliasesClient = mutableMapOf<UInt, String>() // TODO reset all these on reconnection private val maximumQos = atomic(Qos.EXACTLY_ONCE) @@ -134,6 +149,9 @@ public class MQTTClient( private fun send(data: UByteArray, force: Boolean = false) { if (connackReceived.value || force) { socket?.send(data) ?: throw SocketClosedException("MQTT send failed") + if (debugLog) { + println("Sent: " + data.toHexString()) + } lastActiveTimestamp.getAndSet(currentTimeMillis()) } else { pendingSendMessages.value += data @@ -170,22 +188,26 @@ public class MQTTClient( } private fun generatePacketId(): UInt { - do { // TODO make atomic - packetIdentifier++ - if (packetIdentifier > 65535u) - packetIdentifier = 1u - } while (isPacketIdInUse(packetIdentifier)) + lock.withLock { + do { + packetIdentifier++ + if (packetIdentifier > 65535u) + packetIdentifier = 1u + } while (isPacketIdInUse(packetIdentifier)) - return packetIdentifier + return packetIdentifier + } } private fun isPacketIdInUse(packetId: UInt): Boolean { - if (qos2ListReceived.contains(packetId)) - return true - if (pendingAcknowledgeMessages.value[packetId] != null) - return true - if (pendingAcknowledgePubrel.value[packetId] != null) - return true + lock.withLock { + if (qos2ListReceived.contains(packetId)) + return true + if (pendingAcknowledgeMessages[packetId] != null) + return true + if (pendingAcknowledgePubrel[packetId] != null) + return true + } return false } @@ -221,10 +243,12 @@ public class MQTTClient( MQTT5Publish(retain, qos, false, topic, packetId, properties, payload) } if (qos != Qos.AT_MOST_ONCE) { - if (pendingAcknowledgeMessages.value.size + pendingAcknowledgePubrel.value.size >= receiveMax.value.toInt()) { - throw Exception("Sending more PUBLISH with QoS > 0 than indicated by the server in receiveMax") + lock.withLock { + if (pendingAcknowledgeMessages.size + pendingAcknowledgePubrel.size >= receiveMax.value.toInt()) { + throw Exception("Sending more PUBLISH with QoS > 0 than indicated by the server in receiveMax") + } + pendingAcknowledgeMessages[packetId!!] = publish } - pendingAcknowledgeMessages.value[packetId!!] = publish } val data = publish.toByteArray() if (data.size > maximumServerPacketSize.value) { @@ -297,17 +321,19 @@ public class MQTTClient( } socket?.sendRemaining() if (connackReceived.value) { - val pending = pendingSendMessages.value + val pending = pendingSendMessages.getAndSet(mutableListOf()) for (data in pending) { send(data) } - pendingSendMessages.value.clear() } val data = socket?.read() if (data != null) { try { + if (debugLog) { + println("Received: " + data.toHexString()) + } currentReceivedPacket.addData(data).forEach { handlePacket(it) } @@ -342,7 +368,7 @@ public class MQTTClient( val lastActive = lastActiveTimestamp.value val isConnackReceived = connackReceived.value - if (!isConnackReceived && currentTime > lastActive + 30000) { + if (!isConnackReceived && currentTime > lastActive + (connackTimeout * 1000)) { close() lastException = Exception("CONNACK not received in 30 seconds") throw lastException!! @@ -456,16 +482,20 @@ public class MQTTClient( } else if (!cleanStart && !packet.connectAcknowledgeFlags.sessionPresentFlag) { // Session expired on the server, so clean the local session packetIdentifier = 1u - pendingAcknowledgeMessages.value.clear() - pendingAcknowledgePubrel.value.clear() - qos2ListReceived.clear() + lock.withLock { + pendingAcknowledgeMessages.clear() + pendingAcknowledgePubrel.clear() + qos2ListReceived.clear() + } } else if (!cleanStart && packet.connectAcknowledgeFlags.sessionPresentFlag) { // Resend pending publish and pubrel messages (with dup=1) - pendingAcknowledgeMessages.value.forEach { - send(it.value.setDuplicate().toByteArray()) - } - pendingAcknowledgePubrel.value.forEach { - send(it.value.toByteArray()) + lock.withLock { + pendingAcknowledgeMessages.forEach { + send(it.value.setDuplicate().toByteArray()) + } + pendingAcknowledgePubrel.forEach { + send(it.value.toByteArray()) + } } } onConnected(packet) @@ -535,21 +565,25 @@ public class MQTTClient( if (packet is MQTT5Puback && properties.requestProblemInformation == 0u && (packet.properties.reasonString != null || packet.properties.userProperty.isNotEmpty())) { throw MQTTException(ReasonCode.PROTOCOL_ERROR) } - pendingAcknowledgeMessages.value.remove(packet.packetId) + lock.withLock { + pendingAcknowledgeMessages.remove(packet.packetId) + } } private fun handlePubrec(packet: MQTTPubrec) { if (packet is MQTT5Pubrec && properties.requestProblemInformation == 0u && (packet.properties.reasonString != null || packet.properties.userProperty.isNotEmpty())) { throw MQTTException(ReasonCode.PROTOCOL_ERROR) } - pendingAcknowledgeMessages.value.remove(packet.packetId) - val pubrel = if (packet is MQTT4Pubrec) { - MQTT4Pubrel(packet.packetId) - } else { - MQTT5Pubrel(packet.packetId) + lock.withLock { + pendingAcknowledgeMessages.remove(packet.packetId) + val pubrel = if (packet is MQTT4Pubrec) { + MQTT4Pubrel(packet.packetId) + } else { + MQTT5Pubrel(packet.packetId) + } + pendingAcknowledgePubrel[packet.packetId] = pubrel + send(pubrel.toByteArray()) } - pendingAcknowledgePubrel.value[packet.packetId] = pubrel - send(pubrel.toByteArray()) } private fun handlePubrel(packet: MQTTPubrel) { @@ -571,8 +605,10 @@ public class MQTTClient( if (packet is MQTT5Pubcomp && properties.requestProblemInformation == 0u && (packet.properties.reasonString != null || packet.properties.userProperty.isNotEmpty())) { throw MQTTException(ReasonCode.PROTOCOL_ERROR) } - if (pendingAcknowledgePubrel.value.remove(packet.packetId) == null) { - throw MQTTException(ReasonCode.PACKET_IDENTIFIER_NOT_FOUND) + lock.withLock { + if (pendingAcknowledgePubrel.remove(packet.packetId) == null) { + throw MQTTException(ReasonCode.PACKET_IDENTIFIER_NOT_FOUND) + } } } @@ -593,6 +629,7 @@ public class MQTTClient( } } } + onSubscribed(packet) } private fun handleUnsuback(packet: MQTTUnsuback) { @@ -609,7 +646,7 @@ public class MQTTClient( if (packet.authenticateReasonCode == ReasonCode.CONTINUE_AUTHENTICATION) { val data = enhancedAuthCallback(packet.properties.authenticationData) val auth = MQTT5Auth(ReasonCode.CONTINUE_AUTHENTICATION, MQTT5Properties(authenticationMethod = packet.properties.authenticationMethod, authenticationData = data)) - send(auth.toByteArray()) + send(auth.toByteArray(), true) } } @@ -623,7 +660,7 @@ public class MQTTClient( ReasonCode.RE_AUTHENTICATE, MQTT5Properties(authenticationMethod = properties.authenticationMethod, authenticationData = data) ) - send(auth.toByteArray()) + send(auth.toByteArray(), true) } private fun handleDisconnect(disconnect: MQTTDisconnect) { -- GitLab