Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drtii 1609 perform connectivity test with es #246

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ lazy val root = project.in(file(".")).
)

lazy val akkaVersion = "2.8.5"
lazy val akkaHttpVersion = "10.5.2"
lazy val jodaVersion = "2.12.5"
lazy val upickleVersion = "3.1.3"
lazy val sparkMlLibVersion = "3.5.0"
Expand Down Expand Up @@ -55,6 +56,8 @@ lazy val cross = crossProject(JVMPlatform, JSPlatform)
"com.typesafe.akka" %% "akka-actor" % akkaVersion,
"com.typesafe.akka" %% "akka-persistence" % akkaVersion,
"com.typesafe.akka" %% "akka-persistence-query" % akkaVersion,
"com.typesafe.akka" %% "akka-http" % akkaHttpVersion,
"com.typesafe.akka" %% "akka-http-spray-json" % akkaHttpVersion,
"com.typesafe.akka" %% "akka-slf4j" % akkaVersion,
"joda-time" % "joda-time" % jodaVersion,
"org.apache.spark" %% "spark-mllib" % sparkMlLibVersion,
Expand Down
127 changes: 127 additions & 0 deletions jvm/src/main/scala/uk/gov/homeoffice/drt/keycloak/KeyCloakAuth.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package uk.gov.homeoffice.drt.keycloak

import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.Accept
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.stream.Materializer
import org.slf4j.{Logger, LoggerFactory}
import spray.json.{DefaultJsonProtocol, JsNumber, JsObject, JsString, JsValue, RootJsonFormat}

import scala.concurrent.{ExecutionContext, Future}

case class KeyCloakAuth(tokenUrl: String, clientId: String, clientSecret: String, sendHttpRequest: HttpRequest => Future[HttpResponse])
(implicit ec: ExecutionContext, mat: Materializer)
extends KeyCloakAuthTokenParserProtocol {

val log: Logger = LoggerFactory.getLogger(getClass)

def formData(username: String, password: String, clientId: String, clientSecret: String) = FormData(Map(
"username" -> username,
"password" -> password,
"client_id" -> clientId,
"client_secret" -> clientSecret,
"grant_type" -> "password"
))

def getToken(username: String, password: String): Future[KeyCloakAuthResponse] = {
val request = HttpRequest(
method = HttpMethods.POST,
uri = Uri(tokenUrl),
headers = List(Accept(MediaTypes.`application/json`)),
entity = formData(username, password, clientId, clientSecret).toEntity)

val requestWithHeaders = request.addHeader(Accept(MediaTypes.`application/json`))

sendHttpRequest(requestWithHeaders).flatMap { r =>
Unmarshal(r).to[KeyCloakAuthResponse]
}
}
}

sealed trait KeyCloakAuthResponse

case class KeyCloakAuthToken(accessToken: String,
expiresIn: Int,
refreshExpiresIn: Int,
refreshToken: String,
tokenType: String,
notBeforePolicy: Int,
sessionState: String,
scope: String) extends KeyCloakAuthResponse

case class KeyCloakAuthError(error: String, errorDescription: String) extends KeyCloakAuthResponse


trait KeyCloakAuthTokenParserProtocol extends SprayJsonSupport with DefaultJsonProtocol {
implicit val responseFormat: RootJsonFormat[KeyCloakAuthResponse] = new RootJsonFormat[KeyCloakAuthResponse] {
override def write(response: KeyCloakAuthResponse): JsValue = response match {
case KeyCloakAuthToken(token, expires, _, _, tokenType, _, _, _) => JsObject(
"access_token" -> JsString(token),
"expires_in" -> JsNumber(expires),
"token_type" -> JsString(tokenType)
)
case KeyCloakAuthError(error, desc) => JsObject(
"error" -> JsString(error),
"error_description" -> JsString(desc)
)
}

override def read(json: JsValue): KeyCloakAuthResponse = json match {
case JsObject(fields) if fields.contains("access_token") =>
KeyCloakAuthToken(
fields.get("access_token").map(_.convertTo[String]).getOrElse(""),
fields.get("expires_in").map(_.convertTo[Int]).getOrElse(0),
fields.get("refresh_expires_in").map(_.convertTo[Int]).getOrElse(0),
fields.get("refresh_token").map(_.convertTo[String]).getOrElse(""),
fields.get("token_type").map(_.convertTo[String]).getOrElse(""),
fields.get("not-before-policy").map(_.convertTo[Int]).getOrElse(0),
fields.get("session_state").map(_.convertTo[String]).getOrElse(""),
fields.get("scope").map(_.convertTo[String]).getOrElse("")
)
case JsObject(fields) =>
KeyCloakAuthError(
fields.get("error").map(_.convertTo[String]).getOrElse(""),
fields.get("error_description").map(_.convertTo[String]).getOrElse("")
)
}
}

implicit val tokenFormat: RootJsonFormat[KeyCloakAuthToken] = new RootJsonFormat[KeyCloakAuthToken] {
override def write(token: KeyCloakAuthToken): JsValue = JsObject(
"access_token" -> JsString(token.accessToken),
"expires_in" -> JsNumber(token.expiresIn),
"token_type" -> JsString(token.tokenType)
)

override def read(json: JsValue): KeyCloakAuthToken = json match {
case JsObject(fields) if fields.contains("access_token") =>
KeyCloakAuthToken(
fields.get("access_token").map(_.convertTo[String]).getOrElse(""),
fields.get("expires_in").map(_.convertTo[Int]).getOrElse(0),
fields.get("refresh_expires_in").map(_.convertTo[Int]).getOrElse(0),
fields.get("refresh_token").map(_.convertTo[String]).getOrElse(""),
fields.get("token_type").map(_.convertTo[String]).getOrElse(""),
fields.get("not-before-policy").map(_.convertTo[Int]).getOrElse(0),
fields.get("session_state").map(_.convertTo[String]).getOrElse(""),
fields.get("scope").map(_.convertTo[String]).getOrElse("")
)
}
}

implicit val errorFormat: RootJsonFormat[KeyCloakAuthError] = new RootJsonFormat[KeyCloakAuthError] {
override def write(error: KeyCloakAuthError): JsValue = JsObject(
"error" -> JsString(error.error),
"error_description" -> JsString(error.errorDescription)
)

override def read(json: JsValue): KeyCloakAuthError = json match {
case JsObject(fields) =>
KeyCloakAuthError(
fields.get("error").map(_.convertTo[String]).getOrElse(""),
fields.get("error_description").map(_.convertTo[String]).getOrElse("")
)
case _ => KeyCloakAuthError("", "")
}
}
}
148 changes: 148 additions & 0 deletions jvm/src/main/scala/uk/gov/homeoffice/drt/keycloak/KeyCloakClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package uk.gov.homeoffice.drt.keycloak

