Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancements for SAML2 bearer and IdP initiated SSO #3136

Merged
merged 15 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@
import java.util.HashMap;
import java.util.Map;

import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.GRANT_TYPE;
import static org.cloudfoundry.identity.uaa.oauth.token.TokenConstants.GRANT_TYPE_JWT_BEARER;
import static org.cloudfoundry.identity.uaa.oauth.token.TokenConstants.GRANT_TYPE_PASSWORD;
import static org.cloudfoundry.identity.uaa.oauth.token.TokenConstants.GRANT_TYPE_SAML2_BEARER;

/**
* Provides an implementation that sets the UserAuthentication
* prior to createAuthorizatioRequest is called.
* prior to createAuthorizationRequest is called.
* Backwards compatible with Spring Security Oauth2 v1
* This is a copy of the TokenEndpointAuthenticationFilter from Spring Security Oauth2 v2, but made to work with UAA
*/
Expand Down Expand Up @@ -157,18 +158,18 @@ public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
.getContext()
.setAuthentication(new OAuth2Authentication(storedOAuth2Request, userAuthentication));

onSuccessfulAuthentication(request, response, userAuthentication);
onSuccessfulAuthentication();
}
} catch (AuthenticationException failed) {
log.debug("Authentication request failed: {}", failed.getMessage());
onUnsuccessfulAuthentication(request, response, failed);
onUnsuccessfulAuthentication();
authenticationEntryPoint.commence(request, response, failed);
return;
} catch (OAuth2Exception failed) {
String message = failed.getMessage();
log.debug("Authentication request failed with Oauth exception: {}", message);
InsufficientAuthenticationException ex = new InsufficientAuthenticationException(message, failed);
onUnsuccessfulAuthentication(request, response, ex);
onUnsuccessfulAuthentication();
authenticationEntryPoint.commence(request, response, ex);
return;
}
Expand All @@ -186,14 +187,11 @@ private Map<String, String> getSingleValueMap(HttpServletRequest request) {
return map;
}

