Skip to content

Commit

Permalink
refactor key-import, keystores
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeplotean committed Sep 13, 2023
1 parent fcd2bf1 commit b8099fc
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 235 deletions.
12 changes: 7 additions & 5 deletions src/main/kotlin/id/walt/cli/did/DidCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ class CreateDidCommand : CliktCommand(
echo("Creating did:${method.method} (key: ${keyId})")

val did = when (method) {
is WebMethodOption -> DidService.create(web, keyId, DidWebCreateOptions((method as WebMethodOption).domain, (method as WebMethodOption).path))
is EbsiMethodOption -> DidService.create(ebsi, keyId, DidEbsiCreateOptions((method as EbsiMethodOption).version))
is CheqdMethodOption -> DidService.create(cheqd, keyId, DidCheqdCreateOptions((method as CheqdMethodOption).network))
is KeyMethodOption -> DidService.create(key, keyId, DidKeyCreateOptions((method as KeyMethodOption).useJwkJcsPubMulticodec))
else -> DidService.create(DidMethod.valueOf(method.method), keyId)
is WebMethodOption -> DidWebCreateOptions((method as WebMethodOption).domain, (method as WebMethodOption).path)
is EbsiMethodOption -> DidEbsiCreateOptions((method as EbsiMethodOption).version)
is CheqdMethodOption -> DidCheqdCreateOptions((method as CheqdMethodOption).network)
is KeyMethodOption -> DidKeyCreateOptions((method as KeyMethodOption).useJwkJcsPubMulticodec)
else -> null
}.let{
DidService.create(DidMethod.valueOf(method.method), keyId, it)
}

echo("\nResults:\n")
Expand Down
25 changes: 15 additions & 10 deletions src/main/kotlin/id/walt/crypto/CryptFun.kt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ fun java.security.Key.toPEM(): String = when (this) {
else -> throw IllegalArgumentException()
}

fun java.security.Key.toBase64(): String = when (this) {
is PublicKey -> this.toBase64()
is PrivateKey -> this.toBase64()
else -> throw IllegalArgumentException()
}

fun PrivateKey.toPEM(): String =
"-----BEGIN PRIVATE KEY-----" +
System.lineSeparator() +
Expand All @@ -104,6 +110,8 @@ fun PrivateKey.toPEM(): String =

fun PrivateKey.toBase64(): String = String(Base64.getEncoder().encode(PKCS8EncodedKeySpec(this.encoded).encoded))

fun PublicKey.toBase64(): String = encBase64(X509EncodedKeySpec(this.encoded).encoded)

fun PublicKey.toPEM(): String =
"-----BEGIN PUBLIC KEY-----" +
System.lineSeparator() +
Expand All @@ -124,9 +132,6 @@ fun decBase64(base64: String): ByteArray = Base64.getDecoder().decode(base64)

fun toBase64Url(base64: String) = base64.replace("+", "-").replace("/", "_").replace("=", "")


fun PublicKey.toBase64(): String = encBase64(X509EncodedKeySpec(this.encoded).encoded)

fun decodePubKeyBase64(base64: String, kf: KeyFactory): PublicKey =
kf.generatePublic(X509EncodedKeySpec(decBase64(base64)))

Expand Down Expand Up @@ -162,7 +167,7 @@ fun buildKey(
keyId: String,
algorithm: String,
provider: String,
publicPart: String,
publicPart: String?,
privatePart: String?,
format: KeyFormat = KeyFormat.PEM
): Key {
Expand All @@ -174,16 +179,16 @@ fun buildKey(
}
val keyPair = when (format) {
KeyFormat.PEM -> KeyPair(
decodePubKeyPem(publicPart, keyFactory),
privatePart?.let { decodePrivKeyPem(privatePart, keyFactory) })
publicPart?.let { decodePubKeyPem(it, keyFactory) },
privatePart?.let { decodePrivKeyPem(it, keyFactory) })

KeyFormat.BASE64_DER -> KeyPair(
decodePubKeyBase64(publicPart, keyFactory),
privatePart?.let { decodePrivKeyBase64(privatePart, keyFactory) })
publicPart?.let { decodePubKeyBase64(it, keyFactory) },
privatePart?.let { decodePrivKeyBase64(it, keyFactory) })

KeyFormat.BASE64_RAW -> KeyPair(
decodeRawPubKeyBase64(publicPart, keyFactory),
privatePart?.let { decodeRawPrivKey(privatePart, keyFactory) })
publicPart?.let { decodeRawPubKeyBase64(it, keyFactory) },
privatePart?.let { decodeRawPrivKey(it, keyFactory) })
}

