From 1ecc791a90bd95521c1d3f31c01eb3650d83d366 Mon Sep 17 00:00:00 2001
From: davidepianca98 <davidepianca98@gmail.com>
Date: Sat, 11 Jan 2025 13:48:37 +0100
Subject: [PATCH] Add EC private key support on JVM

---
 build.gradle.kts                              |  1 -
 gradle/libs.versions.toml                     |  2 --
 .../io/github/davidepianca98/MQTTClient.kt    |  2 +-
 .../github/davidepianca98/TLSClientSocket.kt  | 22 +++++++++++++++++--
 4 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/build.gradle.kts b/build.gradle.kts
index fd7d381..8d1a35c 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -13,7 +13,6 @@ buildscript {
 
 plugins {
     alias(libs.plugins.kotlin.multiplatform) apply false
-    alias(libs.plugins.complete.kotlin)
     alias(libs.plugins.goncalossilva.resources)
 }
 
diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml
index 539fdf8..d501a28 100644
--- a/gradle/libs.versions.toml
+++ b/gradle/libs.versions.toml
@@ -3,7 +3,6 @@ serialization = "1.7.3"
 coroutines = "1.9.0"
 atomicfu = "0.26.1"
 nodeWrapper = "20.11.30-pre.732"
-completeKotlin = "1.1.0"
 silvaResources = "0.4.0"
 kotlin = "2.1.0"
 shadow = "8.3.5"
@@ -20,6 +19,5 @@ goncalossilva-resources = { module = "com.goncalossilva:resources", version.ref
 [plugins]
 kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
 kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
-complete-kotlin = { id = "com.louiscad.complete-kotlin", version.ref = "completeKotlin" }
 goncalossilva-resources = { id = "com.goncalossilva.resources", version.ref = "silvaResources" }
 johnrengelman-shadow = { id = "com.gradleup.shadow", version.ref = "shadow" }
diff --git a/kmqtt-client/src/commonMain/kotlin/io/github/davidepianca98/MQTTClient.kt b/kmqtt-client/src/commonMain/kotlin/io/github/davidepianca98/MQTTClient.kt
index 4dce2ec..5a55bb6 100644
--- a/kmqtt-client/src/commonMain/kotlin/io/github/davidepianca98/MQTTClient.kt
+++ b/kmqtt-client/src/commonMain/kotlin/io/github/davidepianca98/MQTTClient.kt
@@ -96,7 +96,6 @@ import kotlinx.coroutines.yield
  * @param publishReceived called when a PUBLISH packet has been received
  */
 public class MQTTClient(
-    private val autoInit: Boolean = true,
     private val mqttVersion: MQTTVersion,
     private val address: String,
     private val port: Int,
@@ -115,6 +114,7 @@ public class MQTTClient(
     private val willQos: Qos = Qos.AT_MOST_ONCE,
     private val connackTimeout: Int = 30,
     private val connectTimeout: Int = 30,
+    private val autoInit: Boolean = true,
     private val enhancedAuthCallback: (authenticationData: UByteArray?) -> UByteArray? = { null },
     private val onConnected: (connack: MQTTConnack) -> Unit = {},
     private val onDisconnected: (disconnect: MQTTDisconnect?) -> Unit = {},
diff --git a/kmqtt-client/src/jvmMain/kotlin/io/github/davidepianca98/TLSClientSocket.kt b/kmqtt-client/src/jvmMain/kotlin/io/github/davidepianca98/TLSClientSocket.kt
index 49b53f0..d946b50 100644
--- a/kmqtt-client/src/jvmMain/kotlin/io/github/davidepianca98/TLSClientSocket.kt
+++ b/kmqtt-client/src/jvmMain/kotlin/io/github/davidepianca98/TLSClientSocket.kt
@@ -12,7 +12,9 @@ import java.security.KeyFactory
 import java.security.KeyStore
 import java.security.cert.CertificateFactory
 import java.security.cert.X509Certificate
+import java.security.interfaces.ECPrivateKey
 import java.security.interfaces.RSAPrivateKey
+import java.security.spec.InvalidKeySpecException
 import java.security.spec.PKCS8EncodedKeySpec
 import java.util.*
 import javax.net.ssl.KeyManagerFactory
@@ -77,7 +79,12 @@ public actual class TLSClientSocket actual constructor(
 
             val keyStore = KeyStore.getInstance(KeyStore.getDefaultType())
             keyStore.load(null, null)
-            val key = getPrivateKeyFromString(if (tlsSettings.clientCertificateKey!!.isValidPem()) tlsSettings.clientCertificateKey!! else FileInputStream(tlsSettings.clientCertificateKey!!).bufferedReader().readText())
+            val keyContent = if (tlsSettings.clientCertificateKey!!.isValidPem()) tlsSettings.clientCertificateKey!! else FileInputStream(tlsSettings.clientCertificateKey!!).bufferedReader().readText();
+            val key = try {
+                getRSAPrivateKeyFromString(keyContent)
+            } catch (e: InvalidKeySpecException) {
+                getECPrivateKeyFromString(keyContent)
+            }
             keyStore.setKeyEntry("client", key, tlsSettings.clientCertificatePassword?.toCharArray(), arrayOf(certificate))
 
             val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
@@ -101,7 +108,7 @@ public actual class TLSClientSocket actual constructor(
     }
 
     public companion object {
-        private fun getPrivateKeyFromString(key: String): RSAPrivateKey {
+        private fun getRSAPrivateKeyFromString(key: String): RSAPrivateKey {
             val privateKeyPEM = key
                 .replace("-----BEGIN PRIVATE KEY-----", "")
                 .replace("-----END PRIVATE KEY-----", "")
@@ -111,6 +118,17 @@ public actual class TLSClientSocket actual constructor(
             val keySpec = PKCS8EncodedKeySpec(encoded)
             return kf.generatePrivate(keySpec) as RSAPrivateKey
         }
+
+        private fun getECPrivateKeyFromString(key: String): ECPrivateKey {
+            val privateKeyPEM = key
+                .replace("-----BEGIN EC PRIVATE KEY-----", "")
+                .replace("-----END EC PRIVATE KEY-----", "")
+                .replace("\n","")
+            val encoded = Base64.getDecoder().decode(privateKeyPEM)
+            val kf = KeyFactory.getInstance("EC")
+            val keySpec = PKCS8EncodedKeySpec(encoded)
+            return kf.generatePrivate(keySpec) as ECPrivateKey
+        }
     }
 
     override fun send(data: UByteArray) {
-- 
GitLab