import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.{Accept, Authorization, OAuth2BearerToken}
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.stream.Materializer
import akka.util.Timeout
import org.slf4j.{Logger, LoggerFactory}
import spray.json.{DefaultJsonProtocol, JsObject, JsValue, RootJsonFormat}

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.postfixOps

case class KeyCloakClient(token: String, keyCloakUrl: String, sendHttpRequest: HttpRequest => Future[HttpResponse])
(implicit val ec: ExecutionContext, mat: Materializer)
extends KeyCloakUserParserProtocol {

import KeyCloakUserFormatParser._

def log: Logger = LoggerFactory.getLogger(getClass)

implicit val timeout: Timeout = Timeout(1 minute)

def logResponse(requestName: String, resp: HttpResponse): HttpResponse = {
if (resp.status.isFailure)
log.error(s"Error when calling $requestName on KeyCloak API Status code: ${resp.status} Response:<${resp.entity.toString}>")

resp
}

def pipeline(method: HttpMethod, uri: String, requestName: String): Future[HttpResponse] = {
val request = HttpRequest(method, Uri(uri))
val requestWithHeaders = request
.addHeader(Accept(MediaTypes.`application/json`))
.addHeader(Authorization(OAuth2BearerToken(token)))
sendHttpRequest(requestWithHeaders).map { r =>
logResponse(requestName, r)
r
}
}

def getUserForEmail(email: String): Future[Option[KeyCloakUser]] = {
val uri = keyCloakUrl + s"/users?email=$email"
log.info(s"Calling key cloak: $uri")
pipeline(HttpMethods.GET, uri, "getUsersForEmail")
.flatMap { r => Unmarshal(r).to[List[KeyCloakUser]] }.map(_.headOption)
}

def getUsers(max: Int = 100, offset: Int = 0): Future[List[KeyCloakUser]] = {
val uri = keyCloakUrl + s"/users?max=$max&first=$offset"
log.info(s"Calling key cloak: $uri")
pipeline(HttpMethods.GET, uri, "getUsers").flatMap { r => Unmarshal(r).to[List[KeyCloakUser]] }
}

def getUserByUsername(username: String): Future[Option[KeyCloakUser]] = {
val uri = keyCloakUrl + s"/users?username=$username"
log.info(s"Calling key cloak: $uri")
pipeline(HttpMethods.GET, uri, "getUsersForUsername")
.flatMap { r => Unmarshal(r).to[List[KeyCloakUser]] }.map(_.headOption)
}

def getAllUsers(offset: Int = 0): Seq[KeyCloakUser] = {
val users = Await.result(getUsers(50, offset), 2 seconds)
if (users.isEmpty) Nil else users ++ getAllUsers(offset + 50)
}

def removeUser(userId: String): Future[HttpResponse] = {
log.info(s"Removing $userId")
val uri = s"$keyCloakUrl/users/$userId"
pipeline(HttpMethods.DELETE, uri, "removeUserFromGroup")
}

def logUserOut(userId: String): Future[HttpResponse] = {
log.info(s"Logout $userId")
val uri = s"$keyCloakUrl/users/$userId/logout"
pipeline(HttpMethods.POST, uri, "logoutUser")
}

def getUserGroups(userId: String): Future[List[KeyCloakGroup]] = {
val uri = keyCloakUrl + s"/users/$userId/groups"
log.info(s"Calling key cloak: $uri")
pipeline(HttpMethods.GET, uri, "getUserGroups").flatMap { r => Unmarshal(r).to[List[KeyCloakGroup]] }
}

def getGroups: Future[List[KeyCloakGroup]] = {
val uri = keyCloakUrl + "/groups"
log.info(s"Calling key cloak: $uri")
pipeline(HttpMethods.GET, uri, "getGroups").flatMap { r => Unmarshal(r).to[List[KeyCloakGroup]] }
}

def getUsersInGroup(groupName: String, max: Int = 1000): Future[List[KeyCloakUser]] = {
val futureMaybeId: Future[Option[String]] = getGroups.map(gs => gs.find(_.name == groupName).map(_.id))

futureMaybeId.flatMap {
case Some(id) =>
val uri = keyCloakUrl + s"/groups/$id/members?max=$max"
pipeline(HttpMethods.GET, uri, "getUsersInGroup").flatMap { r => Unmarshal(r).to[List[KeyCloakUser]] }
case None => Future(List())
}
}

def getUsersNotInGroup(groupName: String): Future[List[KeyCloakUser]] = {

val futureUsersInGroup: Future[List[KeyCloakUser]] = getUsersInGroup(groupName)
val futureAllUsers: Future[List[KeyCloakUser]] = getUsers()

for {
usersInGroup <- futureUsersInGroup
allUsers <- futureAllUsers
} yield allUsers.filterNot(usersInGroup.toSet)
}

def addUserToGroup(userId: String, groupId: String): Future[HttpResponse] = {
log.info(s"Adding $userId to $groupId")
val uri = s"$keyCloakUrl/users/$userId/groups/$groupId"
pipeline(HttpMethods.PUT, uri, "addUserToGroup")
}

def removeUserFromGroup(userId: String, groupId: String): Future[HttpResponse] = {
log.info(s"Removing $userId from $groupId")
val uri = s"$keyCloakUrl/users/$userId/groups/$groupId"
pipeline(HttpMethods.DELETE, uri, "removeUserFromGroup")
}
}

