diff --git a/core/src/main/scala/org/thp/scalligraph/auth/MultiAuthSrv.scala b/core/src/main/scala/org/thp/scalligraph/auth/MultiAuthSrv.scala index 2875fbd8..256b9c9d 100644 --- a/core/src/main/scala/org/thp/scalligraph/auth/MultiAuthSrv.scala +++ b/core/src/main/scala/org/thp/scalligraph/auth/MultiAuthSrv.scala @@ -7,6 +7,7 @@ import org.thp.scalligraph.{ AuthenticationError, AuthorizationError, BadConfigurationError, + BadRequestError, EntityIdOrName, NotSupportedError, RichSeq, @@ -58,26 +59,34 @@ class MultiAuthSrv(configuration: Configuration, appConfig: ApplicationConfig, a override def capabilities: Set[AuthCapability.Value] = authProviders.flatMap(_.capabilities).toSet private def forAllAuthProviders[A](providers: Seq[AuthSrv])(body: AuthSrv => Try[A]): Try[A] = - providers.foldLeft[Either[Seq[(String, Throwable)], A]](Left(Seq())) { - case (right: Right[_, _], _) => right - case (Left(errors), auth) => + providers.foldLeft[MultiAuthProviderResponse[A]](MultiAuthProviderResponse.Errors(Seq.empty)) { + case (ok: MultiAuthProviderResponse.Ok[_], _) => ok + case (stopping: MultiAuthProviderResponse.StoppingError, _) => stopping + case (MultiAuthProviderResponse.Errors(errors), auth) => body(auth).fold( { - case _: NotSupportedError => Left(errors) - case error => Left(errors :+ ((auth.name, error))) + case bre: BadRequestError => MultiAuthProviderResponse.StoppingError(auth.name -> bre, errors) + case _: NotSupportedError => MultiAuthProviderResponse.Errors(errors) + case error => MultiAuthProviderResponse.Errors(errors :+ (auth.name -> error)) }, - success => Right(success) + success => MultiAuthProviderResponse.Ok(success) ) } match { - case Right(auth) => Success(auth) - case Left(errors) => - errors - .foreach { - case (authName, AuthenticationError(_, cause)) if cause != null => logAuthError(authName, cause) - case (authName, AuthorizationError(_, cause)) if cause != null => logAuthError(authName, cause) - case (authName, error) => logAuthError(authName, error) - } + case MultiAuthProviderResponse.Ok(value) => Success(value) + case MultiAuthProviderResponse.Errors(errors) => + errors.foreach { + case (authName, AuthenticationError(_, cause)) if cause != null => logAuthError(authName, cause) + case (authName, AuthorizationError(_, cause)) if cause != null => logAuthError(authName, cause) + case (authName, error) => logAuthError(authName, error) + } errors.headOption.fold(Failure(AuthorizationError("Operation not supported")))(e => Failure(e._2)) + case MultiAuthProviderResponse.StoppingError(error, previousErrors) => + (previousErrors :+ error).foreach { + case (authName, AuthenticationError(_, cause)) if cause != null => logAuthError(authName, cause) + case (authName, AuthorizationError(_, cause)) if cause != null => logAuthError(authName, cause) + case (authName, error) => logAuthError(authName, error) + } + Failure(error._2) } private def logAuthError(authName: String, error: Throwable): Unit = { @@ -121,6 +130,13 @@ class MultiAuthSrv(configuration: Configuration, appConfig: ApplicationConfig, a forAllAuthProviders(authProviders)(_.removeKey(username)) } +sealed trait MultiAuthProviderResponse[+A] +object MultiAuthProviderResponse { + case class Ok[A](value: A) extends MultiAuthProviderResponse[A] + case class Errors(errors: Seq[(String, Throwable)]) extends MultiAuthProviderResponse[Nothing] + case class StoppingError(error: (String, Throwable), previousErrors: Seq[(String, Throwable)]) extends MultiAuthProviderResponse[Nothing] +} + class MultiAuthSrvProvider(appConfig: ApplicationConfig, authProviders: Seq[AuthSrvProvider]) extends AuthSrvProvider { def this(app: ScalligraphApplication) = this(app.applicationConfig, app.authSrvProviders)