diff --git a/service/src/main/scala/za/co/absa/loginsvc/rest/config/JwtConfig.scala b/service/src/main/scala/za/co/absa/loginsvc/rest/config/JwtConfig.scala index 0b7bb47..c98f798 100644 --- a/service/src/main/scala/za/co/absa/loginsvc/rest/config/JwtConfig.scala +++ b/service/src/main/scala/za/co/absa/loginsvc/rest/config/JwtConfig.scala @@ -117,7 +117,7 @@ case class FetchSecretConfig( val resultsMerge = results.foldLeft[ConfigValidationResult](ConfigValidationSuccess)(ConfigValidationResult.merge) val TimeResult = if (refreshTime < 1) { - ConfigValidationError(ConfigValidationException("refreshTime must be positive (hours)")) + ConfigValidationError(ConfigValidationException("refreshTime must be positive (minutes)")) } else ConfigValidationSuccess resultsMerge.merge(TimeResult) diff --git a/service/src/main/scala/za/co/absa/loginsvc/rest/service/JWTService.scala b/service/src/main/scala/za/co/absa/loginsvc/rest/service/JWTService.scala index 1f48da8..eb39e51 100644 --- a/service/src/main/scala/za/co/absa/loginsvc/rest/service/JWTService.scala +++ b/service/src/main/scala/za/co/absa/loginsvc/rest/service/JWTService.scala @@ -156,7 +156,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) { private def refreshSecrets(secretConfig: FetchSecretConfig): Unit = { val scheduler = Executors.newSingleThreadScheduledExecutor() - scheduler.scheduleAtFixedRate(() => { + val scheduledFuture = scheduler.scheduleAtFixedRate(() => { try this.config = fetchSecrets(secretConfig) catch { case e: Throwable => @@ -167,6 +167,24 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) { secretConfig.refreshTime, TimeUnit.MINUTES ) + + Runtime.getRuntime.addShutdownHook(new Thread(() => { + + scheduledFuture.cancel(false) + scheduler.shutdown() + + try { + // Wait for up to 5 seconds for the scheduler to terminate + if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) { + // If it doesn't terminate, forcefully shut it down + scheduler.shutdownNow() + } + } + catch { + case e: InterruptedException => + Thread.currentThread().interrupt() + } + })) } private case class keyConfig (keyPair: KeyPair, expTime: Int, algName: String) diff --git a/service/src/test/scala/za/co/absa/loginsvc/rest/config/JwtConfigTest.scala b/service/src/test/scala/za/co/absa/loginsvc/rest/config/JwtConfigTest.scala index 6c69036..a920f7a 100644 --- a/service/src/test/scala/za/co/absa/loginsvc/rest/config/JwtConfigTest.scala +++ b/service/src/test/scala/za/co/absa/loginsvc/rest/config/JwtConfigTest.scala @@ -24,8 +24,12 @@ import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{Config class JwtConfigTest extends AnyFlatSpec with Matchers { - val jwtConfig: JwtConfig = JwtConfig(Option(GenerateKeysConfig("RS256", 2)), None) + val jwtConfig: JwtConfig = JwtConfig( + Option(GenerateKeysConfig("RS256", 2)), + Option(FetchSecretConfig("Secret", "region", "private", "public", "exp", "alg", 30)) + ) val inMemoryConfig: GenerateKeysConfig = jwtConfig.generateInMemory.get + val awsSecretConfig: FetchSecretConfig = jwtConfig.fetchFromAws.get "JwtConfig" should "validate expected content" in { jwtConfig.validate() shouldBe ConfigValidationSuccess @@ -35,15 +39,28 @@ class JwtConfigTest extends AnyFlatSpec with Matchers { inMemoryConfig.validate() shouldBe ConfigValidationSuccess } + "awsSecretConfig" should "validate expected content" in { + awsSecretConfig.validate() shouldBe ConfigValidationSuccess + } + it should "fail on invalid algorithm" in { inMemoryConfig.copy(algName = "ABC").validate() shouldBe ConfigValidationError(ConfigValidationException("Invalid algName 'ABC' was given.")) } - it should "fail on non-negative expTime" in { + "inMemoryConfig" should "fail on non-negative expTime" in { inMemoryConfig.copy(expTime = -7).validate() shouldBe ConfigValidationError(ConfigValidationException("expTime must be positive (hours)")) } + "awsSecretConfig" should "fail on missing value" in { + awsSecretConfig.copy(secretName = null).validate() shouldBe + ConfigValidationError(ConfigValidationException("secretName is empty")) + } + + "awsSecretConfig" should "fail on non-negative refreshTime" in { + awsSecretConfig.copy(refreshTime = -7).validate() shouldBe + ConfigValidationError(ConfigValidationException("refreshTime must be positive (minutes)")) + } }