Skip to content

Commit

Permalink
#75 mergefix
Browse files Browse the repository at this point in the history
  • Loading branch information
dk1844 committed Nov 6, 2023
1 parent a428fe5 commit 64f509e
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 100 deletions.
3 changes: 2 additions & 1 deletion service/src/main/resources/example.application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ loginsvc:
generate-in-memory:
access-exp-time: 15min
refresh-exp-time: 9h
rotation-time: 9h
key-rotation-time: 9h
alg-name: "RS256"
#Instead of generating the key in memory
#The Below Config allows for the application to fetch keys from AWS Secrets Manager.
Expand All @@ -16,6 +16,7 @@ loginsvc:
#private-key-field-name: "privateKey"
#public-key-field-name: "publicKey"
#access-exp-time: 15min
#refresh-exp-time: 9h
#poll-time: 5min
#alg-name: "RS256"
config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package za.co.absa.loginsvc.rest.config.jwt

import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import io.jsonwebtoken.SignatureAlgorithm
import org.slf4j.LoggerFactory
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider
import software.amazon.awssdk.regions.Region
Expand All @@ -30,20 +29,22 @@ import java.security.{KeyFactory, KeyPair}
import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec}
import java.util.Base64
import scala.concurrent.duration.FiniteDuration
import scala.util.{Failure, Success, Try}