return Key(KeyId(keyId), KeyAlgorithm.valueOf(algorithm), CryptoProvider.valueOf(provider), keyPair)
Expand Down
63 changes: 63 additions & 0 deletions src/main/kotlin/id/walt/services/key/import/JwkKeyImport.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package id.walt.services.key.import

import com.nimbusds.jose.jwk.Curve
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.KeyType
import id.walt.crypto.*
import id.walt.services.CryptoProvider
import id.walt.services.keystore.KeyStoreService

class JwkKeyImport(private val keyString: String) : KeyImportStrategy {

override fun import(keyStore: KeyStoreService): KeyId {
val key = parseJwkKey(keyString)
keyStore.store(key)
return key.keyId
}

private fun parseJwkKey(jwkKeyStr: String): Key {
val jwk = JWK.parse(jwkKeyStr)

val key = when (jwk.keyType) {
KeyType.RSA -> Key(
keyId = KeyId(jwk.keyID ?: newKeyId().id),
algorithm = KeyAlgorithm.RSA,
cryptoProvider = CryptoProvider.SUN,
keyPair = jwk.toRSAKey().toKeyPair()
)

KeyType.EC -> {
val alg = when (jwk.toECKey().curve) {
Curve.P_256 -> KeyAlgorithm.ECDSA_Secp256r1
Curve.SECP256K1 -> KeyAlgorithm.ECDSA_Secp256k1
else -> throw IllegalArgumentException("EC key with curve ${jwk.toECKey().curve} not suppoerted")
}
Key(
keyId = KeyId(jwk.keyID ?: newKeyId().id),
algorithm = alg,
cryptoProvider = CryptoProvider.SUN,
keyPair = jwk.toECKey().toKeyPair()
)
}

KeyType.OKP -> {
val alg = when (jwk.toOctetKeyPair().curve) {
Curve.Ed25519 -> KeyAlgorithm.EdDSA_Ed25519
else -> throw IllegalArgumentException("OKP key with curve ${jwk.toOctetKeyPair().curve} not supported")
}
buildKey(
keyId = jwk.keyID ?: newKeyId().id,
algorithm = alg.name,
provider = CryptoProvider.SUN.name,
publicPart = jwk.toOctetKeyPair().x.toString(),
privatePart = jwk.toOctetKeyPair().d?.let { jwk.toOctetKeyPair().d.toString() },
format = KeyFormat.BASE64_RAW
)
}

else -> throw IllegalArgumentException("KeyType ${jwk.keyType} / Algorithm ${jwk.algorithm} not supported")
}
return key
}

}
160 changes: 3 additions & 157 deletions src/main/kotlin/id/walt/services/key/import/KeyImportStrategy.kt
Original file line number Diff line number Diff line change
@@ -1,25 +1,7 @@
package id.walt.services.key.import

import com.nimbusds.jose.jwk.Curve
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.KeyType
import id.walt.crypto.*
import id.walt.services.CryptoProvider
import id.walt.crypto.KeyId
import id.walt.services.keystore.KeyStoreService
import mu.KotlinLogging
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.openssl.PEMKeyPair
import org.bouncycastle.openssl.PEMParser
import java.io.StringReader
import java.security.KeyFactory
import java.security.KeyPair
import java.security.PrivateKey
import java.security.PublicKey
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec

interface KeyImportStrategy {
fun import(keyStore: KeyStoreService): KeyId
Expand All @@ -28,146 +10,10 @@ interface KeyImportStrategy {
abstract class KeyImportFactory {
companion object {
fun create(keyString: String) = when (isPEM(keyString)) {
true -> PEMImportImpl(keyString)
false -> JWKImportImpl(keyString)
true -> PemKeyImport(keyString)
false -> JwkKeyImport(keyString)
}

private fun isPEM(keyString: String) = keyString.startsWith("-----")
}
}

class PEMImportImpl(val keyString: String) : KeyImportStrategy {

private val log = KotlinLogging.logger {}

override fun import(keyStore: KeyStoreService) = importPem(keyString, keyStore)

/**
* Imports the given PEM encoded key string
* @param keyStr the key string
*
* - for RSA keys: the PEM private key file
* - for other key types: concatenated public and private key in PEM format
* @return the imported key id
*/
private fun importPem(keyStr: String, keyStore: KeyStoreService): KeyId {
val parser = PEMParser(StringReader(keyStr))
val parsedPemObject = mutableListOf<Any>()
try {
var currentPEMObject: Any?
do {
currentPEMObject = parser.readObject()
log.debug { "PEM parser next object: $currentPEMObject" }
if (currentPEMObject != null) {
parsedPemObject.add(currentPEMObject)
}
} while (currentPEMObject != null)
} catch (e: Exception) {
log.error(e) { "Error while importing PEM key!" }
}

val kid = newKeyId()
val keyPair = getKeyPair(parsedPemObject)
keyStore.store(Key(kid, KeyAlgorithm.fromString(keyPair.public.algorithm), CryptoProvider.SUN, keyPair))

return kid
}

/**
* Parses a keypair out of a one or multiple objects
*/
private fun getKeyPair(objs: List<Any>): KeyPair {
lateinit var pubKey: PublicKey
lateinit var privKey: PrivateKey

objs.toList()

log.debug { "Searching key pair in: $objs" }
for (obj in objs) {
if (obj is SubjectPublicKeyInfo) {
pubKey = getPublicKey(obj)
}
if (obj is PrivateKeyInfo) {
privKey = getPrivateKey(obj)
}
if (obj is PEMKeyPair) {
pubKey = getPublicKey(obj.publicKeyInfo)
privKey = getPrivateKey(obj.privateKeyInfo)
break
}
}
return KeyPair(pubKey, privKey)
}

private fun getPublicKey(key: SubjectPublicKeyInfo): PublicKey {
val kf = getKeyFactory(key.algorithm.algorithm)
return kf.generatePublic(X509EncodedKeySpec(key.encoded))
}

private fun getPrivateKey(key: PrivateKeyInfo): PrivateKey {
val kf = getKeyFactory(key.privateKeyAlgorithm.algorithm)
return kf.generatePrivate(PKCS8EncodedKeySpec(key.encoded))
}

private fun getKeyFactory(alg: ASN1ObjectIdentifier): KeyFactory = when (alg) {
PKCSObjectIdentifiers.rsaEncryption -> KeyFactory.getInstance("RSA")
ASN1ObjectIdentifier("1.3.101.112") -> KeyFactory.getInstance("Ed25519")
ASN1ObjectIdentifier("1.2.840.10045.2.1") -> KeyFactory.getInstance("ECDSA")
else -> throw IllegalArgumentException("Algorithm not supported")
}
}

class JWKImportImpl(val keyString: String) : KeyImportStrategy {

override fun import(keyStore: KeyStoreService): KeyId {
val key = parseJwkKey(keyString)
keyStore.store(key)
return key.keyId
}

private fun parseJwkKey(jwkKeyStr: String): Key {
val jwk = JWK.parse(jwkKeyStr)

val key = when (jwk.keyType) {
KeyType.RSA -> Key(
keyId = KeyId(jwk.keyID ?: newKeyId().id),
algorithm = KeyAlgorithm.RSA,
cryptoProvider = CryptoProvider.SUN,
keyPair = jwk.toRSAKey().toKeyPair()
)

KeyType.EC -> {
val alg = when (jwk.toECKey().curve) {
Curve.P_256 -> KeyAlgorithm.ECDSA_Secp256r1
Curve.SECP256K1 -> KeyAlgorithm.ECDSA_Secp256k1
else -> throw IllegalArgumentException("EC key with curve ${jwk.toECKey().curve} not suppoerted")
}
Key(
keyId = KeyId(jwk.keyID ?: newKeyId().id),
algorithm = alg,
cryptoProvider = CryptoProvider.SUN,
keyPair = jwk.toECKey().toKeyPair()
)
}

KeyType.OKP -> {
val alg = when (jwk.toOctetKeyPair().curve) {
Curve.Ed25519 -> KeyAlgorithm.EdDSA_Ed25519
else -> throw IllegalArgumentException("OKP key with curve ${jwk.toOctetKeyPair().curve} not supported")
}
buildKey(
keyId = jwk.keyID ?: newKeyId().id,
algorithm = alg.name,
provider = CryptoProvider.SUN.name,
publicPart = jwk.toOctetKeyPair().x.toString(),
privatePart = jwk.toOctetKeyPair().d?.let { jwk.toOctetKeyPair().d.toString() },
format = KeyFormat.BASE64_RAW
)
}

else -> throw IllegalArgumentException("KeyType ${jwk.keyType} / Algorithm ${jwk.algorithm} not supported")
}
return key
}

}
Loading

0 comments on commit b8099fc

Please sign in to comment.