Skip to content

Commit

Permalink
Merge pull request #180 from lasanthaS/pkce-fix
Browse files Browse the repository at this point in the history
Enable PKCE in OIDC federated login flow
  • Loading branch information
ThaminduR authored Apr 17, 2024
2 parents 80d2408 + 8e23d68 commit 55ddd90
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
--add-opens java.base/sun.nio.fs=ALL-UNNAMED
--add-opens java.base/sun.nio.cs=ALL-UNNAMED
--add-opens java.base/sun.net.www.protocol.https=ALL-UNNAMED
--add-opens java.base/java.security=ALL-UNNAMED
</argLine>
<suiteXmlFiles>
<suiteXmlFile>src/test/resources/testng.xml</suiteXmlFile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ private OIDCAuthenticatorConstants() {
/**
* This class holds the constants related to authenticator configuration parameters.
*/

public static final String PKCE_CODE_VERIFIER = "PKCE_CODE_VERIFIER";
public static final String IS_PKCE_ENABLED = "IsPKCEEnabled";

public class AuthenticatorConfParams {

private AuthenticatorConfParams() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -138,6 +141,7 @@ public class OpenIDConnectAuthenticator extends AbstractApplicationAuthenticator

private static final Log LOG = LogFactory.getLog(OpenIDConnectAuthenticator.class);
private static final String OIDC_DIALECT = "http://wso2.org/oidc/claim";
private static final String PKCE_CODE_CHALLENGE_METHOD = "S256";

private static final String DYNAMIC_PARAMETER_LOOKUP_REGEX = "\\$\\{(\\w+)\\}";
private static final String IS_API_BASED = "IS_API_BASED";
Expand All @@ -150,6 +154,11 @@ public class OpenIDConnectAuthenticator extends AbstractApplicationAuthenticator
private static final String[] NON_USER_ATTRIBUTES = new String[]{"at_hash", "iss", "iat", "exp", "aud", "azp"};
private static final String AUTHENTICATOR_MESSAGE = "authenticatorMessage";

private static final String IS_PKCE_ENABLED_NAME = "isPKCEEnabled";
private static final String IS_PKCE_ENABLED_DISPLAY_NAME = "Enable PKCE";
private static final String IS_PKCE_ENABLED_DESCRIPTION = "Specifies that PKCE should be used for client authentication";
private static final String TYPE_BOOLEAN = "boolean";

@Override
public AuthenticatorFlowStatus process(HttpServletRequest request, HttpServletResponse response,
AuthenticationContext context)
Expand Down Expand Up @@ -514,6 +523,8 @@ protected String prepareLoginPage(HttpServletRequest request, AuthenticationCont
context.setProperty(OIDCAuthenticatorConstants.AUTHENTICATOR_NAME + STATE_PARAM_SUFFIX, state);
String nonce = UUID.randomUUID().toString();
context.setProperty(OIDC_FEDERATION_NONCE, nonce);
boolean isPKCEEnabled = Boolean.parseBoolean(
authenticatorProperties.get(OIDCAuthenticatorConstants.IS_PKCE_ENABLED));

OAuthClientRequest authzRequest;

Expand Down Expand Up @@ -585,6 +596,15 @@ protected String prepareLoginPage(HttpServletRequest request, AuthenticationCont
loginPage = loginPage + "&fidp=" + domain;
}

// If PKCE is enabled, add code_challenge and code_challenge_method to the request.
if (isPKCEEnabled) {
String codeVerifier = generateCodeVerifier();
context.setProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER, codeVerifier);
String codeChallenge = generateCodeChallenge(codeVerifier);
loginPage += "&code_challenge=" + codeChallenge + "&code_challenge_method="
+ PKCE_CODE_CHALLENGE_METHOD;
}

if (StringUtils.isNotBlank(queryString)) {
if (!queryString.startsWith("&")) {
loginPage = loginPage + "&" + queryString;
Expand Down Expand Up @@ -1467,6 +1487,9 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
String clientId = authenticatorProperties.get(OIDCAuthenticatorConstants.CLIENT_ID);
String clientSecret = authenticatorProperties.get(OIDCAuthenticatorConstants.CLIENT_SECRET);
String tokenEndPoint = getTokenEndpoint(authenticatorProperties);
boolean isPKCEEnabled = Boolean.parseBoolean(
authenticatorProperties.get(OIDCAuthenticatorConstants.IS_PKCE_ENABLED));
String codeVerifier = (String) context.getProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER);

String callbackUrl = getCallbackUrlFromInitialRequestParamMap(context);
if (StringUtils.isBlank(callbackUrl)) {
Expand All @@ -1489,9 +1512,20 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
"authentication scheme.");
}

accessTokenRequest = OAuthClientRequest.tokenLocation(tokenEndPoint).setGrantType(GrantType
.AUTHORIZATION_CODE).setRedirectURI(callbackUrl).setCode(authzResponse.getCode())
.buildBodyMessage();
OAuthClientRequest.TokenRequestBuilder tokenRequestBuilder = OAuthClientRequest
.tokenLocation(tokenEndPoint)
.setGrantType(GrantType.AUTHORIZATION_CODE)
.setRedirectURI(callbackUrl)
.setCode(authzResponse.getCode());

if (isPKCEEnabled) {
if (StringUtils.isEmpty(codeVerifier)) {
throw new AuthenticationFailedException("PKCE is enabled, but the code verifier is not found.");
}
tokenRequestBuilder.setParameter("code_verifier", codeVerifier);
}

accessTokenRequest = tokenRequestBuilder.buildBodyMessage();
String base64EncodedCredential = new String(Base64.encodeBase64((clientId + ":" +
clientSecret).getBytes()));
accessTokenRequest.addHeader(OAuth.HeaderType.AUTHORIZATION, "Basic " + base64EncodedCredential);
Expand All @@ -1501,10 +1535,22 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
LOG.debug("Authenticating to token endpoint: " + tokenEndPoint + " including client credentials "
+ "in request body.");
}
accessTokenRequest = OAuthClientRequest.tokenLocation(tokenEndPoint).setGrantType(GrantType
.AUTHORIZATION_CODE).setClientId(clientId).setClientSecret(clientSecret).setRedirectURI
(callbackUrl).setCode(authzResponse.getCode()).buildBodyMessage();
OAuthClientRequest.TokenRequestBuilder tokenRequestBuilder = OAuthClientRequest
.tokenLocation(tokenEndPoint)
.setGrantType(GrantType.AUTHORIZATION_CODE)
.setClientId(clientId)
.setClientSecret(clientSecret)
.setRedirectURI(callbackUrl)
.setCode(authzResponse.getCode());
if (isPKCEEnabled) {
if (StringUtils.isEmpty(codeVerifier)) {
throw new AuthenticationFailedException("PKCE is enabled, but the code verifier is not found.");
}
tokenRequestBuilder.setParameter("code_verifier", codeVerifier);
}
accessTokenRequest = tokenRequestBuilder.buildBodyMessage();
}
context.removeProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER);
// set 'Origin' header to access token request.
if (accessTokenRequest != null) {
// fetch the 'Hostname' configured in carbon.xml
Expand All @@ -1522,7 +1568,6 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
} catch (URLBuilderException e) {
throw new RuntimeException("Error occurred while building URL in tenant qualified mode.", e);
}

