diff --git a/src/main/scala/com/snowflake/snowpark/internal/ParameterUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/ParameterUtils.scala index 196b876c..adb9fd06 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ParameterUtils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ParameterUtils.scala @@ -2,7 +2,7 @@ package com.snowflake.snowpark.internal import net.snowflake.client.core.SFSessionProperty -import java.security.spec.RSAPrivateCrtKeySpec +import java.security.spec.{PKCS8EncodedKeySpec, RSAPrivateCrtKeySpec} import java.security.{GeneralSecurityException, KeyFactory, PrivateKey} import java.util.Properties import org.apache.commons.codec.binary.Base64 @@ -104,38 +104,55 @@ private[snowpark] object ParameterUtils extends Logging { } private[snowpark] def parsePrivateKey(key: String): PrivateKey = { + // try to parse pkcs#8 format first, + // if it fails, then try to parse pkcs#1 format. try { val decoded = Base64.decodeBase64(key) - val derReader = new DerInputStream(decoded) - val seq = derReader.getSequence(0) - - if (seq.length < 9) { - throw new GeneralSecurityException("Could not parse a PKCS1 private key.") - } - - // seq(0) is version, skip - val modulus = seq(1).getBigInteger - val publicExp = seq(2).getBigInteger - val privateExp = seq(3).getBigInteger - val prime1 = seq(4).getBigInteger - val prime2 = seq(5).getBigInteger - val exp1 = seq(6).getBigInteger - val exp2 = seq(7).getBigInteger - val crtCoef = seq(8).getBigInteger - val keySpec = new RSAPrivateCrtKeySpec( - modulus, - publicExp, - privateExp, - prime1, - prime2, - exp1, - exp2, - crtCoef) - val keyFactory = KeyFactory.getInstance("RSA") - keyFactory.generatePrivate(keySpec) + val kf = KeyFactory.getInstance("RSA") + val keySpec = new PKCS8EncodedKeySpec(decoded) + kf.generatePrivate(keySpec) } catch { - case e: Exception => - throw ErrorMessage.MISC_INVALID_RSA_PRIVATE_KEY(e.getMessage) + case pkcs8Exception: Exception => + // try to read PKCS#1 key + try { + val decoded = Base64.decodeBase64(key) + val derReader = new DerInputStream(decoded) + val seq = derReader.getSequence(0) + + if (seq.length < 9) { + throw new GeneralSecurityException("Could not parse a PKCS1 private key.") + } + + // seq(0) is version, skip + val modulus = seq(1).getBigInteger + val publicExp = seq(2).getBigInteger + val privateExp = seq(3).getBigInteger + val prime1 = seq(4).getBigInteger + val prime2 = seq(5).getBigInteger + val exp1 = seq(6).getBigInteger + val exp2 = seq(7).getBigInteger + val crtCoef = seq(8).getBigInteger + val keySpec = new RSAPrivateCrtKeySpec( + modulus, + publicExp, + privateExp, + prime1, + prime2, + exp1, + exp2, + crtCoef) + val keyFactory = KeyFactory.getInstance("RSA") + keyFactory.generatePrivate(keySpec) + } catch { + case pkcs1Exception: Exception => + val errorMessage = + s"""Failed to parse PKCS#8 RSA Private key + |${pkcs8Exception.getMessage} + |Failed to parse PKCS#1 RSA Private key + |${pkcs1Exception.getMessage} + |""".stripMargin + throw ErrorMessage.MISC_INVALID_RSA_PRIVATE_KEY(errorMessage) + } } } diff --git a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala index d1814472..8ee5dddd 100644 --- a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala @@ -3,6 +3,10 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.{ParameterUtils, ServerConnection} import net.snowflake.client.core.SFSessionProperty +import java.security.KeyPairGenerator +import java.security.spec.PKCS8EncodedKeySpec +import java.util.Base64 + class ParameterSuite extends SNTestBase { val options: Map[String, String] = Session.loadConfFromFile(defaultProfile) @@ -61,8 +65,26 @@ class ParameterSuite extends SNTestBase { // scalastyle:on } - assertThrows[Exception]( + val ex = intercept[SnowparkClientException] { ParameterUtils - .jdbcConfig(optionWithoutKey + ("privatekey" -> "wrong key"), isScalaAPI = true)) + .jdbcConfig(optionWithoutKey + ("privatekey" -> "wrong key"), isScalaAPI = true) + } + assert(ex.message.contains("Failed to parse PKCS#8 RSA Private key")) + assert(ex.message.contains("Failed to parse PKCS#1 RSA Private key")) + } + + test("enable to read PKCS#8 private keys") { + // no need to verify PKCS#1 format key additionally, + // since all Github Action tests use PKCS#1 key to authenticate with Snowflake server. + ParameterUtils.parsePrivateKey(generatePKCS8Key()) + } + + private def generatePKCS8Key(): String = { + val keyPairGenerator = KeyPairGenerator.getInstance("RSA") + keyPairGenerator.initialize(2048) + val keyPair = keyPairGenerator.generateKeyPair() + val privateKey = keyPair.getPrivate + val encodedKeySpec = new PKCS8EncodedKeySpec(privateKey.getEncoded) + Base64.getEncoder.encodeToString(encodedKeySpec.getEncoded) } }