case class AwsSecretsManagerKeyConfig (secretName: String,
region: String,
privateKeyFieldName: String,
publicKeyFieldName: String,
algName: String,
accessExpTime: FiniteDuration,
pollTime: Option[FiniteDuration])
extends KeyConfig {

case class AwsSecretsManagerKeyConfig(
secretName: String,
region: String,
privateKeyFieldName: String,
publicKeyFieldName: String,
algName: String,
accessExpTime: FiniteDuration,
refreshExpTime: FiniteDuration,
pollTime: Option[FiniteDuration]
) extends KeyConfig {

private val logger = LoggerFactory.getLogger(classOf[AwsSecretsManagerKeyConfig])

override def refreshKeyTime : Option[FiniteDuration] = pollTime
override def keyRotationTime : Option[FiniteDuration] = pollTime
override def keyPair(): KeyPair = {

val default = DefaultCredentialsProvider.create
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@ package za.co.absa.loginsvc.rest.config.jwt
import io.jsonwebtoken.SignatureAlgorithm
import io.jsonwebtoken.security.Keys
import org.slf4j.LoggerFactory
import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess}
import za.co.absa.loginsvc.rest.config.validation.{ConfigValidationException, ConfigValidationResult}

import scala.util.{Failure, Success, Try}
import java.security.KeyPair
import scala.concurrent.duration.FiniteDuration

case class InMemoryKeyConfig (algName: String,
accessExpTime: FiniteDuration,
rotationTime: Option[FiniteDuration])
extends KeyConfig {
case class InMemoryKeyConfig(
algName: String,
accessExpTime: FiniteDuration,
refreshExpTime: FiniteDuration,
keyRotationTime: Option[FiniteDuration]
) extends KeyConfig {

private val logger = LoggerFactory.getLogger(classOf[InMemoryKeyConfig])

override def refreshKeyTime : Option[FiniteDuration] = rotationTime
override def keyPair(): KeyPair = {
logger.info("Generating new keys")
logger.info(s"Generating new keys - every ${keyRotationTime.getOrElse("?")}")
Keys.keyPairFor(SignatureAlgorithm.valueOf(algName))
}

override def throwErrors(): Unit = this.validate().throwOnErrors()

}

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package za.co.absa.loginsvc.rest.config.jwt

import io.jsonwebtoken.SignatureAlgorithm
import za.co.absa.loginsvc.rest.config.JwtConfig.{minAccessExpTime, minRefreshExpTime}
import org.slf4j.LoggerFactory
import za.co.absa.loginsvc.rest.config.validation.{ConfigValidatable, ConfigValidationException, ConfigValidationResult}
import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess}

Expand All @@ -29,7 +29,8 @@ import scala.util.{Failure, Success, Try}
trait KeyConfig extends ConfigValidatable {
def algName: String
def accessExpTime: FiniteDuration
def refreshKeyTime: Option[FiniteDuration]
def refreshExpTime: FiniteDuration
def keyRotationTime: Option[FiniteDuration]
def keyPair(): KeyPair
def throwErrors(): Unit

Expand All @@ -45,6 +46,8 @@ trait KeyConfig extends ConfigValidatable {
})
}

private val logger = LoggerFactory.getLogger(classOf[KeyConfig])

override def validate(): ConfigValidationResult = {

val algValidation = Try {
Expand All @@ -60,19 +63,24 @@ trait KeyConfig extends ConfigValidatable {
ConfigValidationError(ConfigValidationException(s"accessExpTime must be at least ${KeyConfig.minAccessExpTime}"))
} else ConfigValidationSuccess

val refreshKeyTimeResult = if (refreshKeyTime.nonEmpty && refreshKeyTime.get < KeyConfig.minRefreshKeyTime) {
ConfigValidationError(ConfigValidationException(s"refreshKeyTime must be at least ${KeyConfig.minRefreshKeyTime}"))
val refreshExpTimeResult = if (refreshExpTime < KeyConfig.minRefreshExpTime) {
ConfigValidationError(ConfigValidationException(s"refreshExpTime must be at least ${KeyConfig.minRefreshExpTime}"))
} else ConfigValidationSuccess

val refreshExpTimeResult = if (refreshExpTime < minRefreshExpTime) {
ConfigValidationError(ConfigValidationException(s"refreshExpTime must be at least $minRefreshExpTime"))
val keyRotationTimeResult = if (keyRotationTime.nonEmpty && keyRotationTime.get < KeyConfig.minKeyRotationTime) {
ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}"))
} else ConfigValidationSuccess

algValidation.merge(accessExpTimeResult).merge(refreshKeyTimeResult).merge(refreshExpTimeResult)
if (keyRotationTime.isEmpty) {
logger.warn("keyRotationTime is not set in config, key-pair will not be rotated!")
}

algValidation.merge(accessExpTimeResult).merge(refreshExpTimeResult).merge(keyRotationTimeResult)
}
}

object KeyConfig {
val minAccessExpTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
val minRefreshKeyTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
val minRefreshExpTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
val minKeyRotationTime: FiniteDuration = FiniteDuration(10, TimeUnit.MILLISECONDS)
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConfigProvider(@Value("${spring.config.location}") yamlPath: String)
getOrElse(BaseConfig(""))
}

def getJWTConfig : KeyConfig = {
def getJwtKeyConfig : KeyConfig = {
val inMemoryKeyConfig: Option[InMemoryKeyConfig] = createConfigClass[InMemoryKeyConfig]("loginsvc.rest.jwt.generate-in-memory")
val awsSecretsManagerKeyConfig: Option[AwsSecretsManagerKeyConfig] = createConfigClass[AwsSecretsManagerKeyConfig]("loginsvc.rest.jwt.aws-secrets-manager")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ package za.co.absa.loginsvc.rest.config.provider
import za.co.absa.loginsvc.rest.config.jwt.KeyConfig

trait JwtConfigProvider {
def getJWTConfig : KeyConfig
def getJwtKeyConfig: KeyConfig
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
private val logger = LoggerFactory.getLogger(classOf[JWTService])
private val scheduler = Executors.newSingleThreadScheduledExecutor()

private val jwtConfig = jwtConfigProvider.getJWTConfig
private val jwtConfig = jwtConfigProvider.getJwtKeyConfig
@volatile private var keyPair: KeyPair = jwtConfig.keyPair()

if(jwtConfig.refreshKeyTime.nonEmpty)
if(jwtConfig.keyRotationTime.nonEmpty)
{
val refreshTime = jwtConfig.refreshKeyTime.get
val refreshTime = jwtConfig.keyRotationTime.get
scheduleSecretsRefresh(refreshTime)
}

def generateToken(user: User): AccessToken = {
def generateAccessToken(user: User): AccessToken = {
logger.info(s"Generating Token for user: ${user.name}")
import scala.collection.JavaConverters._

Expand All @@ -73,8 +73,8 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
.setIssuedAt(issuedAt)
.claim("kid", publicKeyThumbprint)
.claim("groups", groupsClaim)
.applyIfDefined(user.email, (builder, value: String) => builder.claim("email", value))
.applyIfDefined(user.displayName, (builder, value: String) => builder.claim("displayname", value))
.applyIfDefined(user.email)((builder, value: String) => builder.claim("email", value))
.applyIfDefined(user.displayName)((builder, value: String) => builder.claim("displayname", value))
.claim("type", Token.TokenType.Access.toString)
.signWith(keyPair.getPrivate)
.compact()
Expand All @@ -95,7 +95,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
.setExpiration(expiration)
.setIssuedAt(issuedAt)
.claim("type", Token.TokenType.Refresh.toString)
.signWith(rsaKeyPair.getPrivate)
.signWith(keyPair.getPrivate)
.compact()

RefreshToken(tokenContent)
Expand All @@ -104,7 +104,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
def refreshTokens(accessToken: AccessToken, refreshToken: RefreshToken): (AccessToken, RefreshToken) = {
val oldAccessJws: Jws[Claims] = Jwts.parserBuilder()
.require("type", Token.TokenType.Access.toString)
.setSigningKey(rsaKeyPair.getPublic)
.setSigningKey(keyPair.getPublic)
.setClock(() => Date.from(Instant.now().minus(jwtConfig.refreshExpTime.toJava))) // allowing expired access token - up to refresh token validity window
.build()
.parseClaimsJws(accessToken.token) // checks requirements: type=access, signature, custom validity window
Expand All @@ -114,7 +114,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
Jwts.parserBuilder()
.require("type", Token.TokenType.Refresh.toString)
.requireSubject(userFromOldAccessToken.name)
.setSigningKey(rsaKeyPair.getPublic)
.setSigningKey(keyPair.getPublic)
.build()
.parseClaimsJws(refreshToken.token) // checks username, validity, and signature.

Expand Down Expand Up @@ -190,19 +190,12 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
}

object JWTService {
// todo remove? not used?
implicit class JwtBuilderExt(val jwtBuilder: JwtBuilder) extends AnyVal {
def applyIfDefined[T](opt: Option[T], fn: (JwtBuilder, T) => JwtBuilder): JwtBuilder = {
OptionExt.applyIfDefined(jwtBuilder, opt, fn)
}
def extractUserFrom(claims: Claims): User = {
val name = claims.getSubject
val groups = claims.get("groups", classOf[java.util.List[String]]).asScala
val email = Option(claims.get("email", classOf[String]))
val displayName = Option(claims.get("displayname", classOf[String]))

def extractUserFrom(claims: Claims): User = {
val name = claims.getSubject
val groups = claims.get("groups", classOf[java.util.List[String]]).asScala
val email = Option(claims.get("email", classOf[String]))
val displayName = Option(claims.get("displayname", classOf[String]))

User(name, email, displayName, groups)
}
User(name, email, displayName, groups)
}
}
4 changes: 2 additions & 2 deletions service/src/test/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ loginsvc:
# Rest General Config
jwt:
generate-in-memory:
refresh-exp-time: 10h
access-exp-time: 15min
rotation-time: 5sec
refresh-exp-time: 10h
key-rotation-time: 5sec
alg-name: "RS256"
config:
some-key: "BETA"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,52 @@
* limitations under the License.
*/

package za.co.absa.loginsvc.rest.config
package za.co.absa.loginsvc.rest.config.jwt

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import za.co.absa.loginsvc.rest.config.jwt.{AwsSecretsManagerKeyConfig, InMemoryKeyConfig, KeyConfig}
import za.co.absa.loginsvc.rest.config.validation.ConfigValidationException
import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess}

import scala.concurrent.duration._

import java.util.concurrent.TimeUnit
import scala.concurrent.duration.FiniteDuration

class JwtConfigTest extends AnyFlatSpec with Matchers {

val inMemoryKeyConfig: InMemoryKeyConfig = InMemoryKeyConfig("RS256",
15.munites,
2.hours,
Option(30.minutes))
class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers {

val awsSecretsManagerKeyConfig: AwsSecretsManagerKeyConfig = AwsSecretsManagerKeyConfig("Secret",
"region",
"private",
"public",
"RS256",
FiniteDuration(15, TimeUnit.MINUTES),
FiniteDuration(9, TimeUnit.HOURS),
Option(FiniteDuration(30, TimeUnit.MINUTES)))

"inMemoryKeyConfig" should "validate expected content" in {
inMemoryKeyConfig.validate() shouldBe ConfigValidationSuccess
}

"awsSecretsManagerKeyConfig" should "validate expected content" in {
awsSecretsManagerKeyConfig.validate() shouldBe ConfigValidationSuccess
}

"inMemoryKeyConfig" should "fail on invalid algorithm" in {
inMemoryKeyConfig.copy(algName = "ABC").validate() shouldBe
ConfigValidationError(ConfigValidationException("Invalid algName 'ABC' was given."))
}

"inMemoryKeyConfig" should "fail on non-negative accessExpTime" in {
inMemoryKeyConfig.copy(accessExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"accessExpTime must be at least ${KeyConfig.minAccessExpTime}"))
}

"inMemoryKeyConfig" should "fail on non-negative refreshExpTime" in {
inMemoryKeyConfig.copy(rotationTime = Option(FiniteDuration(5, TimeUnit.MILLISECONDS))).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"refreshKeyTime must be at least ${KeyConfig.minRefreshKeyTime}"))
}

// todo the third value check?

"awsSecretsManagerKeyConfig" should "fail on invalid algorithm" in {
it should "fail on invalid algorithm" in {
awsSecretsManagerKeyConfig.copy(algName = "ABC").validate() shouldBe
ConfigValidationError(ConfigValidationException("Invalid algName 'ABC' was given."))
}

"awsSecretsManagerKeyConfig" should "fail on non-negative accessExpTime" in {
it should "fail on non-negative accessExpTime" in {
awsSecretsManagerKeyConfig.copy(accessExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"accessExpTime must be at least ${KeyConfig.minAccessExpTime}"))
}

"awsSecretsManagerKeyConfig" should "fail on non-negative refreshExpTime" in {
it should "fail on non-negative refreshExpTime" in {
awsSecretsManagerKeyConfig.copy(refreshExpTime = FiniteDuration(5, TimeUnit.MILLISECONDS)).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"refreshExpTime must be at least ${KeyConfig.minRefreshExpTime}"))
}

it should "fail on non-negative keyRotationTime" in {
awsSecretsManagerKeyConfig.copy(pollTime = Option(FiniteDuration(5, TimeUnit.MILLISECONDS))).validate() shouldBe
ConfigValidationError(ConfigValidationException(s"refreshKeyTime must be at least ${KeyConfig.minRefreshKeyTime}"))
ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}"))
}

"awsSecretsManagerKeyConfig" should "fail on missing value" in {
it should "fail on missing value" in {
awsSecretsManagerKeyConfig.copy(secretName = null).validate() shouldBe
ConfigValidationError(ConfigValidationException("secretName is empty"))
}
Expand Down
Loading

0 comments on commit 64f509e

Please sign in to comment.