protected void onSuccessfulAuthentication(HttpServletRequest request,
HttpServletResponse response,
Authentication authResult) {
protected void onSuccessfulAuthentication() {
// do nothing
}

protected void onUnsuccessfulAuthentication(HttpServletRequest request,
HttpServletResponse response,
AuthenticationException failed) {
protected void onUnsuccessfulAuthentication() {
SecurityContextHolder.clearContext();
}

Expand All @@ -214,7 +212,7 @@ protected Authentication extractCredentials(HttpServletRequest request) {
}

protected Authentication attemptTokenAuthentication(HttpServletRequest request, HttpServletResponse response) {
String grantType = request.getParameter("grant_type");
String grantType = request.getParameter(GRANT_TYPE);
log.debug("Processing token user authentication for grant:{}", UaaStringUtils.getCleanedUserControlString(grantType));
Authentication authResult = null;
if (GRANT_TYPE_PASSWORD.equals(grantType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,12 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) {
return createRelyingPartyRegistration(foundSamlIdentityProviderDefinition.getIdpEntityAlias(), foundSamlIdentityProviderDefinition, currentZone);
}

List<SamlIdentityProviderDefinition> identityProviderDefinitions = configurator.getIdentityProviderDefinitionsForZone(currentZone);
for (SamlIdentityProviderDefinition identityProviderDefinition : identityProviderDefinitions) {
for (SamlIdentityProviderDefinition identityProviderDefinition : configurator.getIdentityProviderDefinitionsForZone(currentZone)) {
if (registrationId.equals(identityProviderDefinition.getIdpEntityAlias()) || registrationId.equals(identityProviderDefinition.getIdpEntityId())) {
return createRelyingPartyRegistration(identityProviderDefinition.getIdpEntityAlias(), identityProviderDefinition, currentZone);
}
}

// TODO remove hack
if (!identityProviderDefinitions.isEmpty() && identityProviderDefinitions.size() == 1) {
SamlIdentityProviderDefinition identityProviderDefinition = identityProviderDefinitions.get(0);
return createRelyingPartyRegistration(identityProviderDefinition.getIdpEntityAlias(), identityProviderDefinition, currentZone);
}
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.impl.AssertionUnmarshaller;
import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
Expand Down Expand Up @@ -99,13 +101,16 @@ public final class Saml2BearerGrantAuthenticationConverter implements Authentica
}

private static final AssertionUnmarshaller assertionUnmarshaller;
private static final ResponseUnmarshaller responseUnMarshaller;

private static final ParserPool parserPool;

static {
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
assertionUnmarshaller = (AssertionUnmarshaller) registry.getUnmarshallerFactory()
.getUnmarshaller(Assertion.DEFAULT_ELEMENT_NAME);
responseUnMarshaller = (ResponseUnmarshaller) registry.getUnmarshallerFactory()
.getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
parserPool = registry.getParserPool();
}

Expand Down Expand Up @@ -305,6 +310,22 @@ private static Assertion parseAssertion(String assertion) throws Saml2Exception,
}
}

protected static Response parseSamlResponse(String samlResponse) throws Saml2Exception, Saml2AuthenticationException {
try {
Document document = parserPool
.parse(new ByteArrayInputStream(samlResponse.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
return (Response) responseUnMarshaller.unmarshall(element);
} catch (Exception ex) {
throw OpenSaml4AuthenticationProvider.createAuthenticationException(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to parse saml response", ex);
}
}

protected static String getIssuer(Response response) {
return Optional.ofNullable(response.getIssuer()).map(Issuer::getValue)
.orElseThrow(() -> new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Missing issuer in saml response")));
}

private static String getIssuer(Assertion assertion) {
return Optional.ofNullable(assertion.getIssuer()).map(Issuer::getValue)
.orElseThrow(() -> new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_ASSERTION, "Missing issuer in bearer assertion")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.cloudfoundry.identity.uaa.web.UaaSavedRequestAwareAuthenticationSuccessHandler;
import org.cloudfoundry.identity.uaa.zone.beans.IdentityZoneManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -20,7 +21,6 @@
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponseValidator;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationRequestFilter;
Expand Down Expand Up @@ -112,13 +112,12 @@ AuthenticationProvider samlAuthenticationProvider(IdentityZoneManager identityZo
@Autowired
@Bean
Filter saml2WebSsoAuthenticationFilter(AuthenticationProvider samlAuthenticationProvider,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
UaaRelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
SecurityContextRepository securityContextRepository,
SamlLoginAuthenticationFailureHandler samlLoginAuthenticationFailureHandler,
UaaSavedRequestAwareAuthenticationSuccessHandler samlLoginAuthenticationSuccessHandler) {

RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository);
Saml2AuthenticationTokenConverter saml2AuthenticationTokenConverter = new Saml2AuthenticationTokenConverter(relyingPartyRegistrationResolver);
Saml2AuthenticationTokenConverter saml2AuthenticationTokenConverter = new Saml2AuthenticationTokenConverter((RelyingPartyRegistrationResolver)relyingPartyRegistrationResolver);
Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(saml2AuthenticationTokenConverter, BACKWARD_COMPATIBLE_ASSERTION_CONSUMER_FILTER_PROCESSES_URI);

ProviderManager authenticationManager = new ProviderManager(samlAuthenticationProvider);
Expand Down Expand Up @@ -223,10 +222,9 @@ Saml2LogoutResponseFilter saml2LogoutResponseFilter(RelyingPartyRegistrationReso
*/
@Autowired
@Bean
Saml2LogoutRequestFilter saml2LogoutRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
Saml2LogoutRequestFilter saml2LogoutRequestFilter(UaaRelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
UaaAuthenticationFailureHandler authenticationFailureHandler,
CookieBasedCsrfTokenRepository loginCookieCsrfRepository) {
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository);
strehle marked this conversation as resolved.
Show resolved Hide resolved

// This validator ignores missing signatures in the SAML2 Logout Response
Saml2LogoutRequestValidator logoutRequestValidator = new SamlLogoutRequestValidator();
Expand All @@ -253,9 +251,8 @@ Saml2BearerGrantAuthenticationConverter samlBearerGrantAuthenticationProvider(Id
final JdbcIdentityProviderProvisioning identityProviderProvisioning,
SamlUaaAuthenticationUserManager samlUaaAuthenticationUserManager,
ApplicationEventPublisher applicationEventPublisher,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
UaaRelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {

RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository);
return new Saml2BearerGrantAuthenticationConverter(relyingPartyRegistrationResolver, identityZoneManager,
identityProviderProvisioning, samlUaaAuthenticationUserManager, applicationEventPublisher);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -103,7 +101,8 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti

@Autowired
@Bean
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
return new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository);
UaaRelyingPartyRegistrationResolver relyingPartyRegistrationResolver(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
@Qualifier("samlEntityID") String samlEntityID) {
return new UaaRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository, samlEntityID);
}
}
Loading
Loading