diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java index c881fa451e..f34c3f0912 100644 --- a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java @@ -12,7 +12,6 @@ import java.util.concurrent.TimeUnit; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,8 +35,8 @@ @RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class IpBruteForceAttacksPreventionTests { - private static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); - private static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); + static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); + static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); public static final int ALLOWED_TRIES = 3; public static final int TIME_WINDOW_SECONDS = 3; @@ -51,7 +50,7 @@ public class IpBruteForceAttacksPreventionTests { public static final String CLIENT_IP_8 = "127.0.0.8"; public static final String CLIENT_IP_9 = "127.0.0.9"; - private static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( + static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( new RateLimiting("internal_authentication_backend_limiting").type("ip") .allowedTries(ALLOWED_TRIES) .timeWindowSeconds(TIME_WINDOW_SECONDS) @@ -60,13 +59,17 @@ public class IpBruteForceAttacksPreventionTests { .maxTrackedClients(500) ); - @ClassRule - public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) - .anonymousAuth(false) - .authFailureListeners(listener) - .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) - .users(USER_1, USER_2) - .build(); + @Rule + public LocalCluster cluster = createCluster(); + + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) + .users(USER_1, USER_2) + .build(); + } @Rule public LogsRule logsRule = new LogsRule("org.opensearch.security.auth.BackendRegistry"); @@ -151,7 +154,7 @@ public void shouldReleaseIpAddressLock() throws InterruptedException { } } - private static void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { + void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { var clientConfiguration = new TestRestClientConfiguration().username(user.getName()) .password("incorrect password") .sourceInetAddress(sourceIpAddress); diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java new file mode 100644 index 0000000000..6159599119 --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.runner.RunWith; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; + +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class IpBruteForceAttacksPreventionWithDomainChallengeTests extends IpBruteForceAttacksPreventionTests { + @Override + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(USER_1, USER_2) + .build(); + } +} diff --git a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java index c09127e592..77890a4645 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java +++ b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java @@ -344,7 +344,7 @@ public String toString() { String clusterManagerNodes = nodeByTypeToString(CLUSTER_MANAGER); String dataNodes = nodeByTypeToString(DATA); String clientNodes = nodeByTypeToString(CLIENT); - return "\nES Cluster " + return "\nOS Cluster " + clusterName + "\ncluster manager nodes: " + clusterManagerNodes diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index 0c38245586..0f719a2073 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -36,7 +36,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -236,11 +235,10 @@ public String[] extractRoles(JwtClaims claims) { protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception; @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials authCredentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } public String getRequiredAudience() { diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index 38bc006e92..1b394e20d0 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -39,7 +39,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -224,11 +223,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 1848e14879..f2320f0bed 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -49,7 +49,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -280,8 +279,7 @@ public GSSCredential run() throws GSSException { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse; XContentBuilder response = getNegotiateResponseBody(); @@ -291,16 +289,15 @@ public boolean reRequestAuthentication(final RestChannel channel, AuthCredential wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING); } - if (creds == null || creds.getNativeCredentials() == null) { + if (credentials == null || credentials.getNativeCredentials() == null) { wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate"); } else { wwwAuthenticateResponse.addHeader( "WWW-Authenticate", - "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials()) + "Negotiate " + Base64.getEncoder().encodeToString((byte[]) credentials.getNativeCredentials()) ); } - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index ffc9e3637b..24649562d9 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -23,7 +23,6 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPathExpressionException; import com.fasterxml.jackson.core.JsonParseException; @@ -32,7 +31,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.base.Strings; import com.onelogin.saml2.authn.SamlResponse; -import com.onelogin.saml2.exception.SettingsException; import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; @@ -49,7 +47,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; -import org.xml.sax.SAXException; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; @@ -57,7 +54,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; @@ -122,7 +118,7 @@ class AuthTokenProcessorHandler { } @SuppressWarnings("removal") - boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception { + BytesRestResponse handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -130,11 +126,10 @@ boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exceptio sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction() { @Override - public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException, - SAXException, SettingsException { - return handleLowLevel(restRequest, restChannel); + public BytesRestResponse run() throws SamlConfigException, IOException { + return handleLowLevel(restRequest); } }); } catch (PrivilegedActionException e) { @@ -147,13 +142,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc } private AuthTokenProcessorAction.Response handleImpl( - RestRequest restRequest, - RestChannel restChannel, String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings - ) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException { + ) { if (token_log.isDebugEnabled()) { try { token_log.debug( @@ -188,8 +181,7 @@ private AuthTokenProcessorAction.Response handleImpl( } } - private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException, - XPathExpressionException, ParserConfigurationException, SAXException, SettingsException { + private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getMediaType() != XContentType.JSON) { @@ -234,31 +226,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue()); } - AuthTokenProcessorAction.Response responseBody = this.handleImpl( - restRequest, - restChannel, - samlResponseBase64, - samlRequestId, - acsEndpoint, - saml2Settings - ); + AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return false; + return null; } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); - restChannel.sendResponse(authenticateResponse); - - return true; + return new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); - restChannel.sendResponse(authenticateResponse); - return true; + return new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index cd6209952f..64d816fabf 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -56,7 +56,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.Destroyable; @@ -171,13 +170,15 @@ public String getType() { } @Override - public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { try { - RestRequest restRequest = restChannel.request(); - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (API_AUTHTOKEN_SUFFIX.equals(suffix) && this.authTokenProcessorHandler.handle(restRequest, restChannel)) { - return true; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + final BytesRestResponse restResponse = this.authTokenProcessorHandler.handle(request); + if (restResponse != null) { + return restResponse; + } } Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); @@ -185,12 +186,10 @@ public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)); - restChannel.sendResponse(authenticateResponse); - - return true; + return authenticateResponse; } catch (Exception e) { log.error("Error in reRequestAuthentication()", e); - return false; + return null; } } diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 9721664c70..57a04a0f9e 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -229,7 +229,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g User authenticatedUser = null; - AuthCredentials authCredenetials = null; + AuthCredentials authCredentials = null; HTTPAuthenticator firstChallengingHttpAuthenticator = null; @@ -271,7 +271,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - authCredenetials = ac; + authCredentials = ac; if (ac == null) { // no credentials found in request @@ -279,10 +279,18 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - if (authDomain.isChallenge() && httpAuthenticator.reRequestAuthentication(channel, null)) { - auditLog.logFailedLogin("", false, null, request); - log.warn("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); - return false; + if (authDomain.isChallenge()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, null); + if (restResponse != null) { + auditLog.logFailedLogin("", false, null, request); + if (isTraceEnabled) { + log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + } + notifyIpAuthFailureListeners(request, authCredentials); + channel.sendResponse(restResponse); + return false; + } + } else { // no reRequest possible if (isTraceEnabled) { @@ -293,9 +301,12 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } else { org.apache.logging.log4j.ThreadContext.put("user", ac.getUsername()); if (!ac.isComplete()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, ac); // credentials found in request but we need another client challenge - if (httpAuthenticator.reRequestAuthentication(channel, ac)) { + if (restResponse != null) { // auditLog.logFailedLogin(ac.getUsername()+" ", request); --noauditlog + notifyIpAuthFailureListeners(request, ac); + channel.sendResponse(restResponse); return false; } else { // no reRequest possible @@ -373,7 +384,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("User still not authenticated after checking {} auth domains", restAuthDomains.size()); } - if (authCredenetials == null && anonymousAuthEnabled) { + if (authCredentials == null && anonymousAuthEnabled) { final String tenant = Utils.coalesce(request.header("securitytenant"), request.header("security_tenant")); User anonymousUser = new User(User.ANONYMOUS.getName(), new HashSet(User.ANONYMOUS.getRoles()), null); anonymousUser.setRequestedTenant(tenant); @@ -385,6 +396,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } return true; } + BytesRestResponse challengeResponse = null; if (firstChallengingHttpAuthenticator != null) { @@ -392,31 +404,28 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("Rerequest with {}", firstChallengingHttpAuthenticator.getClass()); } - if (firstChallengingHttpAuthenticator.reRequestAuthentication(channel, null)) { + challengeResponse = firstChallengingHttpAuthenticator.reRequestAuthentication(request, null); + if (challengeResponse != null) { if (isDebugEnabled) { log.debug("Rerequest {} failed", firstChallengingHttpAuthenticator.getClass()); } - - log.warn( - "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), - remoteAddress - ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); - return false; } } log.warn( "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), + authCredentials == null ? null : authCredentials.getUsername(), remoteAddress ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); + auditLog.logFailedLogin(authCredentials == null ? null : authCredentials.getUsername(), false, null, request); - notifyIpAuthFailureListeners(request, authCredenetials); + notifyIpAuthFailureListeners(request, authCredentials); - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); + channel.sendResponse( + challengeResponse != null + ? challengeResponse + : new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed") + ); return false; } diff --git a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java index fa5065ef68..70b94d6dbf 100644 --- a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java @@ -28,7 +28,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.user.AuthCredentials; @@ -71,15 +71,14 @@ public interface HTTPAuthenticator { /** * If the {@code extractCredentials()} call was not successful or the authentication flow needs another roundtrip this method - * will be called. If the custom HTTP authenticator does not support this method is a no-op and false should be returned. - * + * will be called. If the custom HTTP authenticator does not support this method is a no-op and null response should be returned. * If the custom HTTP authenticator does support re-request authentication or supports authentication flows with multiple roundtrips - * then the response should be sent (through the channel) and true must be returned. + * then the response will be returned which can then be sent via response channel. * - * @param channel The rest channel to sent back the response via {@code channel.sendResponse()} + * @param request * @param credentials The credentials from the prior authentication attempt - * @return false if re-request is not supported/necessary, true otherwise. - * If true is returned {@code channel.sendResponse()} must be called so that the request completes. + * @return null if re-request is not supported/necessary, response object otherwise. + * If an object is returned {@code channel.sendResponse()} must be called so that the request completes. */ - boolean reRequestAuthentication(final RestChannel channel, AuthCredentials credentials); + BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials); } diff --git a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java index 4be83bc2e2..16a822df6d 100644 --- a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java @@ -34,7 +34,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -65,11 +64,10 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "Unauthorized"); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java index b1e5d4ef40..cea59f7311 100644 --- a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java @@ -41,7 +41,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -98,8 +98,8 @@ public AuthCredentials extractCredentials(final RestRequest request, final Threa } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java index a58a842394..1db7b9769b 100644 --- a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java @@ -37,7 +37,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -89,8 +89,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java index ef20374d69..0423fecefe 100644 --- a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java @@ -37,7 +37,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.security.http.HTTPProxyAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -84,11 +83,6 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte return credentials.markComplete(); } - @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; - } - @Override public String getType() { return "extended-proxy"; diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index ff9ec19b09..2594388128 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -47,6 +47,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; @@ -141,7 +142,8 @@ public void basicTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -188,7 +190,8 @@ public void decryptAssertionsTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -236,7 +239,8 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -287,7 +291,8 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -338,7 +343,8 @@ public void shouldNotEscapeSamlEntities() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -389,7 +395,8 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -436,7 +443,8 @@ public void testMetadataBody() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -501,7 +509,8 @@ public void unsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -552,7 +561,8 @@ public void badUnsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); } @@ -584,7 +594,8 @@ public void wrongCertTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -613,7 +624,8 @@ public void noSignatureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -646,7 +658,8 @@ public void rolesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -693,7 +706,8 @@ public void idpEndpointWithQueryStringTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -747,7 +761,8 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -850,7 +865,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); Assert.assertNull(restChannel.response); @@ -870,7 +886,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -893,7 +910,8 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); diff --git a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java index 827bfa24b6..69ddc5c03a 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java @@ -20,28 +20,27 @@ import org.junit.Test; import org.opensearch.common.settings.Settings; -import org.opensearch.security.user.AuthCredentials; + +import java.net.InetAddress; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class AddressBasedRateLimiterTest { - private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; - @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); + AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertTrue(rateLimiter.isBlocked("a")); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); } } diff --git a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java index e42d2bd1b8..a8285c42a7 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java @@ -17,30 +17,31 @@ package org.opensearch.security.auth.limiting; -import java.net.InetAddress; - import org.junit.Test; import org.opensearch.common.settings.Settings; +import org.opensearch.security.user.AuthCredentials; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class UserNameBasedRateLimiterTest { + private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; + @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); + UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertTrue(rateLimiter.isBlocked("a")); } } diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java index 55c2e789c6..37ac45080b 100644 --- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java +++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java @@ -16,7 +16,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -39,8 +39,8 @@ public AuthCredentials extractCredentials(RestRequest request, ThreadContext con } @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials credentials) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } public static long getCount() {