From 383c1f8c53f9cf6f8301294998e2f179d5f27679 Mon Sep 17 00:00:00 2001 From: Davide Pianca Date: Mon, 7 Feb 2022 15:51:50 +0100 Subject: [PATCH] Implement base64 and SHA1 --- build.gradle.kts | 1 - src/commonMain/kotlin/CommonUtils.kt | 74 ++++++++++++++++--- src/commonMain/kotlin/socket/tcp/WebSocket.kt | 5 +- src/commonTest/kotlin/mqtt/CommonUtilsTest.kt | 15 ++++ 4 files changed, 81 insertions(+), 14 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 38fc8c8..a50fc2e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -67,7 +67,6 @@ kotlin { implementation(kotlin("stdlib-common")) implementation("org.jetbrains.kotlinx:kotlinx-serialization-core:$serializationVersion") implementation("org.jetbrains.kotlinx:kotlinx-serialization-protobuf:$serializationVersion") - implementation("com.soywiz.korlibs.krypto:krypto:2.4.12") } } val commonTest by getting { diff --git a/src/commonMain/kotlin/CommonUtils.kt b/src/commonMain/kotlin/CommonUtils.kt index 54a1c3a..864b5d3 100644 --- a/src/commonMain/kotlin/CommonUtils.kt +++ b/src/commonMain/kotlin/CommonUtils.kt @@ -1,3 +1,4 @@ +import socket.streams.ByteArrayOutputStream import kotlin.random.Random expect fun currentTimeMillis(): Long @@ -36,6 +37,10 @@ fun String.validateUTF8String(): Boolean { fun UByteArray.toHexString() = joinToString("") { it.toString(16).padStart(2, '0') } +fun UIntArray.toHexString() = joinToString("") { it.toString(16).padStart(8, '0') } + +fun String.fromHexString(): ByteArray = chunked(2).map { it.toInt(16).toByte() }.toByteArray() + fun MutableMap.removeIf(predicate: (MutableMap.MutableEntry) -> Boolean): Boolean { var removed = false val iterator = iterator() @@ -51,8 +56,7 @@ fun MutableMap.removeIf(predicate: (MutableMap.MutableEntry) private infix fun UInt.leftRotate(bits: Int): UInt = ((this shl bits) or (this shr (32 - bits))) -// TODO implement base64 -private fun ByteArray.sha1(): ByteArray { // TODO fix +fun ByteArray.sha1(): ByteArray { val hash = UIntArray(5) hash[0] = 0x67452301u hash[1] = 0xEFCDAB89u @@ -60,21 +64,28 @@ private fun ByteArray.sha1(): ByteArray { // TODO fix hash[3] = 0x10325476u hash[4] = 0xC3D2E1F0u - val ml = this.size * 8 + val ml = (this.size * 8).toULong() + + // Prepare the data + val outStream = ByteArrayOutputStream() - val chunks = UIntArray((((this.size + 8) shr 6) + 1) * 16) + outStream.write(this.toUByteArray()) + outStream.write(0x80u) - for (i in 0 until this.size) { - chunks[i shr 2] = chunks[i shr 2] or (this[i].toUInt() shl (24 - (i % 4) * 8)) + while ((outStream.size() + 8) % 64 != 0) { + outStream.write(0u) } + outStream.writeULong(ml) - chunks[this.size shr 2] = chunks[this.size shr 2] or (0x80u shl (24 - (this.size % 4) * 8)) - chunks[chunks.size - 1] = ml.toUInt() + val data = outStream.toByteArray() - for (j in chunks.indices step 16) { + for (j in data.indices step 64) { val w = UIntArray(80) for (i in 0 until 16) { - w[i] = chunks[j + i] + w[i] = (data[j + i * 4].toUInt() shl 24) or + (data[j + i * 4 + 1].toUInt() shl 16) or + (data[j + i * 4 + 2].toUInt() shl 8) or + data[j + i * 4 + 3].toUInt() } for (i in 16 until 80) { w[i] = (w[i - 3] xor w[i - 8] xor w[i - 14] xor w[i - 16]) leftRotate 1 @@ -122,5 +133,46 @@ private fun ByteArray.sha1(): ByteArray { // TODO fix hash[4] += e } - return hash.foldIndexed(ByteArray(hash.size)) { i, a, v -> a.apply { set(i, v.toByte()) } } + val hexString = hash.toHexString() + + return hexString.fromHexString() +} + +fun ByteArray.toBase64(): String { + val base64chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + var r = "" + var p = "" + var c = this.size % 3 + + val outStream = ByteArrayOutputStream() + outStream.write(this.toUByteArray()) + + if (c > 0) { + while (c < 3) { + p += "=" + outStream.write(0u) + c++ + } + } + + val s = outStream.toByteArray() + + c = 0 + while (c < this.size) { + if (c > 0 && (c / 3 * 4) % 76 == 0) { + r += "\r\n" + } + + val n = (s[c].toInt() shl 16) + (s[c + 1].toInt() shl 8) + s[c + 2].toInt() + + val n1 = n shr 18 and 63 + val n2 = n shr 12 and 63 + val n3 = n shr 6 and 63 + val n4 = n and 63 + + r += ("" + base64chars[n1] + base64chars[n2] + base64chars[n3] + base64chars[n4]) + c += 3 + } + + return r.substring(0, r.length - p.length) + p } diff --git a/src/commonMain/kotlin/socket/tcp/WebSocket.kt b/src/commonMain/kotlin/socket/tcp/WebSocket.kt index 6ee9894..cd83f58 100644 --- a/src/commonMain/kotlin/socket/tcp/WebSocket.kt +++ b/src/commonMain/kotlin/socket/tcp/WebSocket.kt @@ -1,10 +1,11 @@ package socket.tcp -import com.soywiz.krypto.sha1 +import sha1 import socket.SocketInterface import socket.streams.ByteArrayOutputStream import socket.streams.DynamicByteBuffer import socket.streams.EOFException +import toBase64 class WebSocket(private val socket: Socket) : SocketInterface { @@ -51,7 +52,7 @@ class WebSocket(private val socket: Socket) : SocketInterface { } val match = Regex("Sec-WebSocket-Key: (.*)") val key = match.find(string)?.groups?.get(1)?.value - val digest = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encodeToByteArray().sha1().base64 + val digest = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encodeToByteArray().sha1().toBase64() val response = ( "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" diff --git a/src/commonTest/kotlin/mqtt/CommonUtilsTest.kt b/src/commonTest/kotlin/mqtt/CommonUtilsTest.kt index 402bdc9..9a9a024 100644 --- a/src/commonTest/kotlin/mqtt/CommonUtilsTest.kt +++ b/src/commonTest/kotlin/mqtt/CommonUtilsTest.kt @@ -1,7 +1,10 @@ package mqtt +import sha1 +import toBase64 import validateUTF8String import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertTrue @@ -34,4 +37,16 @@ class CommonUtilsTest { assertTrue { ubyteArrayOf(0xEFu, 0xBBu, 0xBFu).toByteArray().decodeToString().validateUTF8String() } } + + @Test + fun testSHA1() { + val str1 = "The quick brown fox jumps over the lazy dog".encodeToByteArray().sha1() + assertEquals("L9ThxnotKPzthJ7hu3bnORuT6xI=", str1.toBase64()) + + val str2 = "The quick brown fox jumps over the lazy cog".encodeToByteArray().sha1() + assertEquals("3p8sf9JeGzr60+haC9F9mxANtLM=", str2.toBase64()) + + val str3 = "".encodeToByteArray().sha1() + assertEquals("2jmj7l5rSw0yVb/vlWAYkK/YBwk=", str3.toBase64()) + } }