return accessTokenRequest;
}

Expand Down Expand Up @@ -1692,6 +1737,15 @@ public List<Property> getConfigurationProperties() {
enableBasicAuth.setDisplayOrder(10);
configProperties.add(enableBasicAuth);

Property enablePKCE = new Property();
enablePKCE.setName(IS_PKCE_ENABLED_NAME);
enablePKCE.setDisplayName(IS_PKCE_ENABLED_DISPLAY_NAME);
enablePKCE.setRequired(false);
enablePKCE.setDescription(IS_PKCE_ENABLED_DESCRIPTION);
enablePKCE.setType(TYPE_BOOLEAN);
enablePKCE.setDisplayOrder(10);
configProperties.add(enablePKCE);

return configProperties;
}

Expand Down Expand Up @@ -2147,4 +2201,35 @@ private String getFederatedAuthenticatorName(AuthenticationContext context) {
}
return context.getExternalIdP().getIdPName();
}

/**
* Generate code verifier for PKCE
*
* @return code verifier
*/
private String generateCodeVerifier() {
SecureRandom secureRandom = new SecureRandom();
byte[] codeVerifier = new byte[32];
secureRandom.nextBytes(codeVerifier);
return java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifier);
}

/**
* Generate code challenge for PKCE
*
* @param codeVerifier code verifier
* @return code challenge
* @throws AuthenticationFailedException
*/
private String generateCodeChallenge(String codeVerifier) throws AuthenticationFailedException {
try {
byte[] bytes = codeVerifier.getBytes("US-ASCII");
MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
messageDigest.update(bytes, 0, bytes.length);
byte[] digest = messageDigest.digest();
return java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
} catch (UnsupportedEncodingException | NoSuchAlgorithmException e) {
throw new AuthenticationFailedException("Error while generating code challenge", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor;
import org.powermock.modules.testng.PowerMockTestCase;
Expand Down Expand Up @@ -123,6 +124,7 @@
LoggerUtils.class, OIDCTokenValidationUtil.class, IdentityProviderManager.class})
@SuppressStaticInitializationFor({"org.wso2.carbon.idp.mgt.IdentityProviderManager",
"org.wso2.carbon.identity.application.authentication.framework.exception.AuthenticationFailedException"})
@PowerMockIgnore("jdk.internal.reflect.*")
public class OpenIDConnectAuthenticatorTest extends PowerMockTestCase {

private static final String OIDC_PARAM_MAP_STRING = "oidc:param.map";
Expand Down Expand Up @@ -480,7 +482,7 @@ public void testInitiateAuthenticationRequestNullProperties() throws OAuthSystem
public void testPassProcessAuthenticationResponse() throws Exception {

setupTest();

authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
IdentityProviderProperty property = new IdentityProviderProperty();
property.setName(IdPManagementConstants.IS_TRUSTED_TOKEN_ISSUER);
property.setValue("false");
Expand Down Expand Up @@ -521,6 +523,7 @@ public void testPassProcessAuthenticationResponseWithNonce() throws Exception {
when(mockAuthenticationContext.getExternalIdP()).thenReturn(externalIdPConfig);
when(externalIdPConfig.getIdentityProvider()).thenReturn(identityProvider);
when(identityProvider.getIdpProperties()).thenReturn(identityProviderProperties);
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
when(openIDConnectAuthenticatorDataHolder.getClaimMetadataManagementService()).thenReturn
(claimMetadataManagementService);
when(mockAuthenticationContext.getExternalIdP()).thenReturn(externalIdPConfig);
Expand All @@ -539,6 +542,29 @@ public void testPassProcessAuthenticationResponseWithNonce() throws Exception {
"Invalid Id token in the authentication context.");
}

/**
* Test whether the token request contains the code verifier when PKCE is enabled.
*
* @throws URLBuilderException
* @throws AuthenticationFailedException
*/
@Test()
public void testGetAccessTokenRequestWithPKCE() throws URLBuilderException, AuthenticationFailedException {
mockAuthenticationRequestContext(mockAuthenticationContext);
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "true");
when(mockAuthenticationContext.getProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER))
.thenReturn("sample_code_verifier");
when(mockOAuthzResponse.getCode()).thenReturn("abc");
mockStatic(ServiceURLBuilder.class);
ServiceURLBuilder serviceURLBuilder = mock(ServiceURLBuilder.class);
when(ServiceURLBuilder.create()).thenReturn(serviceURLBuilder);
when(serviceURLBuilder.build()).thenReturn(serviceURL);
when(serviceURL.getAbsolutePublicURL()).thenReturn("http://localhost:9443");
OAuthClientRequest request = openIDConnectAuthenticator
.getAccessTokenRequest(mockAuthenticationContext, mockOAuthzResponse);
assertTrue(request.getBody().contains("code_verifier=sample_code_verifier"));
}

@Test
public void testPassProcessAuthenticationResponseWithoutAccessToken() throws Exception {

Expand All @@ -558,6 +584,7 @@ public void testPassProcessAuthenticationWithBlankCallBack() throws Exception {

setupTest();
authenticatorProperties.put("callbackUrl", " ");
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
mockStatic(IdentityUtil.class);
when(IdentityUtil.getServerURL(FrameworkConstants.COMMONAUTH, true, true))
.thenReturn("http:/localhost:9443/oauth2/callback");
Expand Down Expand Up @@ -618,6 +645,7 @@ public void testPassProcessAuthenticationWithParamValue() throws Exception {
setupTest();
when(LoggerUtils.isDiagnosticLogsEnabled()).thenReturn(true);
authenticatorProperties.put("callbackUrl", "http://localhost:8080/playground2/oauth2client");
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
Map<String, String> paramMap = new HashMap<>();
paramMap.put("redirect_uri", "http:/localhost:9443/oauth2/redirect");
when(mockAuthenticationContext.getProperty(OIDC_PARAM_MAP_STRING)).thenReturn(paramMap);
Expand Down

0 comments on commit 55ddd90

Please sign in to comment.