Skip to content

Commit

Permalink
use async and precomputation to improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
hbmartin committed Oct 18, 2023
1 parent 707ee34 commit 452b029
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ package com.jump.sdk.amplifyframework
sealed class CognitoException(override val message: String) : Exception(message) {
data object BadSrpB : CognitoException("Bad server public value 'B'")
data object HashOfAAndSrpBCannotBeZero : CognitoException("Hash of A and B cannot be zero")
data object UserIdNotSet : CognitoException("Must call setUserPoolParams() before this")
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.jump.sdk.amplifyframework

import com.ionspin.kotlin.bignum.integer.BigInteger
import com.ionspin.kotlin.bignum.integer.Sign
import com.ionspin.kotlin.bignum.integer.util.fromTwosComplementByteArray
import com.ionspin.kotlin.bignum.integer.util.toTwosComplementByteArray
import com.ionspin.kotlin.bignum.modular.ModularBigInteger
import io.ktor.utils.io.core.toByteArray
Expand Down Expand Up @@ -44,45 +45,47 @@ private const val HEX_N =
"CEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E2" +
"4FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF"

// precomputed: k = H(g|N)
private val kByteArray = byteArrayOf(83, -126, -126, -60, 53, 71, 66, -41, -53, -67, -30, 53, -97, -49, 103, -7, -11, -77, -90, -80, -121, -111, -27, 1, 27, 67, -72, -91, -74, 109, -98, -26)
// precomputed: N = BigInteger.parseString(HEX_N, 16)
private val nByteArray = byteArrayOf(0, -1, -1, -1, -1, -1, -1, -1, -1, -55, 15, -38, -94, 33, 104, -62, 52, -60, -58, 98, -117, -128, -36, 28, -47, 41, 2, 78, 8, -118, 103, -52, 116, 2, 11, -66, -90, 59, 19, -101, 34, 81, 74, 8, 121, -114, 52, 4, -35, -17, -107, 25, -77, -51, 58, 67, 27, 48, 43, 10, 109, -14, 95, 20, 55, 79, -31, 53, 109, 109, 81, -62, 69, -28, -123, -75, 118, 98, 94, 126, -58, -12, 76, 66, -23, -90, 55, -19, 107, 11, -1, 92, -74, -12, 6, -73, -19, -18, 56, 107, -5, 90, -119, -97, -91, -82, -97, 36, 17, 124, 75, 31, -26, 73, 40, 102, 81, -20, -28, 91, 61, -62, 0, 124, -72, -95, 99, -65, 5, -104, -38, 72, 54, 28, 85, -45, -102, 105, 22, 63, -88, -3, 36, -49, 95, -125, 101, 93, 35, -36, -93, -83, -106, 28, 98, -13, 86, 32, -123, 82, -69, -98, -43, 41, 7, 112, -106, -106, 109, 103, 12, 53, 78, 74, -68, -104, 4, -15, 116, 108, 8, -54, 24, 33, 124, 50, -112, 94, 70, 46, 54, -50, 59, -29, -98, 119, 44, 24, 14, -122, 3, -101, 39, -125, -94, -20, 7, -94, -113, -75, -59, 93, -16, 111, 76, 82, -55, -34, 43, -53, -10, -107, 88, 23, 24, 57, -107, 73, 124, -22, -107, 106, -27, 21, -46, 38, 24, -104, -6, 5, 16, 21, 114, -114, 90, -118, -86, -60, 45, -83, 51, 23, 13, 4, 80, 122, 51, -88, 85, 33, -85, -33, 28, -70, 100, -20, -5, -123, 4, 88, -37, -17, 10, -118, -22, 113, 87, 93, 6, 12, 125, -77, -105, 15, -123, -90, -31, -28, -57, -85, -11, -82, -116, -37, 9, 51, -41, 30, -116, -108, -32, 74, 37, 97, -99, -50, -29, -46, 38, 26, -46, -18, 107, -15, 47, -6, 6, -39, -118, 8, 100, -40, 118, 2, 115, 62, -56, 106, 100, 82, 31, 43, 24, 23, 123, 32, 12, -69, -31, 23, 87, 122, 97, 93, 108, 119, 9, -120, -64, -70, -39, 70, -30, 8, -30, 79, -96, 116, -27, -85, 49, 67, -37, 91, -4, -32, -3, 16, -114, 75, -126, -47, 32, -87, 58, -46, -54, -1, -1, -1, -1, -1, -1, -1, -1)

@OptIn(ExperimentalEncodingApi::class)
@Suppress("TooManyFunctions")
class SRPHelper(private val password: String, userPoolName: String) {
class SRPHelper(userPool: String) {
@Suppress("VariableNaming")
private val N = BigInteger.parseString(HEX_N, 16)
private val N = BigInteger.fromTwosComplementByteArray(nByteArray)

private val creator = ModularBigInteger.creatorForModulo(N)
private val g = creator.fromInt(2)

private val random = SecureRandom()

private val k: BigInteger
private var privateA: BigInteger
private var publicA: ModularBigInteger
private val k: BigInteger = BigInteger.fromTwosComplementByteArray(kByteArray)
lateinit var privateA: BigInteger
lateinit var publicA: ModularBigInteger
var timestamp: String = nowAsFormattedString()
internal set

private val digest = SHA256()
var userIdForSrp: String? = null
private val userPoolName: String

init {
if (userPoolName.contains("_")) {
this.userPoolName = userPoolName.split(Regex("_"), 2)[1]
if (userPool.contains("_")) {
this.userPoolName = userPool.split(Regex("_"), 2)[1]
} else {
this.userPoolName = userPoolName
this.userPoolName = userPool
}
}

// Generate client private 'a' and public 'A' values
// Generate client private 'a' and public 'A' values
suspend fun getPublicA(): String {
do {
privateA = BigInteger.fromByteArray(random.nextBytesOf(EPHEMERAL_KEY_LENGTH), Sign.POSITIVE).mod(N)
// A = (g ^ a) % N
publicA = g.pow(privateA)
} while (publicA.residue == BigInteger.ZERO)

// compute k = H(g|N)
digest.reset()
digest.update(N.toTwosComplementByteArray())
k = BigInteger.fromByteArray(digest.digest(g.toByteArray()), Sign.POSITIVE)
return publicA.toString(HEX)
}

// @TestOnly
Expand All @@ -94,8 +97,6 @@ class SRPHelper(private val password: String, userPoolName: String) {
this.publicA = creator.fromBigInteger(publicA)
}

fun getPublicA(): String = publicA.toString(HEX)

// u = H(A, B)
internal fun computeU(srpB: BigInteger): BigInteger {
digest.reset()
Expand All @@ -105,10 +106,10 @@ class SRPHelper(private val password: String, userPoolName: String) {

// x = H(salt | H(poolName | userId | ":" | password))
@Throws(CognitoException::class)
internal fun computeX(salt: BigInteger): BigInteger {
internal fun computeX(salt: BigInteger, userIdForSrp: String, password: String): BigInteger {
digest.reset()
digest.update(userPoolName.toByteArray())
digest.update(userIdForSrp?.toByteArray() ?: throw CognitoException.UserIdNotSet)
digest.update(userIdForSrp.toByteArray())
digest.update(":".toByteArray())
val userIdPasswordHash = digest.digest(password.toByteArray())

Expand All @@ -119,8 +120,12 @@ class SRPHelper(private val password: String, userPoolName: String) {

// verifier = (g ^ x) % N
@Throws(CognitoException::class)
internal fun computePasswordVerifier(salt: BigInteger): ModularBigInteger {
val xValue = computeX(salt)
internal fun computePasswordVerifier(
salt: BigInteger,
userIdForSrp: String,
password: String,
): ModularBigInteger {
val xValue = computeX(salt = salt, userIdForSrp = userIdForSrp, password = password)
return g.pow(xValue)
}

Expand Down Expand Up @@ -151,10 +156,14 @@ class SRPHelper(private val password: String, userPoolName: String) {

// M1 = MAC(poolId | userId | secret | timestamp, key)
@Throws(CognitoException::class)
internal fun generateM1Signature(key: ByteArray, secretBlock: String): ByteArray {
internal fun generateM1Signature(
key: ByteArray,
secretBlock: String,
userIdForSrp: String,
): ByteArray {
val mac = HmacSHA256(key)
mac.update(userPoolName.toByteArray())
mac.update(userIdForSrp?.toByteArray() ?: throw CognitoException.UserIdNotSet)
mac.update(userIdForSrp.toByteArray())
mac.update(Base64.decode(secretBlock))
return mac.doFinal(timestamp.toByteArray())
}
Expand All @@ -174,10 +183,18 @@ class SRPHelper(private val password: String, userPoolName: String) {
* @param srpB The SRP_B value provided by the Cognito service.
* @param secretBlock The secret block - should be passed into PASSWORD_CLAIM_SECRET_BLOCK
* for the subsequent call to AWSCognitoIdentityProviderService.RespondToAuthChallenge
* @param userIdForSrp The user ID used in the authentication process.
* @param password The password used in the authentication process.
* @return A string representing the PASSWORD_CLAIM_SIGNATURE for authentication.
*/
@Throws(CognitoException::class, CancellationException::class)
suspend fun getSignature(salt: String, srpB: String, secretBlock: String): String {
suspend fun getSignature(
salt: String,
srpB: String,
secretBlock: String,
userIdForSrp: String,
password: String,
): String {
val bigIntSRPB = BigInteger.parseString(srpB, HEX)
val bigIntSalt = BigInteger.parseString(salt, HEX)

Expand All @@ -191,10 +208,14 @@ class SRPHelper(private val password: String, userPoolName: String) {
throw CognitoException.HashOfAAndSrpBCannotBeZero
}

val xValue = computeX(bigIntSalt)
val xValue = computeX(salt = bigIntSalt, userIdForSrp = userIdForSrp, password = password)
val sValue = computeS(uValue, xValue, bigIntSRPB)
val key = computePasswordAuthenticationKey(sValue, uValue)
val m1Signature = generateM1Signature(key, secretBlock)
val m1Signature = generateM1Signature(
key = key,
secretBlock = secretBlock,
userIdForSrp = userIdForSrp,
)
return Base64.encode(m1Signature)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
import kotlinx.coroutines.runBlocking


@OptIn(ExperimentalEncodingApi::class)
Expand Down Expand Up @@ -101,18 +102,20 @@ class SRPHelperTests {
private val m1Expected = "QG7a57h+ndPBVasvx/OkmsJdy5uoMEVRshboEd4S+j8="

private lateinit var helper: SRPHelper
private val password = "Password123"
private val userIdForSrp = "username"

@BeforeTest
fun setUp() {
helper = SRPHelper(password = "Password123", userPoolName = "us-east-2_KO6fcefgd")
helper = SRPHelper(userPool = "us-east-2_KO6fcefgd")
helper.setAValues(privateA, publicA)
helper.userIdForSrp = "username"
}

@Test
fun testValidPublicA() {
val testHelper = SRPHelper(password = "Password123", userPoolName = "us-east-2_KO6fcefgd")
val bigA = BigInteger.parseString(testHelper.getPublicA(), 16)
val testHelper = SRPHelper(userPool = "us-east-2_KO6fcefgd")
val publicA = runBlocking { testHelper.getPublicA() }
val bigA = BigInteger.parseString(publicA, 16)
assertNotEquals(BigInteger.ZERO, testHelper.modN(bigA))
}

Expand All @@ -126,7 +129,7 @@ class SRPHelperTests {
fun testComputeX() {
val salt = BigInteger.parseString("e7dc204cebbfda6b62b8493e932f7f4c", 16)
println(salt)
val xActual = helper.computeX(salt)
val xActual = helper.computeX(salt = salt, userIdForSrp = userIdForSrp, password = password)
println(xActual)
assertEquals(xExpected, xActual)
}
Expand All @@ -146,7 +149,11 @@ class SRPHelperTests {
@Test
fun testGenerateM1() {
helper.timestamp = "Wed Sep 29 06:40:48 UTC 2021"
val m1Actual = helper.generateM1Signature(keyExpected.toByteArray(), secretBlock)
val m1Actual = helper.generateM1Signature(
key = keyExpected.toByteArray(),
secretBlock = secretBlock,
userIdForSrp = userIdForSrp,
)
assertEquals(m1Expected, Base64.encode(m1Actual))
}
}

0 comments on commit 452b029

Please sign in to comment.