diff --git a/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/PanDomainAuthSettingsRefresher.scala b/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/PanDomainAuthSettingsRefresher.scala index f7d9928..79d512c 100644 --- a/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/PanDomainAuthSettingsRefresher.scala +++ b/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/PanDomainAuthSettingsRefresher.scala @@ -1,12 +1,10 @@ package com.gu.pandomainauth -import java.util.concurrent.atomic.AtomicReference -import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} - -import com.amazonaws.services.s3.AmazonS3 import com.gu.pandomainauth.model.PanDomainAuthSettings +import com.gu.pandomainauth.service.CryptoConf import org.slf4j.LoggerFactory +import java.util.concurrent.{Executors, ScheduledExecutorService} import scala.language.postfixOps /** @@ -14,46 +12,25 @@ import scala.language.postfixOps * * @param domain the domain you are authenticating against * @param system the identifier for your app, typically the same as the subdomain your app runs on - * @param bucketName the bucket where the settings are stored - * @param settingsFileKey the name of the file that contains the private settings for the given domain - * @param s3Client the AWS S3 client that will be used to download the settings from the bucket * @param scheduler optional scheduler that will be used to run the code that updates the bucket */ class PanDomainAuthSettingsRefresher( val domain: String, val system: String, - val bucketName: String, settingsFileKey: String, - val s3Client: AmazonS3, + val s3BucketLoader: S3BucketLoader, scheduler: ScheduledExecutorService = Executors.newScheduledThreadPool(1) ) { private val logger = LoggerFactory.getLogger(this.getClass) - // This is deliberately designed to throw an exception during construction if we cannot immediately read the settings - private val authSettings: AtomicReference[PanDomainAuthSettings] = new AtomicReference[PanDomainAuthSettings](loadSettings() match { - case Right(settings) => PanDomainAuthSettings(settings) - case Left(err) => throw Settings.errorToThrowable(err) - }) - - scheduler.scheduleAtFixedRate(() => refresh(), 1, 1, TimeUnit.MINUTES) - - def settings: PanDomainAuthSettings = authSettings.get() - - private def loadSettings(): Either[SettingsFailure, Map[String, String]] = { - Settings.fetchSettings(settingsFileKey, bucketName, s3Client).flatMap(Settings.extractSettings) - } - - private def refresh(): Unit = { - loadSettings() match { - case Right(settings) => - logger.debug(s"Updated pan-domain settings for $domain") - authSettings.set(PanDomainAuthSettings(settings)) + private val settingsRefresher = new Settings.Refresher[PanDomainAuthSettings]( + new Settings.Loader(s3BucketLoader, settingsFileKey), + PanDomainAuthSettings.extractFrom, + scheduler + ) + settingsRefresher.start(1) - case Left(err) => - logger.error(s"Failed to update pan-domain settings for $domain") - Settings.logError(err, logger) - } - } + def settings: PanDomainAuthSettings = settingsRefresher.get() } diff --git a/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/model/PanDomainAuthSettings.scala b/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/model/PanDomainAuthSettings.scala index 8ef55f7..a2a839b 100644 --- a/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/model/PanDomainAuthSettings.scala +++ b/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/model/PanDomainAuthSettings.scala @@ -1,8 +1,7 @@ package com.gu.pandomainauth.model -import com.gu.pandomainauth.service.Crypto - -import java.security.KeyPair +import com.gu.pandomainauth.SettingsFailure.SettingsResult +import com.gu.pandomainauth.service.{CryptoConf, KeyPair} case class PanDomainAuthSettings( signingKeyPair: KeyPair, @@ -32,7 +31,7 @@ case class Google2FAGroupSettings( object PanDomainAuthSettings{ private val legacyCookieNameSetting = "assymCookieName" - def apply(settingMap: Map[String, String]): PanDomainAuthSettings = { + def extractFrom(settingMap: Map[String, String]): SettingsResult[PanDomainAuthSettings] = { val cookieSettings = CookieSettings( cookieName = settingMap.getOrElse(legacyCookieNameSetting, settingMap("cookieName")) ) @@ -49,12 +48,12 @@ object PanDomainAuthSettings{ serviceAccountCert <- settingMap.get("googleServiceAccountCert"); adminUser <- settingMap.get("google2faUser"); group <- settingMap.get("multifactorGroupId") - ) yield { - Google2FAGroupSettings(serviceAccountId, serviceAccountCert, adminUser, group) - } + ) yield Google2FAGroupSettings(serviceAccountId, serviceAccountCert, adminUser, group) - PanDomainAuthSettings( - Crypto.keyPairFrom(settingMap), + for { + activeKeyPair <- CryptoConf.SettingsReader(settingMap).activeKeyPair + } yield PanDomainAuthSettings( + activeKeyPair, cookieSettings, oAuthSettings, google2faSettings diff --git a/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/service/Google2FAGroupChecker.scala b/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/service/Google2FAGroupChecker.scala index 262577c..831d515 100644 --- a/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/service/Google2FAGroupChecker.scala +++ b/pan-domain-auth-core/src/main/scala/com/gu/pandomainauth/service/Google2FAGroupChecker.scala @@ -1,21 +1,20 @@ package com.gu.pandomainauth.service -import com.amazonaws.services.s3.AmazonS3 import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport import com.google.api.client.googleapis.json.GoogleJsonResponseException import com.google.api.client.json.gson.GsonFactory import com.google.api.client.util.SecurityUtils -import com.google.api.services.directory.Directory import com.google.api.services.directory.model.Groups -import com.google.api.services.directory.DirectoryScopes +import com.google.api.services.directory.{Directory, DirectoryScopes} import com.google.auth.http.HttpCredentialsAdapter import com.google.auth.oauth2.ServiceAccountCredentials - -import scala.jdk.CollectionConverters._ +import com.gu.pandomainauth.S3BucketLoader import com.gu.pandomainauth.model.{AuthenticatedUser, Google2FAGroupSettings} import org.slf4j.LoggerFactory -class GroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: AmazonS3, appName: String) { +import scala.jdk.CollectionConverters._ + +class GroupChecker(config: Google2FAGroupSettings, s3BucketLoader: S3BucketLoader, appName: String) { private val logger = LoggerFactory.getLogger(this.getClass) private val transport = GoogleNetHttpTransport.newTrustedTransport() @@ -36,14 +35,13 @@ class GroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: .build private def loadServiceAccountPrivateKey = { - val certInputStream = s3Client.getObject(bucketName, config.serviceAccountCert).getObjectContent val serviceAccountPrivateKey = SecurityUtils.loadPrivateKeyFromKeyStore( SecurityUtils.getPkcs12KeyStore, - certInputStream, + s3BucketLoader.inputStreamFetching(config.serviceAccountCert), "notasecret", "privatekey", "notasecret" ) - try { certInputStream.close() } catch { case _ : Throwable => } + try { s3BucketLoader.inputStreamFetching(config.serviceAccountCert).close() } catch { case _ : Throwable => } serviceAccountPrivateKey } @@ -72,11 +70,11 @@ class GroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: private def hasMoreGroups(groupsResponse: Groups): Boolean = { val token = groupsResponse.getNextPageToken - token != null && token.length > 0 + token != null && token.nonEmpty } } -class GoogleGroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: AmazonS3, appName: String) extends GroupChecker(config, bucketName, s3Client, appName) { +class GoogleGroupChecker(config: Google2FAGroupSettings, s3BucketLoader: S3BucketLoader, appName: String) extends GroupChecker(config, s3BucketLoader, appName) { def checkGroups(authenticatedUser: AuthenticatedUser, groupIds: List[String]): Either[String, Boolean] = { val query = directory.groups().list().setUserKey(authenticatedUser.user.email) @@ -86,10 +84,9 @@ class GoogleGroupChecker(config: Google2FAGroupSettings, bucketName: String, s3C } -class Google2FAGroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: AmazonS3, appName: String) extends GroupChecker(config, bucketName, s3Client, appName) { +class Google2FAGroupChecker(config: Google2FAGroupSettings, s3BucketLoader: S3BucketLoader, appName: String) extends GroupChecker(config, s3BucketLoader, appName) { - def checkMultifactor(authenticatedUser: AuthenticatedUser): Boolean = { + def checkMultifactor(authenticatedUser: AuthenticatedUser): Boolean = hasGroup(authenticatedUser.user.email, config.multifactorGroupId) - } } diff --git a/pan-domain-auth-example/app/VerifyExample.scala b/pan-domain-auth-example/app/VerifyExample.scala index eb18271..75a0081 100644 --- a/pan-domain-auth-example/app/VerifyExample.scala +++ b/pan-domain-auth-example/app/VerifyExample.scala @@ -1,8 +1,9 @@ import com.amazonaws.auth.DefaultAWSCredentialsProviderChain import com.amazonaws.regions.Regions import com.amazonaws.services.s3.AmazonS3ClientBuilder +import com.gu.pandomainauth.S3BucketLoader.forAwsSdkV1 import com.gu.pandomainauth.model.{Authenticated, AuthenticatedUser, GracePeriod} -import com.gu.pandomainauth.{PanDomain, PublicSettings} +import com.gu.pandomainauth.{PanDomain, PublicSettings, Settings} object VerifyExample { // Change this to point to the S3 bucket and key for the settings file @@ -14,16 +15,13 @@ object VerifyExample { val credentials = DefaultAWSCredentialsProviderChain.getInstance() val s3Client = AmazonS3ClientBuilder.standard().withRegion(region).withCredentials(credentials).build() - val publicSettings = new PublicSettings(settingsFileKey, bucketName, s3Client) + val loader = new Settings.Loader(forAwsSdkV1(s3Client, bucketName), settingsFileKey) + val publicSettings = new PublicSettings(loader) // Call the start method when your application starts up to ensure the settings are kept up to date publicSettings.start() - // You can integrate with your own scheduler by calling refresh() which will synchronously update the settings - publicSettings.refresh() - - // `publicKey` will return None if a value has not been successfully obtained - val publicKey = publicSettings.publicKey.get + val publicKey = publicSettings.publicKey // The name of this particular application val system = "test" diff --git a/pan-domain-auth-example/app/di.scala b/pan-domain-auth-example/app/di.scala index 2f38870..ee50db0 100644 --- a/pan-domain-auth-example/app/di.scala +++ b/pan-domain-auth-example/app/di.scala @@ -3,6 +3,7 @@ import com.amazonaws.auth.{AWSCredentialsProviderChain, DefaultAWSCredentialsPro import com.amazonaws.regions.Regions import com.amazonaws.services.s3.AmazonS3ClientBuilder import com.gu.pandomainauth.PanDomainAuthSettingsRefresher +import com.gu.pandomainauth.S3BucketLoader.forAwsSdkV1 import controllers.AdminController import play.api.ApplicationLoader.Context import play.api.libs.ws.ahc.AhcWSComponents @@ -37,9 +38,8 @@ class AppComponents(context: Context) extends BuiltInComponentsFromContext(conte val panDomainSettings = new PanDomainAuthSettingsRefresher( domain = "local.dev-gutools.co.uk", system = "example", - bucketName = bucketName, settingsFileKey = "local.dev-gutools.co.uk.settings", - s3Client = s3Client + s3BucketLoader = forAwsSdkV1(s3Client, bucketName) ) val controller = new AdminController(controllerComponents, configuration, wsClient, panDomainSettings) diff --git a/pan-domain-auth-play/src/main/scala/com/gu/pandomainauth/action/Actions.scala b/pan-domain-auth-play/src/main/scala/com/gu/pandomainauth/action/Actions.scala index 611cd86..403b75c 100644 --- a/pan-domain-auth-play/src/main/scala/com/gu/pandomainauth/action/Actions.scala +++ b/pan-domain-auth-play/src/main/scala/com/gu/pandomainauth/action/Actions.scala @@ -90,7 +90,7 @@ trait AuthActions { val applicationName: String = s"pan-domain-authentication-$system" val multifactorChecker: Option[Google2FAGroupChecker] = settings.google2FAGroupSettings.map { - new Google2FAGroupChecker(_, panDomainSettings.bucketName, panDomainSettings.s3Client, applicationName) + new Google2FAGroupChecker(_, panDomainSettings.s3BucketLoader, applicationName) } /** @@ -198,7 +198,7 @@ trait AuthActions { } def readAuthenticatedUser(request: RequestHeader): Option[AuthenticatedUser] = readCookie(request) map { cookie => - CookieUtils.parseCookieData(cookie.cookie.value, settings.signingKeyPair.getPublic) + CookieUtils.parseCookieData(cookie.cookie.value, settings.signingKeyPair.publicKey) } def readCookie(request: RequestHeader): Option[PandomainCookie] = { @@ -211,7 +211,7 @@ trait AuthActions { def generateCookie(authedUser: AuthenticatedUser): Cookie = Cookie( name = settings.cookieSettings.cookieName, - value = CookieUtils.generateCookieData(authedUser, settings.signingKeyPair.getPrivate), + value = CookieUtils.generateCookieData(authedUser, settings.signingKeyPair.privateKey), domain = Some(domain), secure = true, httpOnly = true @@ -237,7 +237,7 @@ trait AuthActions { */ def extractAuth(request: RequestHeader): AuthenticationStatus = { readCookie(request).map { cookie => - PanDomain.authStatus(cookie.cookie.value, settings.signingKeyPair.getPublic, validateUser, apiGracePeriod, system, cacheValidation, cookie.forceExpiry) + PanDomain.authStatus(cookie.cookie.value, settings.signingKeyPair.publicKey, validateUser, apiGracePeriod, system, cacheValidation, cookie.forceExpiry) } getOrElse NotAuthenticated } diff --git a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/PublicSettings.scala b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/PublicSettings.scala index 0bd96df..5a1b84a 100644 --- a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/PublicSettings.scala +++ b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/PublicSettings.scala @@ -1,51 +1,30 @@ package com.gu.pandomainauth -import java.util.concurrent.atomic.AtomicReference -import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} -import com.amazonaws.services.s3.AmazonS3 -import com.gu.pandomainauth.PublicSettings.validateAndParseKeyText -import com.gu.pandomainauth.service.Crypto -import org.slf4j.LoggerFactory +import com.gu.pandomainauth.SettingsFailure.SettingsResult +import com.gu.pandomainauth.service.CryptoConf import java.security.PublicKey -import java.util.regex.Pattern -import scala.concurrent.ExecutionContext +import java.util.concurrent.{Executors, ScheduledExecutorService} import scala.concurrent.duration._ /** * Class that contains the static public settings and includes mechanism for fetching the public key. Once you have an * instance, call the `start()` method to load the public data. - * - * @param settingsFileKey the settings file for the domain in the S3 bucket (eg local.dev.gutools.co.uk.public.settings) - * @param bucketName the name of the S3 bucket (eg pan-domain-auth-settings) - * @param s3Client the AWS S3 client that will be used to download the settings from the bucket - * @param scheduler optional scheduler that will be used to run the code that updates the bucket + * + * @param scheduler optional scheduler that will be used to run the code that updates the bucket */ -class PublicSettings(settingsFileKey: String, bucketName: String, s3Client: AmazonS3, +class PublicSettings(loader: Settings.Loader, scheduler: ScheduledExecutorService = Executors.newScheduledThreadPool(1)) { - private val agent = new AtomicReference[Option[PublicKey]](None) + private val settingsRefresher = new Settings.Refresher[PublicKey]( + loader, + CryptoConf.SettingsReader(_).activePublicKey, + scheduler + ) - private val logger = LoggerFactory.getLogger(this.getClass) - implicit private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(scheduler) + def start(interval: FiniteDuration = 60.seconds): Unit = settingsRefresher.start(interval.toMinutes.toInt) - def start(interval: FiniteDuration = 60.seconds): Unit = { - scheduler.scheduleAtFixedRate(() => refresh(), 0, interval.toMillis, TimeUnit.MILLISECONDS) - } - - def refresh(): Unit = { - PublicSettings.getPublicKey(settingsFileKey, bucketName, s3Client) match { - case Right(publicKey) => - agent.set(Some(publicKey)) - logger.debug("Successfully updated pan-domain public settings") - - case Left(err) => - logger.error("Failed to update pan-domain public settings") - Settings.logError(err, logger) - } - } - - def publicKey: Option[PublicKey] = agent.get() + def publicKey: PublicKey = settingsRefresher.get() } /** @@ -59,18 +38,7 @@ object PublicSettings { * Fetches the public key from the public S3 bucket * * @param domain the domain to fetch the public key for - * @param client implicit dispatch.Http to use for fetching the key - * @param ec implicit execution context to use for fetching the key */ - def getPublicKey(settingsFileKey: String, bucketName: String, s3Client: AmazonS3): Either[SettingsFailure, PublicKey] = { - fetchSettings(settingsFileKey, bucketName, s3Client) flatMap extractSettings flatMap extractPublicKey - } - - private[pandomainauth] def extractPublicKey(settings: Map[String, String]): Either[SettingsFailure, PublicKey] = - settings.get("publicKey").toRight(PublicKeyNotFoundFailure).flatMap(validateAndParseKeyText) - - private val KeyPattern: Pattern = "[a-zA-Z0-9+/\n]+={0,3}".r.pattern - - private[pandomainauth] def validateAndParseKeyText(pubKeyText: String): Either[SettingsFailure, PublicKey] = - Either.cond(KeyPattern.matcher(pubKeyText).matches, Crypto.publicKeyFor(pubKeyText), PublicKeyFormatFailure) + def getPublicKey(loader: Loader): SettingsResult[PublicKey] = + loader.loadAndParseSettingsMap().flatMap(CryptoConf.SettingsReader(_).activePublicKey) } diff --git a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/S3BucketLoader.scala b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/S3BucketLoader.scala new file mode 100644 index 0000000..fd6154a --- /dev/null +++ b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/S3BucketLoader.scala @@ -0,0 +1,25 @@ +package com.gu.pandomainauth + +import com.amazonaws.services.s3.AmazonS3 + +import java.io.InputStream + +/** + * This trait provides a way to download a file from an S3 bucket, in a way that's agnostic of which + * AWS SDK (v1 or v2) is being used. An instance of S3BucketLoader is *specific* to a particular S3 bucket. + */ +trait S3BucketLoader { + /** + * @param key the key of the file in the S3 bucket, not including the bucket name or a starting "/" + */ + def inputStreamFetching(key: String): InputStream +} + +object S3BucketLoader { + /** + * A convenience method to create an S3BucketLoader using AWS SDK v1, the version used by most of our existing code. + * However, codebases that want to use AWS SDK v2 are able to provide their own implementation of S3BucketLoader. + */ + def forAwsSdkV1(s3Client: AmazonS3, bucketName: String): S3BucketLoader = + s3Client.getObject(bucketName, _).getObjectContent +} diff --git a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/Settings.scala b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/Settings.scala index 0e4f052..ccf848e 100644 --- a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/Settings.scala +++ b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/Settings.scala @@ -1,66 +1,105 @@ package com.gu.pandomainauth +import com.amazonaws.util.IOUtils +import com.gu.pandomainauth.SettingsFailure.SettingsResult +import org.slf4j.{Logger, LoggerFactory} + import java.io.ByteArrayInputStream import java.util.Properties +import java.util.concurrent.TimeUnit.MINUTES +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.{Executors, ScheduledExecutorService} +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal -import com.amazonaws.services.s3.AmazonS3 -import com.amazonaws.util.IOUtils -import org.slf4j.Logger +sealed trait SettingsFailure { + val description: String -import scala.util.control.NonFatal -import scala.jdk.CollectionConverters._ + def logError(logger: Logger): Unit = logger.error(description) + + def asThrowable(): Throwable = new IllegalStateException(description) +} + +trait FailureWithCause extends SettingsFailure { + val cause: Throwable + + override def logError(logger: Logger): Unit = logger.error(description, cause) + + override def asThrowable(): Throwable = new IllegalStateException(description, cause) +} + +case class SettingsDownloadFailure(cause: Throwable) extends FailureWithCause { + override val description: String = "Unable to download public key" +} + +case class MissingSetting(name: String) extends SettingsFailure { + override val description: String = s"Key '$name' not found in settings file" +} + +case class SettingsParseFailure(cause: Throwable) extends FailureWithCause { + override val description: String = "Unable to parse public key" +} + +case object PublicKeyFormatFailure extends SettingsFailure { + override val description: String = "Public key does not match expected format" +} + +case object InvalidBase64 extends SettingsFailure { + override val description: String = "Settings file value for cryptographic key is not valid base64" +} -sealed trait SettingsFailure -case class SettingsDownloadFailure(cause: Throwable) extends SettingsFailure -case class SettingsParseFailure(cause: Throwable) extends SettingsFailure -case object PublicKeyFormatFailure extends SettingsFailure -case object PublicKeyNotFoundFailure extends SettingsFailure +object SettingsFailure { + type SettingsResult[A] = Either[SettingsFailure, A] +} object Settings { - // internal functions for fetching and parsing the responses - def fetchSettings(settingsFileKey: String, bucketName: String, s3Client: AmazonS3): Either[SettingsFailure, String] = try { - val response = s3Client.getObject(bucketName, settingsFileKey) - Right(IOUtils.toString(response.getObjectContent)) - } catch { - case NonFatal(e) => - Left(SettingsDownloadFailure(e)) + /** + * @param settingsFileKey the name of the file that contains the private settings for the given domain + */ + class Loader(s3BucketLoader: S3BucketLoader, settingsFileKey: String) { + + def loadAndParseSettingsMap(): SettingsResult[Map[String, String]] = fetchSettings().flatMap(extractSettings) + + private def fetchSettings(): SettingsResult[String] = try { + Right(IOUtils.toString(s3BucketLoader.inputStreamFetching(settingsFileKey))) + } catch { case NonFatal(e) => Left(SettingsDownloadFailure(e)) } } - private[pandomainauth] def extractSettings(settingsBody: String): Either[SettingsFailure, Map[String, String]] = try { + private[pandomainauth] def extractSettings(settingsBody: String): SettingsResult[Map[String, String]] = try { val props = new Properties() props.load(new ByteArrayInputStream(settingsBody.getBytes("UTF-8"))) - Right(props.asScala.toMap) } catch { case NonFatal(e) => Left(SettingsParseFailure(e)) } - def logError(failure: SettingsFailure, logger: Logger) = failure match { - case SettingsDownloadFailure(cause) => - logger.error("Unable to download public key", cause) - - case SettingsParseFailure(cause) => - logger.error("Unable to parse public key", cause) + class Refresher[A]( + loader: Settings.Loader, + settingsParser: Map[String, String] => SettingsResult[A], + scheduler: ScheduledExecutorService = Executors.newScheduledThreadPool(1) + ) { + // This is deliberately designed to throw an exception during construction if we cannot immediately read the settings + private val store: AtomicReference[A] = new AtomicReference( + loadAndParseSettings().fold(fail => throw fail.asThrowable(), identity) + ) - case PublicKeyFormatFailure => - logger.error("Public key does not match expected format") - - case PublicKeyNotFoundFailure => - logger.error("Public key not found in settings file") - } + private val logger = LoggerFactory.getLogger(getClass) - def errorToThrowable(failure: SettingsFailure): Throwable = failure match { - case SettingsDownloadFailure(cause) => - new IllegalStateException("Unable to download public key", cause) + def start(interval: Int): Unit = scheduler.scheduleAtFixedRate(() => refresh(), 0, interval, MINUTES) - case SettingsParseFailure(cause) => - new IllegalStateException("Unable to parse public key", cause) + def loadAndParseSettings(): SettingsResult[A] = + loader.loadAndParseSettingsMap().flatMap(settingsParser) - case PublicKeyFormatFailure => - new IllegalStateException("Public key does not match expected format") + private def refresh(): Unit = loadAndParseSettings() match { + case Right(newSettings) => + // logger.debug(s"Updated pan-domain settings for $domain") + val oldSettings = store.getAndSet(newSettings) + case Left(err) => + logger.error("Failed to update pan-domain settings for $domain") + err.logError(logger) + } - case PublicKeyNotFoundFailure => - new IllegalStateException("Public key not found in settings file") + def get(): A = store.get() } } diff --git a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/Crypto.scala b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/Crypto.scala index 2926084..b21059d 100644 --- a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/Crypto.scala +++ b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/Crypto.scala @@ -1,9 +1,7 @@ package com.gu.pandomainauth.service -import org.apache.commons.codec.binary.Base64._ import org.bouncycastle.jce.provider.BouncyCastleProvider -import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec} import java.security._ @@ -37,12 +35,4 @@ object Crypto { rsa.update(data) rsa.verify(signature) } - - def publicKeyFor(base64EncodedKey: String): PublicKey = - keyFactory.generatePublic(new X509EncodedKeySpec(decodeBase64(base64EncodedKey))) - def privateKeyFor(base64EncodedKey: String): PrivateKey = - keyFactory.generatePrivate(new PKCS8EncodedKeySpec(decodeBase64(base64EncodedKey))) - - def keyPairFrom(settingMap: Map[String,String]): KeyPair = - new KeyPair(publicKeyFor(settingMap("publicKey")), privateKeyFor(settingMap("privateKey"))) } diff --git a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/CryptoConf.scala b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/CryptoConf.scala new file mode 100644 index 0000000..7e78cc1 --- /dev/null +++ b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/CryptoConf.scala @@ -0,0 +1,44 @@ +package com.gu.pandomainauth.service + +import com.gu.pandomainauth.SettingsFailure.SettingsResult +import com.gu.pandomainauth.service.Crypto.keyFactory +import com.gu.pandomainauth.service.CryptoConf.SettingsReader.{privateKeyFor, publicKeyFor} +import com.gu.pandomainauth.{InvalidBase64, MissingSetting, PublicKeyFormatFailure} +import org.apache.commons.codec.binary.Base64.{decodeBase64, isBase64} + +import java.security.spec.{InvalidKeySpecException, PKCS8EncodedKeySpec, X509EncodedKeySpec} +import java.security.{PrivateKey, PublicKey} +import scala.util.Try + + + +object CryptoConf { + case class SettingsReader(settingMap: Map[String,String]) { + def setting(key: String): SettingsResult[String] = settingMap.get(key).toRight(MissingSetting(key)) + + val activePublicKey: SettingsResult[PublicKey] = setting("publicKey").flatMap(publicKeyFor) + + def activeKeyPair: SettingsResult[KeyPair] = for { + publicKey <- activePublicKey + privateKey <- setting("privateKey").flatMap(privateKeyFor) + } yield KeyPair(publicKey, privateKey) + } + + object SettingsReader { + def publicKeyFor(data: Array[Byte]) = keyFactory.generatePublic(new X509EncodedKeySpec(data)) + def privateKeyFor(data: Array[Byte]) = keyFactory.generatePrivate(new PKCS8EncodedKeySpec(data)) + + def bytesFromBase64(base64Encoded: String): SettingsResult[Array[Byte]] = + Either.cond(isBase64(base64Encoded), decodeBase64(base64Encoded), InvalidBase64) + + private def keyFor[A](keyConstructor: Array[Byte] => A, base64EncodedKey: String): SettingsResult[A] = for { + bytes <- bytesFromBase64(base64EncodedKey) + key <- Try(keyConstructor(bytes)).map(Right(_)).recover { + case _: InvalidKeySpecException => Left(PublicKeyFormatFailure) + }.get + } yield key + + def publicKeyFor(base64EncodedKey: String): SettingsResult[PublicKey] = keyFor(publicKeyFor, base64EncodedKey) + def privateKeyFor(base64EncodedKey: String): SettingsResult[PrivateKey] = keyFor(privateKeyFor, base64EncodedKey) + } +} \ No newline at end of file diff --git a/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/KeyPair.scala b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/KeyPair.scala new file mode 100644 index 0000000..b2d6dc9 --- /dev/null +++ b/pan-domain-auth-verification/src/main/scala/com/gu/pandomainauth/service/KeyPair.scala @@ -0,0 +1,9 @@ +package com.gu.pandomainauth.service + +import java.security.{PrivateKey, PublicKey} + +/** + * This class mainly exists because java.security.KeyPair does not implement a useful `.equals()`` method, + * and we're going to want to be able to check whether two key-pairs are equal. + */ +case class KeyPair(publicKey: PublicKey, privateKey: PrivateKey) diff --git a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/CryptoConfTest.scala b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/CryptoConfTest.scala new file mode 100644 index 0000000..b708f05 --- /dev/null +++ b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/CryptoConfTest.scala @@ -0,0 +1,34 @@ +package com.gu.pandomainauth + +import com.gu.pandomainauth.service.CryptoConf.SettingsReader +import com.gu.pandomainauth.service.TestKeys.testPublicKey +import org.scalatest.EitherValues +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.should.Matchers + + +class CryptoConfTest extends AnyFreeSpec with Matchers with EitherValues { + "CryptoConf.SettingsReader" - { + "returns an error if the key looks invalid" in { + SettingsReader.publicKeyFor("not a valid key").left.value shouldEqual PublicKeyFormatFailure + } + + "returns the key if it is valid" in { + SettingsReader.publicKeyFor(testPublicKey.base64Encoded) shouldEqual Right(testPublicKey.key) + } + } + + "CryptoConf.SettingsReader activePublicKey" - { + "will get a public key from a valid settings map" in { + SettingsReader(Map("publicKey" -> testPublicKey.base64Encoded)).activePublicKey shouldEqual Right(testPublicKey.key) + } + + "will reject a key that is not correctly formatted" in { + SettingsReader(Map("publicKey" -> "improperly formatted public key!!")).activePublicKey.left.value should be(InvalidBase64) + } + + "will fail if the key is not present in the settings" in { + SettingsReader(Map("another key" -> "bar")).activePublicKey.left.value should be(MissingSetting("publicKey")) + } + } +} diff --git a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/PublicSettingsTest.scala b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/PublicSettingsTest.scala index ad4c1a7..522883c 100644 --- a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/PublicSettingsTest.scala +++ b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/PublicSettingsTest.scala @@ -1,38 +1,11 @@ package com.gu.pandomainauth -import com.gu.pandomainauth.service.TestKeys.testPublicKey -import com.gu.pandomainauth.service.{Crypto, TestKeys} -import org.scalatest.concurrent.ScalaFutures -import org.scalatest.freespec.AnyFreeSpec import org.scalatest.EitherValues +import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers -class PublicSettingsTest extends AnyFreeSpec with Matchers with EitherValues with ScalaFutures { - "validateKey" - { - "returns an error if the key looks invalid" in { - val invalidKeyText = "not a valid key" - PublicSettings.validateAndParseKeyText(invalidKeyText).left.value shouldEqual PublicKeyFormatFailure - } - - "returns the key if it is valid" in { - PublicSettings.validateAndParseKeyText(testPublicKey.base64Encoded) shouldEqual Right(testPublicKey.key) - } - } - - "extractPublicKey" - { - "will get a public key from a valid settings map" in { - PublicSettings.extractPublicKey(Map("publicKey" -> testPublicKey.base64Encoded)) shouldEqual Right(testPublicKey.key) - } - - "will reject a key that is not correctly formatted" in { - PublicSettings.extractPublicKey(Map("publicKey" -> "improperly formatted public key!!")).left.value should be(PublicKeyFormatFailure) - } - - "will fail if the key is not present in the settings" in { - PublicSettings.extractPublicKey(Map("another key" -> "bar")).left.value should be(PublicKeyNotFoundFailure) - } - } +class PublicSettingsTest extends AnyFreeSpec with Matchers with EitherValues { "extractSettings" - { "extracts properties from a valid body" in { diff --git a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/CryptoTest.scala b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/CryptoTest.scala index 21e0f03..d1d57c2 100644 --- a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/CryptoTest.scala +++ b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/CryptoTest.scala @@ -1,6 +1,5 @@ package com.gu.pandomainauth.service -import com.gu.pandomainauth.service.Crypto.{privateKeyFor, publicKeyFor} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers diff --git a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/TestKeys.scala b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/TestKeys.scala index fc7947c..ac31866 100644 --- a/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/TestKeys.scala +++ b/pan-domain-auth-verification/src/test/scala/com/gu/pandomainauth/service/TestKeys.scala @@ -1,14 +1,15 @@ package com.gu.pandomainauth.service -import com.gu.pandomainauth.service.Crypto.{privateKeyFor, publicKeyFor} +import com.gu.pandomainauth.SettingsFailure.SettingsResult +import com.gu.pandomainauth.service.CryptoConf.SettingsReader.{privateKeyFor, publicKeyFor} import java.security.Key object TestKeys { case class Example[K <: Key](key: K, base64Encoded: String) - def example[K <: Key](f: String => K)(base64Encoded: String): Example[K] = - Example(f(base64Encoded), base64Encoded) + def example[K <: Key](f: String => SettingsResult[K])(base64Encoded: String): Example[K] = + Example(f(base64Encoded).toOption.get, base64Encoded) /** * A test public/private key-pair