trait KeyCloakUserParserProtocol extends DefaultJsonProtocol with SprayJsonSupport {

implicit object KeyCloakUserFormatParser extends RootJsonFormat[KeyCloakUser] {
override def write(obj: KeyCloakUser): JsValue = throw new Exception("KeyCloakUser writer not implemented")

override def read(json: JsValue): KeyCloakUser = json match {
case JsObject(fields) =>
KeyCloakUser(
fields.get("id").map(_.convertTo[String]).getOrElse(""),
fields.get("username").map(_.convertTo[String]).getOrElse(""),
fields.get("enabled").exists(_.convertTo[Boolean]),
fields.get("emailVerified").exists(_.convertTo[Boolean]),
fields.get("firstName").map(_.convertTo[String]).getOrElse(""),
fields.get("lastName").map(_.convertTo[String]).getOrElse(""),
fields.get("email").map(_.convertTo[String]).getOrElse("")
)
}
}

implicit val keyCloakGroupFormat: RootJsonFormat[KeyCloakGroup] = jsonFormat3(KeyCloakGroup)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package uk.gov.homeoffice.drt.keycloak

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future


case class KeyCloakGroups(groups: List[KeyCloakGroup], client: KeyCloakClient) {
def usersWithGroupsCsvContent: Future[String] = {
val usersWithGroupsFuture = allUsersWithGroups(groups)
usersWithGroupsToCsv(usersWithGroupsFuture)
}

def usersWithGroupsToCsv(usersWithGroupsFuture: Future[Map[KeyCloakUser, List[String]]]): Future[String] = {
val headerLine = "Email,First Name,Last Name,Enabled,Groups"
usersWithGroupsFuture
.map(usersToUsersWithGroups => {
val csvLines = usersToUsersWithGroups
.map {
case (user, userGroups) =>
val userGroupsCsvValue = userGroups.sorted.mkString(", ")
s"""${user.email},${user.firstName},${user.lastName},${user.enabled},"$userGroupsCsvValue""""
}
headerLine + "\n" + csvLines.mkString("\n")
})
}

def usersWithGroups(groups: List[KeyCloakGroup]): Future[List[(KeyCloakUser, String)]] = {
val eventualUsersWithGroupsByGroup: List[Future[List[(KeyCloakUser, String)]]] = groups.map(group => {
val eventualUsersWithGroups = client
.getUsersInGroup(group.name)
.map(_.map(user => (user, group.name)))
eventualUsersWithGroups
})
Future.sequence(eventualUsersWithGroupsByGroup).map(_.flatten)
}

def usersWithGroupsByUser(groups: List[KeyCloakGroup]): Future[Map[KeyCloakUser, List[String]]] =
usersWithGroups(groups).map(usersAndGroups => {
usersAndGroups.groupBy {
case (user, _) => user
}.view.mapValues(_.map {
case (_, group) => group
}).toMap
})

def allUsersWithGroups(groups: List[KeyCloakGroup]): Future[Map[KeyCloakUser, List[String]]] =
usersWithGroupsByUser(groups).map(groupsByUser => {
client.getAllUsers().map(u => {
u -> groupsByUser.getOrElse(u, List())
}).toMap
})
}
Loading