diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 44bff5c73e..125bbed073 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -39,6 +39,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; import org.opensearch.security.auth.HTTPAuthenticator; @@ -284,10 +285,12 @@ public GSSCredential run() throws GSSException { public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) { final Map headers = new HashMap<>(); String responseBody = ""; + String contentType = null; + SecurityResponse response; final String negotiateResponseBody = getNegotiateResponseBody(); if (negotiateResponseBody != null) { responseBody = negotiateResponseBody; - headers.putAll(SecurityResponse.CONTENT_TYPE_APP_JSON); + contentType = XContentType.JSON.mediaType(); } if (creds == null || creds.getNativeCredentials() == null) { @@ -296,7 +299,12 @@ public Optional reRequestAuthentication(final SecurityRequest headers.put("WWW-Authenticate", "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials())); } - return Optional.of(new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody)); + if (contentType != null) { + response = new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody, contentType); + } else { + response = new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody); + } + return Optional.of(response); } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 41e9305ba6..32e01b9e2f 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -226,10 +226,10 @@ private Optional handleLowLevel(RestRequest restRequest) throw String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString)); + return Optional.of(new SecurityResponse(HttpStatus.SC_OK, null, responseBodyString, XContentType.JSON.mediaType())); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, new Exception("JSON could not be parsed"))); + return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, "JSON could not be parsed")); } } diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 3f6aae0720..3ab9a2afc9 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -202,7 +202,7 @@ public boolean authenticate(final SecurityRequestChannel request) { log.debug("Rejecting REST request because of blocked address: {}", request.getRemoteAddress().orElse(null)); } - request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, new Exception("Authentication finally failed"))); + request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, "Authentication finally failed")); return false; } @@ -224,7 +224,7 @@ public boolean authenticate(final SecurityRequestChannel request) { if (!isInitialized()) { log.error("Not yet initialized (you may need to run securityadmin)"); - request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, new Exception("OpenSearch Security not initialized."))); + request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, "OpenSearch Security not initialized.")); return false; } @@ -354,11 +354,7 @@ public boolean authenticate(final SecurityRequestChannel request) { log.error("Cannot authenticate rest user because admin user is not permitted to login via HTTP"); auditLog.logFailedLogin(authenticatedUser.getName(), true, null, request); request.queueForSending( - new SecurityResponse( - SC_FORBIDDEN, - null, - "Cannot authenticate user because admin user is not permitted to login via HTTP" - ) + new SecurityResponse(SC_FORBIDDEN, "Cannot authenticate user because admin user is not permitted to login via HTTP") ); return false; } @@ -429,7 +425,7 @@ public boolean authenticate(final SecurityRequestChannel request) { notifyIpAuthFailureListeners(request, authCredentials); request.queueForSending( - challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")) + challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, "Authentication finally failed")) ); return false; } diff --git a/src/main/java/org/opensearch/security/filter/SecurityResponse.java b/src/main/java/org/opensearch/security/filter/SecurityResponse.java index 0dc833a440..5041936d2e 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityResponse.java +++ b/src/main/java/org/opensearch/security/filter/SecurityResponse.java @@ -12,11 +12,15 @@ package org.opensearch.security.filter; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.http.HttpHeaders; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.rest.RestStatus; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestResponse; @@ -26,26 +30,63 @@ public class SecurityResponse { public static final Map CONTENT_TYPE_APP_JSON = Map.of(HttpHeaders.CONTENT_TYPE, "application/json"); private final int status; - private final Map headers; + private Map> headers; private final String body; + private final String contentType; public SecurityResponse(final int status, final Exception e) { this.status = status; - this.headers = CONTENT_TYPE_APP_JSON; + populateHeaders(CONTENT_TYPE_APP_JSON); this.body = generateFailureMessage(e); + this.contentType = XContentType.JSON.mediaType(); + } + + public SecurityResponse(final int status, String body) { + this.status = status; + this.body = body; + this.contentType = null; } public SecurityResponse(final int status, final Map headers, final String body) { this.status = status; - this.headers = headers; + populateHeaders(headers); + this.body = body; + this.contentType = null; + } + + public SecurityResponse(final int status, final Map headers, final String body, String contentType) { + this.status = status; this.body = body; + this.contentType = contentType; + populateHeaders(headers); + } + + private void populateHeaders(Map headers) { + if (headers != null) { + headers.entrySet().forEach(entry -> addHeader(entry.getKey(), entry.getValue())); + } + } + + /** + * Add a custom header. + */ + public void addHeader(String name, String value) { + if (headers == null) { + headers = new HashMap<>(2); + } + List header = headers.get(name); + if (header == null) { + header = new ArrayList<>(); + headers.put(name, header); + } + header.add(value); } public int getStatus() { return status; } - public Map getHeaders() { + public Map> getHeaders() { return headers; } @@ -54,9 +95,14 @@ public String getBody() { } public RestResponse asRestResponse() { - final RestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody()); + final RestResponse restResponse; + if (this.contentType != null) { + restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), this.contentType, getBody()); + } else { + restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody()); + } if (getHeaders() != null) { - getHeaders().forEach(restResponse::addHeader); + getHeaders().entrySet().forEach(entry -> { entry.getValue().forEach(value -> restResponse.addHeader(entry.getKey(), value)); }); } return restResponse; } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java index e4d087cfe3..d52c5109fc 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java @@ -245,7 +245,7 @@ void authorizeRequest(RestHandler original, SecurityRequestChannel request, User } log.debug(err); - request.queueForSending(new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, null, err)); + request.queueForSending(new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, err)); return; } } @@ -288,7 +288,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t } catch (SSLPeerUnverifiedException e) { log.error("No ssl info", e); auditLog.logSSLException(requestChannel, e); - requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, new Exception("No ssl info"))); + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, e)); return; } diff --git a/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java b/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java index 63d9186e1f..2a25ad8795 100644 --- a/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java +++ b/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java @@ -20,6 +20,7 @@ import org.apache.http.HttpStatus; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.filter.SecurityResponse; @@ -113,7 +114,7 @@ public Optional checkRequestIsAllowed(final SecurityRequest re // if allowlisting is enabled but the request is not allowlisted, then return false, otherwise true. if (this.enabled && !requestIsAllowlisted(request)) { return Optional.of( - new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request)) + new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, generateFailureMessage(request), XContentType.JSON.mediaType()) ); } return Optional.empty(); diff --git a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java index ce643477c2..4cc16a7f00 100644 --- a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java +++ b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java @@ -18,6 +18,7 @@ import org.apache.http.HttpStatus; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.filter.SecurityResponse; @@ -111,7 +112,7 @@ public Optional checkRequestIsAllowed(final SecurityRequest re // if whitelisting is enabled but the request is not whitelisted, then return false, otherwise true. if (this.enabled && !requestIsWhitelisted(request)) { return Optional.of( - new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request)) + new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, generateFailureMessage(request), XContentType.JSON.mediaType()) ); } return Optional.empty(); diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index c76a1b546d..bba2ee8b5c 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -887,7 +887,7 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow(); - String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate"); + String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate").get(0); Assert.assertNotNull(wwwAuthenticateHeader); diff --git a/src/test/java/org/opensearch/security/filter/SecurityResponseTests.java b/src/test/java/org/opensearch/security/filter/SecurityResponseTests.java new file mode 100644 index 0000000000..7735a8a7cd --- /dev/null +++ b/src/test/java/org/opensearch/security/filter/SecurityResponseTests.java @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.http.HttpStatus; +import org.junit.Test; + +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +public class SecurityResponseTests { + + /** + * This test should check whether a basic constructor with the JSON content type is successfully converted to RestResponse + */ + @Test + public void testSecurityResponseHasSingleContentType() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", XContentType.JSON.mediaType()); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + assertThat(restResponse.contentType(), equalTo(XContentType.JSON.mediaType())); + } + + /** + * This test should check whether adding a new HTTP Header for the content type takes the argument or the added header (should take arg.) + */ + @Test + public void testSecurityResponseMultipleContentTypesUsesPassed() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", XContentType.JSON.mediaType()); + response.addHeader(HttpHeaders.CONTENT_TYPE, BytesRestResponse.TEXT_CONTENT_TYPE); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of(BytesRestResponse.TEXT_CONTENT_TYPE))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(XContentType.JSON.mediaType())); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test should check whether specifying no content type correctly uses plain text + */ + @Test + public void testSecurityResponseDefaultContentTypeIsText() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test checks whether adding a new ContentType header actually changes the converted content type header (it should not) + */ + @Test + public void testSecurityResponseSetHeaderContentTypeDoesNothing() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar"); + response.addHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test should check whether adding a multiple new HTTP Headers for the content type takes the argument or the added header (should take arg.) + */ + @Test + public void testSecurityResponseAddMultipleContentTypeHeaders() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", XContentType.JSON.mediaType()); + response.addHeader(HttpHeaders.CONTENT_TYPE, BytesRestResponse.TEXT_CONTENT_TYPE); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of(BytesRestResponse.TEXT_CONTENT_TYPE))); + response.addHeader(HttpHeaders.CONTENT_TYPE, "newContentType"); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of(BytesRestResponse.TEXT_CONTENT_TYPE, "newContentType"))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test confirms that fake content types work for conversion + */ + @Test + public void testSecurityResponseFakeContentTypeArgumentPasses() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", "testType"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo("testType")); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test checks that types passed as part of the Headers parameter in the argument do not overwrite actual Content Type + */ + @Test + public void testSecurityResponseContentTypeInConstructorHeader() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, Map.of("Content-Type", "testType"), "foo bar"); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of("testType"))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test confirms the same as above but with a conflicting content type arg + */ + @Test + public void testSecurityResponseContentTypeInConstructorHeaderConflicts() { + final SecurityResponse response = new SecurityResponse( + HttpStatus.SC_OK, + Map.of("Content-Type", "testType"), + "foo bar", + XContentType.JSON.mediaType() + ); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of("testType"))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(XContentType.JSON.mediaType())); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test should check whether unauthorized requests are converted properly + */ + @Test + public void testSecurityResponseUnauthorizedRequestWithPlainTextContentType() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, null, "foo bar"); + response.addHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.UNAUTHORIZED)); + } + + /** + * This test should check whether forbidden requests are converted properly + */ + @Test + public void testSecurityResponseForbiddenRequestWithPlainTextContentType() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, "foo bar"); + response.addHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.FORBIDDEN)); + } +}