Skip to content

Commit

Permalink
use copy userCredentialHash for jwtCachekey
Browse files Browse the repository at this point in the history
  • Loading branch information
jakelandis committed Nov 3, 2023
1 parent 4fc58af commit ddadcca
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
Expand Down Expand Up @@ -283,7 +284,9 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac
return; // FAILED (secret is missing or mismatched)
}

final BytesArray jwtCacheKey = isCacheEnabled() ? new BytesArray(jwtAuthenticationToken.getUserCredentialsHash()) : null;
final BytesArray jwtCacheKey = isCacheEnabled()
? new BytesArray(new BytesRef(jwtAuthenticationToken.getUserCredentialsHash()), true)
: null;
if (jwtCacheKey != null) {
final User cachedUser = tryAuthenticateWithCache(tokenPrincipal, jwtCacheKey);
if (cachedUser != null) {
Expand Down Expand Up @@ -483,6 +486,11 @@ private boolean isCacheEnabled() {
return jwtCache != null && jwtCacheHelper != null;
}

// package private for testing
Cache<BytesArray, ExpiringUser> getJwtCache() {
return jwtCache;
}

/**
* Format and filter JWT contents as user metadata.
* @param claimsSet Claims are supported. Claim keys are prefixed by "jwt_claim_".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ public void testJwtAuthcRealmAuthcAuthzWithEmptyRoles() throws Exception {
doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
}

// this test is mostly a duplicate of others tests but this test guarantees the cache is used which invoke downstream assertions
// https://github.com/elastic/elasticsearch/issues/101752
public void testJwtCache() throws Exception {
jwtIssuerAndRealms = generateJwtIssuerRealmPairs(1, 1, 1, 1, 1, 1, 99, false);
final JwtIssuerAndRealm jwtIssuerAndRealm = randomJwtIssuerRealmPair();
final User user = randomUser(jwtIssuerAndRealm.issuer());
final SecureString jwt = randomJwt(jwtIssuerAndRealm, user);
final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, 10);
}

/**
* Test with no authz realms.
* @throws Exception Unexpected test failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import com.nimbusds.openid.connect.sdk.Nonce;

import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.settings.MockSecureSettings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -46,6 +48,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
Expand Down Expand Up @@ -290,7 +293,7 @@ protected JwtRealmSettingsBuilder createJwtRealmSettingsBuilder(final JwtIssuer
if (randomBoolean()) {
authcSettings.put(
RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_TTL),
randomIntBetween(10, 120) + randomFrom("s", "m", "h")
randomIntBetween(10, 120) + randomFrom("m", "h")
);
}
authcSettings.put(RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_SIZE), jwtCacheSize);
Expand Down Expand Up @@ -378,11 +381,12 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
final int jwtAuthcRepeats
) {
final List<JwtRealm> jwtRealmsList = jwtIssuerAndRealms.stream().map(p -> p.realm).toList();

BytesArray firstCacheKeyFound = null;
// Select different test JWKs from the JWT realm, and generate test JWTs for the test user. Run the JWT through the chain.
for (int authcRun = 1; authcRun <= jwtAuthcRepeats; authcRun++) {

final ThreadContext requestThreadContext = createThreadContext(jwt, sharedSecret);
logger.info("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());
logger.debug("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());

// Any JWT realm can recognize and extract the request headers.
final var jwtAuthenticationToken = (JwtAuthenticationToken) randomFrom(jwtRealmsList).token(requestThreadContext);
Expand All @@ -393,11 +397,11 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
// Loop through all authc/authz realms. Confirm user is returned with expected principal and roles.
User authenticatedUser = null;
realmLoop: for (final JwtRealm candidateJwtRealm : jwtRealmsList) {
logger.info("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
logger.debug("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
final PlainActionFuture<AuthenticationResult<User>> authenticateFuture = PlainActionFuture.newFuture();
candidateJwtRealm.authenticate(jwtAuthenticationToken, authenticateFuture);
final AuthenticationResult<User> authenticationResult = authenticateFuture.actionGet();
logger.info("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult);
logger.debug("Authentication result with realm [{}]: [{}]", candidateJwtRealm.name(), authenticationResult);
switch (authenticationResult.getStatus()) {
case SUCCESS:
assertThat("Unexpected realm SUCCESS status", candidateJwtRealm.name(), equalTo(jwtRealm.name()));
Expand Down Expand Up @@ -430,20 +434,41 @@ protected void doMultipleAuthcAuthzAndVerifySuccess(
equalTo(Map.of("jwt_token_type", JwtRealmInspector.getTokenType(jwtRealm).value()))
);
}
// if the cache is enabled ensure the cache is used and does not change for the provided jwt
if (jwtRealm.getJwtCache() != null) {
Cache<BytesArray, JwtRealm.ExpiringUser> cache = jwtRealm.getJwtCache();
if (firstCacheKeyFound == null) {
assertNotNull("could not find cache keys", cache.keys());
firstCacheKeyFound = cache.keys().iterator().next();
}
jwtAuthenticationToken.clearCredentials(); // simulates the realm's context closing which clears the credential
boolean foundInCache = false;
for (BytesArray key : cache.keys()) {
logger.trace("cache key: " + HexFormat.of().formatHex(key.array()));
if (key.equals(firstCacheKeyFound)) {
foundInCache = true;
}
assertFalse(
"cache key should not be nulled out",
IntStream.range(0, key.array().length).map(idx -> key.array()[idx]).allMatch(b -> b == 0)
);
}
assertTrue("cache key was not found in cache", foundInCache);
}
}
logger.info("Test succeeded");
logger.debug("Test succeeded");
}

protected User randomUser(final JwtIssuer jwtIssuer) {
final User user = randomFrom(jwtIssuer.principals.values());
logger.info("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "].");
logger.debug("USER[" + user.principal() + "]: roles=[" + String.join(",", user.roles()) + "].");
return user;
}

protected SecureString randomJwt(final JwtIssuerAndRealm jwtIssuerAndRealm, User user) throws Exception {
final JwtIssuer.AlgJwkPair algJwkPair = randomFrom(jwtIssuerAndRealm.issuer.algAndJwksAll);
final JWK jwk = algJwkPair.jwk();
logger.info(
logger.debug(
"ALG["
+ algJwkPair.alg()
+ "]. JWK: kty=["
Expand Down Expand Up @@ -491,7 +516,7 @@ protected void printJwtRealmAndIssuer(JwtIssuerAndRealm jwtIssuerAndRealm) throw
}

protected void printJwtRealm(final JwtRealm jwtRealm) {
logger.info(
logger.debug(
"REALM["
+ jwtRealm.name()
+ ","
Expand Down Expand Up @@ -527,15 +552,15 @@ protected void printJwtRealm(final JwtRealm jwtRealm) {
+ "]."
);
for (final JWK jwk : JwtRealmInspector.getJwksAlgsHmac(jwtRealm).jwks()) {
logger.info("REALM HMAC: jwk=[{}]", jwk);
logger.debug("REALM HMAC: jwk=[{}]", jwk);
}
for (final JWK jwk : JwtRealmInspector.getJwksAlgsPkc(jwtRealm).jwks()) {
logger.info("REALM PKC: jwk=[{}]", jwk);
logger.debug("REALM PKC: jwk=[{}]", jwk);
}
}

protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
logger.info(
logger.debug(
"ISSUER: iss=["
+ jwtIssuer.issuerClaimValue
+ "], aud=["
Expand All @@ -549,13 +574,13 @@ protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
+ "]."
);
if (jwtIssuer.algAndJwkHmacOidc != null) {
logger.info("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
logger.debug("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
}
for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksHmac) {
logger.info("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
logger.debug("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
}
for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksPkc) {
logger.info("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
logger.debug("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
}
}
}

0 comments on commit ddadcca

Please sign in to comment.