Skip to content
This repository has been archived by the owner on Mar 8, 2024. It is now read-only.

Commit

Permalink
MultiAuthSrv: Fail fast in case of BadRequestError
Browse files Browse the repository at this point in the history
  • Loading branch information
vdebergue committed Jun 15, 2021
1 parent 2c16590 commit 9825a31
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions core/src/main/scala/org/thp/scalligraph/auth/MultiAuthSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.thp.scalligraph.auth
import org.thp.scalligraph.controllers.AuthenticatedRequest
import org.thp.scalligraph.services.config.ApplicationConfig.configurationFormat
import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem}
import org.thp.scalligraph.{AuthenticationError, AuthorizationError, BadConfigurationError, EntityIdOrName, RichSeq, ScalligraphApplication}
import org.thp.scalligraph.{AuthenticationError, AuthorizationError, BadConfigurationError, BadRequestError, EntityIdOrName, RichSeq, ScalligraphApplication}
import play.api.mvc.{ActionFunction, Request, RequestHeader, Result}
import play.api.{Configuration, Logger}

Expand Down Expand Up @@ -49,18 +49,22 @@ class MultiAuthSrv(configuration: Configuration, appConfig: ApplicationConfig, a

override def capabilities: Set[AuthCapability.Value] = authProviders.flatMap(_.capabilities).toSet

private def forAllAuthProviders[A](providers: Seq[AuthSrv])(body: AuthSrv => Try[A]): Try[A] =
providers.foldLeft[Either[Seq[(String, Throwable)], A]](Left(Seq())) {
case (right: Right[_, _], _) => right
case (Left(errors), auth) =>
private def forAllAuthProviders[A](providers: Seq[AuthSrv])(body: AuthSrv => Try[A]): Try[A] = {
providers.foldLeft[MultiAuthProviderResponse[A]](MultiAuthProviderResponse.Errors(Seq.empty)) {
case (ok: MultiAuthProviderResponse.Ok[_], _) => ok
case (stopping: MultiAuthProviderResponse.StoppingError, _) => stopping
case (MultiAuthProviderResponse.Errors(errors), auth) =>
body(auth).fold(
error => Left(errors :+ ((auth.name, error))),
success => Right(success)
{
case bre: BadRequestError => MultiAuthProviderResponse.StoppingError(auth.name -> bre, errors)
case error => MultiAuthProviderResponse.Errors(errors :+ (auth.name -> error))
},
success => MultiAuthProviderResponse.Ok(success)
)
} match {
case Right(auth) => Success(auth)
case Left(Seq()) => Failure(AuthorizationError("no authentication provider found"))
case Left(errors) =>
case MultiAuthProviderResponse.Ok(value) => Success(value)
case MultiAuthProviderResponse.Errors(Seq()) => Failure(AuthorizationError("no authentication provider found"))
case MultiAuthProviderResponse.Errors(errors) =>
errors.foreach {
case (authName, AuthenticationError(_, cause)) if cause != null => logAuthError(authName, cause)
case (authName, AuthorizationError(_, cause)) if cause != null => logAuthError(authName, cause)
Expand All @@ -70,7 +74,15 @@ class MultiAuthSrv(configuration: Configuration, appConfig: ApplicationConfig, a
Failure(AuthorizationError("Operation not permitted"))
else
Failure(AuthenticationError("Authentication failure"))
case MultiAuthProviderResponse.StoppingError(error, previousErrors) =>
(previousErrors :+ error).foreach {
case (authName, AuthenticationError(_, cause)) if cause != null => logAuthError(authName, cause)
case (authName, AuthorizationError(_, cause)) if cause != null => logAuthError(authName, cause)
case (authName, error) => logAuthError(authName, error)
}
Failure(error._2)
}
}

private def logAuthError(authName: String, error: Throwable): Unit = {
logger.warn(s"$authName fails: $error")
Expand Down Expand Up @@ -113,6 +125,13 @@ class MultiAuthSrv(configuration: Configuration, appConfig: ApplicationConfig, a
forAllAuthProviders(authProviders)(_.removeKey(username))
}

sealed trait MultiAuthProviderResponse[+A]
object MultiAuthProviderResponse {
case class Ok[A](value: A) extends MultiAuthProviderResponse[A]
case class Errors(errors: Seq[(String, Throwable)]) extends MultiAuthProviderResponse[Nothing]
case class StoppingError(error: (String, Throwable), previousErrors: Seq[(String, Throwable)]) extends MultiAuthProviderResponse[Nothing]
}

class MultiAuthSrvProvider(appConfig: ApplicationConfig, authProviders: Seq[AuthSrvProvider]) extends AuthSrvProvider {
def this(app: ScalligraphApplication) = this(app.applicationConfig, app.authSrvProviders)

Expand Down

0 comments on commit 9825a31

Please sign in to comment.