Skip to content

Commit

Permalink
Feature/75 refresh token simple extend (#79)
Browse files Browse the repository at this point in the history
* #75 back to payload-based refresh token implementation
  • Loading branch information
dk1844 authored Nov 13, 2023
1 parent f7d833e commit a232b72
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 94 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ To interact with the service, most notable endpoints are
Please, refer to the [API documentation](#api-documentation) below for details of the endpoints.

#### Generate tokens
Once you request your token at `/token/generate` endpoint, you will receive both an access token (in body)
Once you request your token at `/token/generate` endpoint, you will receive both an access token and a refresh token
```json
{
"token": "..."
"token": "...",
"refresh": "..."
}
```
and a refresh token (in Cookie named `refresh`).

Both tokens are signed by LS public key and carry the username (`sub`), `type` (`access`/`refresh`) and creation/expiry info (`iat`/`exp`).

#### Refresh access token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@ import io.swagger.v3.oas.annotations.security.SecurityRequirement
import io.swagger.v3.oas.annotations.tags.{Tag, Tags}
import io.swagger.v3.oas.annotations.{Operation, Parameter}
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.http.{HttpHeaders, HttpStatus, MediaType, ResponseCookie, ResponseEntity}
import org.springframework.http.{HttpStatus, MediaType}
import org.springframework.security.core.Authentication
import org.springframework.web.bind.annotation._
import za.co.absa.loginsvc.model.User
import za.co.absa.loginsvc.rest.controller.TokenController.{RefreshCookieName, extractRefreshTokenFromRequest, refreshResponseCookieFromRefreshToken, responseEntityWithRefreshCookieHeader}
import za.co.absa.loginsvc.rest.model.{AccessToken, PublicKey, RefreshToken}
import za.co.absa.loginsvc.rest.model.{PublicKey, TokensWrapper}
import za.co.absa.loginsvc.rest.service.JWTService
import za.co.absa.loginsvc.utils.OptionUtils.ImplicitBuilderExt

import java.util.concurrent.CompletableFuture
import java.util.{Base64, Optional}
import javax.servlet.http.{Cookie, HttpServletRequest}
import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration

Expand All @@ -52,10 +50,10 @@ class TokenController @Autowired()(jwtService: JWTService) {
summary = "Generates access and refresh JWTs",
description = """Generates access and refresh JWTs signed by the private key, verifiable by the public key available at /token/public-key. RSA256 is used.""",
responses = Array(
new ApiResponse(responseCode = "200", description = "Access JWT is retrieved in the response body, the refresh JWT in Cookie 'refresh'.",
new ApiResponse(responseCode = "200", description = "JWTs are retrieved in the response body",
content = Array(new Content(
schema = new Schema(implementation = classOf[AccessToken]),
examples = Array(new ExampleObject(value = "{\n \"token\": \"abcd123.efgh456.ijkl789\"}")))
schema = new Schema(implementation = classOf[TokensWrapper]),
examples = Array(new ExampleObject(value = "{\n \"token\": \"abcd123.efgh456.ijkl789\",\n \"refresh\": \"ab12.cd34.ef56\"\n}")))
)
),
new ApiResponse(responseCode = "401", description = "Auth error",
Expand All @@ -70,8 +68,9 @@ class TokenController @Autowired()(jwtService: JWTService) {
path = Array("/generate"),
produces = Array(MediaType.APPLICATION_JSON_VALUE)
)
@ResponseStatus(HttpStatus.OK)
@SecurityRequirement(name = "basicAuth")
def generateToken(authentication: Authentication, @RequestParam("group-prefixes") groupPrefixes: Optional[String]): CompletableFuture[ResponseEntity[AccessToken]] = {
def generateToken(authentication: Authentication, @RequestParam("group-prefixes") groupPrefixes: Optional[String]): CompletableFuture[TokensWrapper] = {
val user = authentication.getPrincipal.asInstanceOf[User]
val groupPrefixesStrScala = groupPrefixes.toScalaOption

Expand All @@ -82,24 +81,19 @@ class TokenController @Autowired()(jwtService: JWTService) {

val accessJwt = jwtService.generateAccessToken(filteredGroupsUser)
val refreshJwt = jwtService.generateRefreshToken(filteredGroupsUser)

Future.successful(
responseEntityWithRefreshCookieHeader(refreshJwt, refreshExpDuration) {
accessJwt
}
)
Future.successful(TokensWrapper.fromTokens(accessJwt, refreshJwt))
}

@Tags(Array(new Tag(name = "token")))
@Operation(
summary = "Refreshes access JWT",
// note: further implementation, perhaps in https://github.com/AbsaOSS/login-service/issues/76, may issue new refresh tokens
description = """Refreshed access JWT and (currently original) refresh JWTs signed by the private key, verifiable by the public key available at /token/public-key. RSA256 is used. Make sure that the refresh token is present in Cookies (refresh=ab.123.cd).""",
description = """Refreshed access JWT and (currently original) refresh JWTs signed by the private key, verifiable by the public key available at /token/public-key. RSA256 is used.""",
responses = Array(
new ApiResponse(responseCode = "200", description = "Access JWT is retrieved in the response body, updated refresh JWT in Cookie 'refresh'.",
new ApiResponse(responseCode = "200", description = "JWTs are retrieved in the response body",
content = Array(new Content(
schema = new Schema(implementation = classOf[AccessToken]),
examples = Array(new ExampleObject(value = "{\n \"token\": \"abcd123.efgh456.ijkl789\"}")))
schema = new Schema(implementation = classOf[TokensWrapper]),
examples = Array(new ExampleObject(value = "{\n \"token\": \"abcd123.efgh456.ijkl789\",\n \"refresh\": \"ab12.cd34.ef56\"\n}")))
)
),
new ApiResponse(responseCode = "401", description = "Understood the supplied tokens, but cannot refresh with those", // specific JWT expcetions
Expand All @@ -119,22 +113,9 @@ class TokenController @Autowired()(jwtService: JWTService) {
produces = Array(MediaType.APPLICATION_JSON_VALUE)
)
@ResponseStatus(HttpStatus.OK)
def refreshToken(@RequestBody accessToken: AccessToken, request: HttpServletRequest): CompletableFuture[ResponseEntity[AccessToken]] = {

val response: Future[ResponseEntity[AccessToken]] = extractRefreshTokenFromRequest(request).map { refreshToken =>
val (refreshedAccessToken, refreshedRefreshToken) = jwtService.refreshTokens(accessToken, refreshToken)

Future.successful(
responseEntityWithRefreshCookieHeader(refreshedRefreshToken, refreshExpDuration) {
refreshedAccessToken
}
)

}.getOrElse(
Future.failed(new IllegalArgumentException("The expected refresh header not found, cannot refresh access token!"))
)

response
def refreshToken(@RequestBody tokens: TokensWrapper): CompletableFuture[TokensWrapper] = {
val (refreshedAccessToken, refreshedRefreshToken) = jwtService.refreshTokens(tokens.accessToken, tokens.refreshToken)
Future.successful(TokensWrapper.fromTokens(refreshedAccessToken, refreshedRefreshToken))
}

@Tags(Array(new Tag(name = "token")))
Expand Down Expand Up @@ -183,32 +164,5 @@ class TokenController @Autowired()(jwtService: JWTService) {
}

object TokenController {
val RefreshCookieName = "refresh"

def extractRefreshTokenFromRequest(request: HttpServletRequest): Option[RefreshToken] = {
Option(request.getCookies()) // getCookies returns null if there are no cookies
.getOrElse(Array.empty[Cookie])
.find(_.getName == RefreshCookieName)
.map(_.getValue)
.map(RefreshToken)
}

def responseEntityWithRefreshCookieHeader[T](refreshToken: RefreshToken, refreshExpDuration: FiniteDuration)(body: T): ResponseEntity[T] = {
val refreshCookie: ResponseCookie = refreshResponseCookieFromRefreshToken(refreshToken, refreshExpDuration)

ResponseEntity
.ok()
.header(HttpHeaders.SET_COOKIE, refreshCookie.toString())
.body(body)
}

def refreshResponseCookieFromRefreshToken(refreshToken: RefreshToken, refreshExpiryHint: FiniteDuration): ResponseCookie = {
ResponseCookie.from(RefreshCookieName, refreshToken.token)
.httpOnly(true)
.secure(true)
.path("/")
.maxAge(refreshExpiryHint.toSeconds)
.build()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,29 @@ import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty}
import io.swagger.v3.oas.annotations.media.Schema
import io.swagger.v3.oas.annotations.media.Schema.RequiredMode

case class AccessToken(
case class TokensWrapper(
@JsonProperty("token")
@Schema(example = "abcd123.efgh456.ijkl789", requiredMode = RequiredMode.REQUIRED)
token: String,
@JsonProperty("refresh")
@Schema(example = "ab12.cd34.ef56", requiredMode = RequiredMode.NOT_REQUIRED)
refresh: String
) {
def accessToken: AccessToken = AccessToken(token)
def refreshToken: RefreshToken = RefreshToken(refresh)
}

object TokensWrapper {
def fromTokens(accessToken: AccessToken, refreshToken: RefreshToken): TokensWrapper = {
TokensWrapper(accessToken.token, refreshToken.token)
}
}

case class AccessToken(
token: String
) extends Token

case class RefreshToken(
@JsonProperty("refresh")
@Schema(example = "ab12.cd34.ef56", requiredMode = RequiredMode.NOT_REQUIRED)
token: String,
) extends Token

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package za.co.absa.loginsvc.rest.service

import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.{JWKSet, KeyUse, RSAKey}
import io.jsonwebtoken.security.Keys
import io.jsonwebtoken.{JwtBuilder, Jwts, SignatureAlgorithm}
import io.jsonwebtoken._
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
Expand All @@ -33,9 +31,8 @@ import za.co.absa.loginsvc.utils.OptionUtils.ImplicitBuilderExt
import java.security.interfaces.RSAPublicKey
import java.security.{KeyPair, PublicKey}
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.concurrent.{Executors, TimeUnit}
import java.util.Date
import java.util.concurrent.{Executors, TimeUnit}
import scala.collection.JavaConverters._
import scala.compat.java8.DurationConverters._
import scala.concurrent.duration.FiniteDuration
Expand All @@ -55,8 +52,10 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
scheduleSecretsRefresh(refreshTime)
}

def generateAccessToken(user: User): AccessToken = {
logger.info(s"Generating Token for user: ${user.name}")
def generateAccessToken(user: User, isRefresh: Boolean = false): AccessToken = {
val msgIntro = if (isRefresh) "Refreshing" else "Generating new"
logger.info(s"$msgIntro token for user: ${user.name}")

import scala.collection.JavaConverters._

val expiration = Date.from(
Expand Down Expand Up @@ -119,7 +118,7 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider) {
.parseClaimsJws(refreshToken.token) // checks username, validity, and signature.


val refreshedAccessToken = generateAccessToken(userFromOldAccessToken) // same process as with normal generation
val refreshedAccessToken = generateAccessToken(userFromOldAccessToken, isRefresh = true) // same process as with normal generation, but different msg

// we are giving the original still-valid refreshToken back - potentially making room here to revoke or regenerate refreshTokens later
(refreshedAccessToken, refreshToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package za.co.absa.loginsvc.rest.controller

import com.nimbusds.jose.jwk.{JWKSet, RSAKey}
import io.jsonwebtoken.security.Keys
import io.jsonwebtoken.{ExpiredJwtException, MalformedJwtException, SignatureAlgorithm}
import io.jsonwebtoken.security.{Keys, SignatureException}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.when
import org.scalatest.flatspec.AnyFlatSpec
Expand All @@ -29,12 +29,11 @@ import org.springframework.context.annotation.Import
import org.springframework.test.web.servlet.MockMvc
import za.co.absa.loginsvc.model.User
import za.co.absa.loginsvc.rest.model.{AccessToken, RefreshToken}
import za.co.absa.loginsvc.rest.{FakeAuthentication, RestResponseEntityExceptionHandler, SecurityConfig}
import za.co.absa.loginsvc.rest.service.JWTService
import za.co.absa.loginsvc.rest.{FakeAuthentication, RestResponseEntityExceptionHandler, SecurityConfig}

import java.security.interfaces.RSAPublicKey
import java.util.Base64
import javax.servlet.http.Cookie
import scala.concurrent.duration._

@Import(Array(classOf[SecurityConfig], classOf[RestResponseEntityExceptionHandler]))
Expand Down Expand Up @@ -69,8 +68,7 @@ class TokenControllerTest extends AnyFlatSpec
"/token/generate",
Post()
)(
expectedJsonBody = s"""{"token": "${fakeAccessJwt.token}"}""",
expectedHeaderContaining = Some(("Set-Cookie", Set(s"refresh=${fakeRefreshJwt.token}", "Secure; HttpOnly")))
expectedJsonBody = s"""{"token": "${fakeAccessJwt.token}", "refresh": "${fakeRefreshJwt.token}"}"""
)(Some(FakeAuthentication.fakeUserAuthentication))
}

Expand All @@ -85,8 +83,7 @@ class TokenControllerTest extends AnyFlatSpec
"/token/generate?group-prefixes=first",
Post()
)(
expectedJsonBody = s"""{"token": "${fakeAccessJwt.token}"}""",
expectedHeaderContaining = Some(("Set-Cookie", Set(s"refresh=${fakeRefreshJwt.token}", "Secure; HttpOnly")))
expectedJsonBody = s"""{"token": "${fakeAccessJwt.token}", "refresh": "${fakeRefreshJwt.token}"}"""
)(Some(FakeAuthentication.fakeUserAuthentication))
}

Expand All @@ -100,13 +97,12 @@ class TokenControllerTest extends AnyFlatSpec
"/token/generate?group-prefixes=second,third,nonexistent",
Post()
)(
expectedJsonBody = s"""{"token": "${fakeAccessJwt.token}"}""",
expectedHeaderContaining = Some(("Set-Cookie", Set(s"refresh=${fakeRefreshJwt.token}", "Secure; HttpOnly")))
expectedJsonBody = s"""{"token": "${fakeAccessJwt.token}", "refresh": "${fakeRefreshJwt.token}"}"""
)(Some(FakeAuthentication.fakeUserAuthentication))
}

it should "fail for anonymous (not authenticated) user" in {
when(jwtService.generateAccessToken(any[User]())).thenReturn(fakeAccessJwt)
when(jwtService.generateAccessToken(any[User], any[Boolean])).thenReturn(fakeAccessJwt)

assertNotAuthenticatedFailure(
"/token/generate",
Expand All @@ -126,13 +122,10 @@ class TokenControllerTest extends AnyFlatSpec
assertExpectedResponseFields(
"/token/refresh",
Post(
cookies = Some(Seq(new Cookie("refresh", fakeRefreshJwt.token))),
body = Some(s"""{"token": "${fakeAccessJwt.token}"}""")

body = Some(s"""{"token": "${fakeAccessJwt.token}", "refresh": "${fakeRefreshJwt.token}"}""")
)
)(
expectedJsonBody = s"""{"token": "${newFakeAccessJwt.token}"}""",
expectedHeaderContaining = Some(("Set-Cookie", Set(s"refresh=${newFakeRefreshJwt.token}", "Secure; HttpOnly")))
expectedJsonBody = s"""{"token": "${newFakeAccessJwt.token}", "refresh": "${newFakeRefreshJwt.token}"}"""
)(auth = None)
}

Expand All @@ -143,8 +136,7 @@ class TokenControllerTest extends AnyFlatSpec
assertErrorStatusAndResultBodyJsonEquals(
"/token/refresh",
Post(
cookies = Some(Seq(new Cookie("refresh", fakeRefreshJwt.token))),
body = Some(s"""{"token": "${fakeAccessJwt.token}"}""")
body = Some(s"""{"token": "${fakeAccessJwt.token}", "refresh": "${fakeRefreshJwt.token}"}""")
),
expectedStatus = 400,
expectedJson =
Expand All @@ -159,8 +151,8 @@ class TokenControllerTest extends AnyFlatSpec
assertErrorStatusAndResultBodyJsonEquals(
"/token/refresh",
Post(
cookies = Some(Seq(new Cookie("refresh", fakeRefreshJwt.token))),
body = Some(s"""{"token": "${fakeAccessJwt.token}"}""")),
body = Some(s"""{"token": "${fakeAccessJwt.token}", "refresh": "${fakeRefreshJwt.token}"}""")
),
expectedStatus = 401,
expectedJson = s"""{
| "message": "expired jwt"
Expand Down

0 comments on commit a232b72

Please sign in to comment.