From 22f44d3bccf93a4af2562d26694094ec7151669f Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Sat, 30 Sep 2023 04:11:10 +0530 Subject: [PATCH 1/8] Improve serialization speeds (#2802) Use custom serialization in security plugin. - Resolves https://github.com/opensearch-project/security/issues/2780 Signed-off-by: Paras Jain Signed-off-by: Peter Nied Co-authored-by: Paras Jain Co-authored-by: Peter Nied --- .github/workflows/ci.yml | 18 + bwc-test/build.gradle | 4 + .../SecurityBackwardsCompatibilityIT.java | 205 ---------- .../opensearch/security/bwc/ClusterType.java | 28 ++ .../bwc/SecurityBackwardsCompatibilityIT.java | 367 ++++++++++++++++++ .../org/opensearch/security/bwc/Song.java | 117 ++++++ .../security/bwc/helper/RestHelper.java | 90 +++++ .../com/amazon/dlic/auth/ldap/LdapUser.java | 15 + .../auditlog/impl/AbstractAuditLog.java | 6 +- .../security/auth/UserInjector.java | 15 +- .../configuration/DlsFlsValveImpl.java | 16 +- .../security/filter/SecurityFilter.java | 4 + .../transport/SecuritySSLRequestHandler.java | 6 + .../security/support/Base64CustomHelper.java | 225 +++++++++++ .../security/support/Base64Helper.java | 183 ++------- .../security/support/Base64JDKHelper.java | 156 ++++++++ .../security/support/ConfigConstants.java | 5 + .../security/support/HeaderHelper.java | 16 +- .../support/SafeSerializationUtils.java | 81 ++++ .../security/support/SourceFieldsContext.java | 26 +- .../security/support/StreamableRegistry.java | 134 +++++++ .../transport/SecurityInterceptor.java | 31 +- .../transport/SecurityRequestHandler.java | 6 +- .../org/opensearch/security/user/User.java | 9 +- .../support/Base64CustomHelperTest.java | 159 ++++++++ .../security/support/Base64HelperTest.java | 90 +---- .../security/support/Base64JDKHelperTest.java | 128 ++++++ .../support/StreamableRegistryTest.java | 29 ++ .../transport/SecurityInterceptorTests.java | 65 +++- .../SecuritySSLRequestHandlerTests.java | 80 ++++ 30 files changed, 1842 insertions(+), 472 deletions(-) delete mode 100644 bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java create mode 100644 bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java create mode 100644 bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java create mode 100644 bwc-test/src/test/java/org/opensearch/security/bwc/Song.java create mode 100644 bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java create mode 100644 src/main/java/org/opensearch/security/support/Base64CustomHelper.java create mode 100644 src/main/java/org/opensearch/security/support/Base64JDKHelper.java create mode 100644 src/main/java/org/opensearch/security/support/SafeSerializationUtils.java create mode 100644 src/main/java/org/opensearch/security/support/StreamableRegistry.java create mode 100644 src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java create mode 100644 src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java create mode 100644 src/test/java/org/opensearch/security/support/StreamableRegistryTest.java create mode 100644 src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d5d5e3430d..9e58ff4b4f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,6 +108,24 @@ jobs: arguments: | integrationTest -Dbuild.snapshot=false + backward-compatibility-build: + runs-on: ubuntu-latest + steps: + - uses: actions/setup-java@v3 + with: + distribution: temurin # Temurin is a distribution of adoptium + java-version: 17 + + - name: Checkout Security Repo + uses: actions/checkout@v4 + + - name: Build BWC tests + uses: gradle/gradle-build-action@v2 + with: + cache-disabled: true + arguments: | + -p bwc-test build -x test -x integTest + backward-compatibility: strategy: fail-fast: false diff --git a/bwc-test/build.gradle b/bwc-test/build.gradle index 24cc645ba1..6fb7fc2348 100644 --- a/bwc-test/build.gradle +++ b/bwc-test/build.gradle @@ -47,6 +47,7 @@ buildscript { opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") opensearch_group = "org.opensearch" common_utils_version = System.getProperty("common_utils.version", '2.9.0.0-SNAPSHOT') + jackson_version = System.getProperty("jackson_version", "2.15.2") } repositories { mavenLocal() @@ -72,6 +73,9 @@ dependencies { testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" testImplementation "org.opensearch:common-utils:${common_utils_version}" + testImplementation "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + testImplementation "com.fasterxml.jackson.core:jackson-annotations:${jackson_version}" + } loggerUsageCheck.enabled = false diff --git a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java deleted file mode 100644 index 3758b43265..0000000000 --- a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.security.bwc; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -import org.apache.hc.client5.http.auth.AuthScope; -import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; -import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; -import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; -import org.apache.hc.client5.http.nio.AsyncClientConnectionManager; -import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; -import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; -import org.apache.hc.core5.function.Factory; -import org.apache.hc.core5.http.Header; -import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.message.BasicHeader; -import org.apache.hc.core5.http.nio.ssl.TlsStrategy; -import org.apache.hc.core5.reactor.ssl.TlsDetails; -import org.apache.hc.core5.ssl.SSLContextBuilder; -import org.junit.Assume; -import org.junit.Before; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.test.rest.OpenSearchRestTestCase; - -import org.opensearch.Version; -import org.opensearch.common.settings.Settings; -import org.opensearch.test.rest.OpenSearchRestTestCase; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.hasItem; - -import org.opensearch.client.RestClient; -import org.opensearch.client.RestClientBuilder; - -import org.junit.Assert; - -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLEngine; - -public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase { - - private ClusterType CLUSTER_TYPE; - private String CLUSTER_NAME; - - @Before - private void testSetup() { - final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite"); - Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null); - CLUSTER_TYPE = ClusterType.parse(bwcsuiteString); - CLUSTER_NAME = System.getProperty("tests.clustername"); - } - - @Override - protected final boolean preserveClusterUponCompletion() { - return true; - } - - @Override - protected final boolean preserveIndicesUponCompletion() { - return true; - } - - @Override - protected final boolean preserveReposUponCompletion() { - return true; - } - - @Override - protected boolean preserveTemplatesUponCompletion() { - return true; - } - - @Override - protected String getProtocol() { - return "https"; - } - - @Override - protected final Settings restClientSettings() { - return Settings.builder() - .put(super.restClientSettings()) - // increase the timeout here to 90 seconds to handle long waits for a green - // cluster health. the waits for green need to be longer than a minute to - // account for delayed shards - .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") - .build(); - } - - @Override - protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { - RestClientBuilder builder = RestClient.builder(hosts); - configureHttpsClient(builder, settings); - boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); - builder.setStrictDeprecationMode(strictDeprecationMode); - return builder.build(); - } - - protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { - Map headers = ThreadContext.buildDefaultHeaders(settings); - Header[] defaultHeaders = new Header[headers.size()]; - int i = 0; - for (Map.Entry entry : headers.entrySet()) { - defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); - } - builder.setDefaultHeaders(defaultHeaders); - builder.setHttpClientConfigCallback(httpClientBuilder -> { - String userName = Optional.ofNullable(System.getProperty("tests.opensearch.username")) - .orElseThrow(() -> new RuntimeException("user name is missing")); - String password = Optional.ofNullable(System.getProperty("tests.opensearch.password")) - .orElseThrow(() -> new RuntimeException("password is missing")); - BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials(new AuthScope(null, -1), new UsernamePasswordCredentials(userName, password.toCharArray())); - try { - SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build(); - - TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() - .setSslContext(sslContext) - .setTlsVersions(new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "SSLv3" }) - .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) - // See please https://issues.apache.org/jira/browse/HTTPCLIENT-2219 - .setTlsDetailsFactory(new Factory() { - @Override - public TlsDetails create(final SSLEngine sslEngine) { - return new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol()); - } - }) - .build(); - - final AsyncClientConnectionManager cm = PoolingAsyncClientConnectionManagerBuilder.create() - .setTlsStrategy(tlsStrategy) - .build(); - return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(cm); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - } - - public void testBasicBackwardsCompatibility() throws Exception { - String round = System.getProperty("tests.rest.bwcsuite_round"); - - if (round.equals("first") || round.equals("old")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-0/plugins"); - } else if (round.equals("second")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-1/plugins"); - } else if (round.equals("third")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-2/plugins"); - } - } - - @SuppressWarnings("unchecked") - public void testWhoAmI() throws Exception { - Map responseMap = (Map) getAsMap("_plugins/_security/whoami"); - Assert.assertTrue(responseMap.containsKey("dn")); - } - - private enum ClusterType { - OLD, - MIXED, - UPGRADED; - - public static ClusterType parse(String value) { - switch (value) { - case "old_cluster": - return OLD; - case "mixed_cluster": - return MIXED; - case "upgraded_cluster": - return UPGRADED; - default: - throw new AssertionError("unknown cluster type: " + value); - } - } - } - - @SuppressWarnings("unchecked") - private void assertPluginUpgrade(String uri) throws Exception { - Map> responseMap = (Map>) getAsMap(uri).get("nodes"); - for (Map response : responseMap.values()) { - List> plugins = (List>) response.get("plugins"); - Set pluginNames = plugins.stream().map(map -> (String) map.get("name")).collect(Collectors.toSet()); - - final Version minNodeVersion = this.minimumNodeVersion(); - - if (minNodeVersion.major <= 1) { - assertThat(pluginNames, hasItem("opensearch_security")); - } else { - assertThat(pluginNames, hasItem("opensearch-security")); - } - - } - } -} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java b/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java new file mode 100644 index 0000000000..7fe849d5b3 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.bwc; + +public enum ClusterType { + OLD, + MIXED, + UPGRADED; + + public static ClusterType parse(String value) { + switch (value) { + case "old_cluster": + return OLD; + case "mixed_cluster": + return MIXED; + case "upgraded_cluster": + return UPGRADED; + default: + throw new AssertionError("unknown cluster type: " + value); + } + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java new file mode 100644 index 0000000000..1647dbb132 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java @@ -0,0 +1,367 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.security.bwc; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.net.ssl.SSLContext; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.nio.AsyncClientConnectionManager; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.ssl.TlsStrategy; +import org.apache.hc.core5.reactor.ssl.TlsDetails; +import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.Randomness; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.security.bwc.helper.RestHelper; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.Version; + +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; + +public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase { + + private ClusterType CLUSTER_TYPE; + private String CLUSTER_NAME; + + private final String TEST_USER = "user"; + private final String TEST_PASSWORD = "290735c0-355d-4aaf-9b42-1aaa1f2a3cee"; + private final String TEST_ROLE = "test-dls-fls-role"; + private static RestClient testUserRestClient = null; + + @Before + public void testSetup() { + final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite"); + Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null); + CLUSTER_TYPE = ClusterType.parse(bwcsuiteString); + CLUSTER_NAME = System.getProperty("tests.clustername"); + if (testUserRestClient == null) { + testUserRestClient = buildClient( + super.restClientSettings(), + super.getClusterHosts().toArray(new HttpHost[0]), + TEST_USER, + TEST_PASSWORD + ); + } + } + + @Override + protected final boolean preserveClusterUponCompletion() { + return true; + } + + @Override + protected final boolean preserveIndicesUponCompletion() { + return true; + } + + @Override + protected final boolean preserveReposUponCompletion() { + return true; + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + @Override + protected String getProtocol() { + return "https"; + } + + @Override + protected final Settings restClientSettings() { + return Settings.builder() + .put(super.restClientSettings()) + // increase the timeout here to 90 seconds to handle long waits for a green + // cluster health. the waits for green need to be longer than a minute to + // account for delayed shards + .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") + .build(); + } + + protected RestClient buildClient(Settings settings, HttpHost[] hosts, String username, String password) { + RestClientBuilder builder = RestClient.builder(hosts); + configureHttpsClient(builder, settings, username, password); + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) { + String username = Optional.ofNullable(System.getProperty("tests.opensearch.username")) + .orElseThrow(() -> new RuntimeException("user name is missing")); + String password = Optional.ofNullable(System.getProperty("tests.opensearch.password")) + .orElseThrow(() -> new RuntimeException("password is missing")); + return buildClient(super.restClientSettings(), super.getClusterHosts().toArray(new HttpHost[0]), username, password); + } + + private static void configureHttpsClient(RestClientBuilder builder, Settings settings, String userName, String password) { + Map headers = ThreadContext.buildDefaultHeaders(settings); + Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(new AuthScope(null, -1), new UsernamePasswordCredentials(userName, password.toCharArray())); + try { + SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build(); + + TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() + .setSslContext(sslContext) + .setTlsVersions(new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "SSLv3" }) + .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + // See please https://issues.apache.org/jira/browse/HTTPCLIENT-2219 + .setTlsDetailsFactory(sslEngine -> new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol())) + .build(); + + final AsyncClientConnectionManager cm = PoolingAsyncClientConnectionManagerBuilder.create() + .setTlsStrategy(tlsStrategy) + .build(); + return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(cm); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + public void testWhoAmI() throws Exception { + Map responseMap = getAsMap("_plugins/_security/whoami"); + assertThat(responseMap, hasKey("dn")); + } + + public void testBasicBackwardsCompatibility() throws Exception { + String round = System.getProperty("tests.rest.bwcsuite_round"); + + if (round.equals("first") || round.equals("old")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-0/plugins"); + } else if (round.equals("second")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-1/plugins"); + } else if (round.equals("third")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-2/plugins"); + } + } + + /** + * Tests backward compatibility by created a test user and role with DLS, FLS and masked field settings. Ingests + * data into a test index and runs a matchAll query against the same. + */ + public void testDataIngestionAndSearchBackwardsCompatibility() throws Exception { + String round = System.getProperty("tests.rest.bwcsuite_round"); + String index = "test_index"; + if (round.equals("old")) { + createTestRoleIfNotExists(TEST_ROLE); + createUserIfNotExists(TEST_USER, TEST_PASSWORD, TEST_ROLE); + createIndexIfNotExists(index); + } + ingestData(index); + searchMatchAll(index); + } + + public void testNodeStats() throws IOException { + List responses = RestHelper.requestAgainstAllNodes(client(), "GET", "_nodes/stats", null); + responses.forEach(r -> Assert.assertEquals(200, r.getStatusLine().getStatusCode())); + } + + @SuppressWarnings("unchecked") + private void assertPluginUpgrade(String uri) throws Exception { + Map> responseMap = (Map>) getAsMap(uri).get("nodes"); + for (Map response : responseMap.values()) { + List> plugins = (List>) response.get("plugins"); + Set pluginNames = plugins.stream().map(map -> (String) map.get("name")).collect(Collectors.toSet()); + + final Version minNodeVersion = minimumNodeVersion(); + + if (minNodeVersion.major <= 1) { + assertThat(pluginNames, hasItem("opensearch_security")); // With underscore seperator + } else { + assertThat(pluginNames, hasItem("opensearch-security")); // With dash seperator + } + } + } + + /** + * Ingests data into the test index + * @param index index to ingest data into + */ + + private void ingestData(String index) throws IOException { + StringBuilder bulkRequestBody = new StringBuilder(); + ObjectMapper objectMapper = new ObjectMapper(); + int numberOfRequests = Randomness.get().nextInt(10); + while (numberOfRequests-- > 0) { + for (int i = 0; i < Randomness.get().nextInt(100); i++) { + Map> indexRequest = new HashMap<>(); + indexRequest.put("index", new HashMap<>() { + { + put("_index", index); + } + }); + bulkRequestBody.append(objectMapper.writeValueAsString(indexRequest) + "\n"); + bulkRequestBody.append(objectMapper.writeValueAsString(Song.randomSong().asJson()) + "\n"); + } + List responses = RestHelper.requestAgainstAllNodes( + testUserRestClient, + "POST", + "_bulk?refresh=wait_for", + RestHelper.toHttpEntity(bulkRequestBody.toString()) + ); + responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode())); + } + } + + /** + * Runs a matchAll query against the test index + * @param index index to search + */ + private void searchMatchAll(String index) throws IOException { + String matchAllQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; + int numberOfRequests = Randomness.get().nextInt(10); + while (numberOfRequests-- > 0) { + List responses = RestHelper.requestAgainstAllNodes( + testUserRestClient, + "POST", + index + "/_search", + RestHelper.toHttpEntity(matchAllQuery) + ); + responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode())); + } + } + + /** + * Checks if a resource at the specified URL exists + * @param url of the resource to be checked for existence + * @return true if the resource exists, false otherwise + */ + + private boolean resourceExists(String url) throws IOException { + try { + RestHelper.get(adminClient(), url); + return true; + } catch (ResponseException e) { + if (e.getResponse().getStatusLine().getStatusCode() == 404) { + return false; + } else { + throw e; + } + } + } + + /** + * Creates a test role with DLS, FLS and masked field settings on the test index. + */ + private void createTestRoleIfNotExists(String role) throws IOException { + String url = "_plugins/_security/api/roles/" + role; + String roleSettings = "{\n" + + " \"cluster_permissions\": [\n" + + " \"unlimited\"\n" + + " ],\n" + + " \"index_permissions\": [\n" + + " {\n" + + " \"index_patterns\": [\n" + + " \"test_index*\"\n" + + " ],\n" + + " \"dls\": \"{ \\\"bool\\\": { \\\"must\\\": { \\\"match\\\": { \\\"genre\\\": \\\"rock\\\" } } } }\",\n" + + " \"fls\": [\n" + + " \"~lyrics\"\n" + + " ],\n" + + " \"masked_fields\": [\n" + + " \"artist\"\n" + + " ],\n" + + " \"allowed_actions\": [\n" + + " \"read\",\n" + + " \"write\"\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"tenant_permissions\": []\n" + + "}\n"; + Response response = RestHelper.makeRequest(adminClient(), "PUT", url, RestHelper.toHttpEntity(roleSettings)); + + assertThat(response.getStatusLine().getStatusCode(), anyOf(equalTo(200), equalTo(201))); + } + + /** + * Creates a test index if it does not exist already + * @param index index to create + */ + + private void createIndexIfNotExists(String index) throws IOException { + String settings = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"number_of_shards\": 3,\n" + + " \"number_of_replicas\": 1\n" + + " }\n" + + " }\n" + + "}"; + if (!resourceExists(index)) { + Response response = RestHelper.makeRequest(client(), "PUT", index, RestHelper.toHttpEntity(settings)); + assertThat(response.getStatusLine().getStatusCode(), equalTo(200)); + } + } + + /** + * Creates the test user if it does not exist already and maps it to the test role with DLS/FLS settings. + * @param user user to be created + * @param password password for the new user + * @param role roles that the user has to be mapped to + */ + private void createUserIfNotExists(String user, String password, String role) throws IOException { + String url = "_plugins/_security/api/internalusers/" + user; + if (!resourceExists(url)) { + String userSettings = String.format( + Locale.ENGLISH, + "{\n" + " \"password\": \"%s\",\n" + " \"opendistro_security_roles\": [\"%s\"],\n" + " \"backend_roles\": []\n" + "}", + password, + role + ); + Response response = RestHelper.makeRequest(adminClient(), "PUT", url, RestHelper.toHttpEntity(userSettings)); + assertThat(response.getStatusLine().getStatusCode(), equalTo(201)); + } + } + + @AfterClass + public static void cleanUp() throws IOException { + OpenSearchRestTestCase.closeClients(); + IOUtils.close(testUserRestClient); + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java b/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java new file mode 100644 index 0000000000..3cfd2c03e8 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java @@ -0,0 +1,117 @@ +/* +* Copyright OpenSearch Contributors +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +* +*/ +package org.opensearch.security.bwc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.common.Randomness; + +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +public class Song { + + public static final String FIELD_TITLE = "title"; + public static final String FIELD_ARTIST = "artist"; + public static final String FIELD_LYRICS = "lyrics"; + public static final String FIELD_STARS = "stars"; + public static final String FIELD_GENRE = "genre"; + public static final String ARTIST_FIRST = "First artist"; + public static final String ARTIST_STRING = "String"; + public static final String ARTIST_TWINS = "Twins"; + public static final String TITLE_MAGNUM_OPUS = "Magnum Opus"; + public static final String TITLE_SONG_1_PLUS_1 = "Song 1+1"; + public static final String TITLE_NEXT_SONG = "Next song"; + public static final String ARTIST_NO = "No!"; + public static final String TITLE_POISON = "Poison"; + + public static final String ARTIST_YES = "yes"; + + public static final String TITLE_AFFIRMATIVE = "Affirmative"; + + public static final String ARTIST_UNKNOWN = "unknown"; + public static final String TITLE_CONFIDENTIAL = "confidential"; + + public static final String LYRICS_1 = "Very deep subject"; + public static final String LYRICS_2 = "Once upon a time"; + public static final String LYRICS_3 = "giant nonsense"; + public static final String LYRICS_4 = "Much too much"; + public static final String LYRICS_5 = "Little to little"; + public static final String LYRICS_6 = "confidential secret classified"; + + public static final String GENRE_ROCK = "rock"; + public static final String GENRE_JAZZ = "jazz"; + public static final String GENRE_BLUES = "blues"; + + public static final String QUERY_TITLE_NEXT_SONG = FIELD_TITLE + ":" + "\"" + TITLE_NEXT_SONG + "\""; + public static final String QUERY_TITLE_POISON = FIELD_TITLE + ":" + TITLE_POISON; + public static final String QUERY_TITLE_MAGNUM_OPUS = FIELD_TITLE + ":" + TITLE_MAGNUM_OPUS; + + public static final Song[] SONGS = { + new Song(ARTIST_FIRST, TITLE_MAGNUM_OPUS, LYRICS_1, 1, GENRE_ROCK), + new Song(ARTIST_STRING, TITLE_SONG_1_PLUS_1, LYRICS_2, 2, GENRE_BLUES), + new Song(ARTIST_TWINS, TITLE_NEXT_SONG, LYRICS_3, 3, GENRE_JAZZ), + new Song(ARTIST_NO, TITLE_POISON, LYRICS_4, 4, GENRE_ROCK), + new Song(ARTIST_YES, TITLE_AFFIRMATIVE, LYRICS_5, 5, GENRE_BLUES), + new Song(ARTIST_UNKNOWN, TITLE_CONFIDENTIAL, LYRICS_6, 6, GENRE_JAZZ) }; + + private final String artist; + private final String title; + private final String lyrics; + private final Integer stars; + private final String genre; + + public Song(String artist, String title, String lyrics, Integer stars, String genre) { + this.artist = Objects.requireNonNull(artist, "Artist is required"); + this.title = Objects.requireNonNull(title, "Title is required"); + this.lyrics = Objects.requireNonNull(lyrics, "Lyrics is required"); + this.stars = Objects.requireNonNull(stars, "Stars field is required"); + this.genre = Objects.requireNonNull(genre, "Genre field is required"); + } + + public String getArtist() { + return artist; + } + + public String getTitle() { + return title; + } + + public String getLyrics() { + return lyrics; + } + + public Integer getStars() { + return stars; + } + + public String getGenre() { + return genre; + } + + public Map asMap() { + return Map.of(FIELD_ARTIST, artist, FIELD_TITLE, title, FIELD_LYRICS, lyrics, FIELD_STARS, stars, FIELD_GENRE, genre); + } + + public String asJson() throws JsonProcessingException { + return new ObjectMapper().writeValueAsString(this.asMap()); + } + + public static Song randomSong() { + return new Song( + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + Randomness.get().nextInt(5), + UUID.randomUUID().toString() + ); + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java b/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java new file mode 100644 index 0000000000..3272ac736a --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.bwc.helper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; + +import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON; + +public class RestHelper { + + private static final Logger log = LogManager.getLogger(RestHelper.class); + + public static HttpEntity toHttpEntity(String jsonString) { + return new StringEntity(jsonString, APPLICATION_JSON); + } + + public static Response get(RestClient client, String url) throws IOException { + return makeRequest(client, "GET", url, null, null); + } + + public static Response makeRequest(RestClient client, String method, String endpoint, HttpEntity entity) throws IOException { + return makeRequest(client, method, endpoint, entity, null); + } + + public static Response makeRequest(RestClient client, String method, String endpoint, HttpEntity entity, List
headers) + throws IOException { + log.info("Making request " + method + " " + endpoint + ", with headers " + headers); + + Request request = new Request(method, endpoint); + + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(WarningsHandler.PERMISSIVE); + if (headers != null) { + headers.forEach(header -> options.addHeader(header.getName(), header.getValue())); + } + request.setOptions(options.build()); + + if (entity != null) { + request.setEntity(entity); + } + + Response response = client.performRequest(request); + log.info("Recieved response " + response.getStatusLine()); + return response; + } + + public static List requestAgainstAllNodes(RestClient client, String method, String endpoint, HttpEntity entity) + throws IOException { + return requestAgainstAllNodes(client, method, endpoint, entity, null); + } + + public static List requestAgainstAllNodes( + RestClient client, + String method, + String endpoint, + HttpEntity entity, + List
headers + ) throws IOException { + int nodeCount = client.getNodes().size(); + List responses = new ArrayList<>(); + while (nodeCount-- > 0) { + responses.add(makeRequest(client, method, endpoint, entity, headers)); + } + return responses; + } + + public static Header getAuthorizationHeader(String username, String password) { + return new BasicHeader("Authorization", "Basic " + username + ":" + password); + } +} diff --git a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java index 907d605860..f752ce4a49 100755 --- a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java +++ b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java @@ -11,6 +11,7 @@ package com.amazon.dlic.auth.ldap; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -20,6 +21,8 @@ import com.amazon.dlic.auth.ldap.util.Utils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.security.support.WildcardMatcher; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.user.User; @@ -45,6 +48,12 @@ public LdapUser( attributes.putAll(extractLdapAttributes(originalUsername, userEntry, customAttrMaxValueLen, allowlistedCustomLdapAttrMatcher)); } + public LdapUser(StreamInput in) throws IOException { + super(in); + userEntry = null; + originalUsername = in.readString(); + } + /** * May return null because ldapEntry is transient * @@ -88,4 +97,10 @@ public static Map extractLdapAttributes( } return Collections.unmodifiableMap(attributes); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(originalUsername); + } } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java index 804e0a2114..a8f511be97 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java @@ -773,7 +773,8 @@ private TransportAddress getRemoteAddress() { if (address == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) != null) { address = new TransportAddress( (InetSocketAddress) Base64Helper.deserializeObject( - threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) + threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER), + threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ) ); } @@ -784,7 +785,8 @@ private String getUser() { User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); if (user == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) != null) { user = (User) Base64Helper.deserializeObject( - threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) + threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), + threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ); } return user == null ? null : user.getName(); diff --git a/src/main/java/org/opensearch/security/auth/UserInjector.java b/src/main/java/org/opensearch/security/auth/UserInjector.java index 3e89a52e93..30df84ef5f 100644 --- a/src/main/java/org/opensearch/security/auth/UserInjector.java +++ b/src/main/java/org/opensearch/security/auth/UserInjector.java @@ -26,6 +26,7 @@ package org.opensearch.security.auth; +import java.io.IOException; import java.io.ObjectStreamException; import java.net.InetAddress; import java.net.UnknownHostException; @@ -36,6 +37,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.rest.RestRequest; @@ -63,13 +66,18 @@ public UserInjector(Settings settings, ThreadPool threadPool, AuditLog auditLog, } - static class InjectedUser extends User { + public static class InjectedUser extends User { private transient TransportAddress transportAddress; public InjectedUser(String name) { super(name); } + public InjectedUser(StreamInput in) throws IOException { + super(in); + this.setInjected(true); + } + private Object writeReplace() throws ObjectStreamException { User user = new User(getName()); user.addRoles(getRoles()); @@ -96,6 +104,11 @@ public void setTransportAddress(String addr) throws UnknownHostException, Illega this.transportAddress = new TransportAddress(iAdress, port); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } } public InjectedUser getInjectedUser() { diff --git a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java index 14eaed4e0d..b35137a35d 100644 --- a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java +++ b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java @@ -443,7 +443,8 @@ private void setDlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) != null) { Object deserializedDlsQueries = Base64Helper.deserializeObject( - threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ); if (!dlsQueries.equals(deserializedDlsQueries)) { throw new OpenSearchSecurityException( @@ -506,7 +507,10 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER) != null) { if (!maskedFieldsMap.equals( - Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER)) + Base64Helper.deserializeObject( + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) + ) )) { throw new OpenSearchSecurityException( ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER + " does not match (SG 901D)" @@ -542,7 +546,10 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) != null) { if (!flsFields.equals( - Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER)) + Base64Helper.deserializeObject( + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) + ) )) { throw new OpenSearchSecurityException( ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER @@ -550,7 +557,8 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) + flsFields + "---" + Base64Helper.deserializeObject( - threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ) ); } else { diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index 06f2fae397..f433a5857d 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -183,6 +183,10 @@ private void ap threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN, Origin.LOCAL.toString()); } + if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) { + threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false); + } + final ComplianceConfig complianceConfig = auditLog.getComplianceConfig(); if (complianceConfig != null && complianceConfig.isEnabled()) { attachSourceFieldContext(request); diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 0a1b94548e..c67579e30f 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -83,8 +83,14 @@ protected ThreadContext getThreadContext() { @Override public final void messageReceived(T request, TransportChannel channel, Task task) throws Exception { + ThreadContext threadContext = getThreadContext(); + threadContext.putTransient( + ConfigConstants.USE_JDK_SERIALIZATION, + channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + ); + if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { final Exception exception = ExceptionUtils.createBadHeaderException(); channel.sendResponse(exception); diff --git a/src/main/java/org/opensearch/security/support/Base64CustomHelper.java b/src/main/java/org/opensearch/security/support/Base64CustomHelper.java new file mode 100644 index 0000000000..dc66268fcd --- /dev/null +++ b/src/main/java/org/opensearch/security/support/Base64CustomHelper.java @@ -0,0 +1,225 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.io.BaseEncoding; +import org.opensearch.OpenSearchException; +import org.opensearch.common.Nullable; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.common.Strings; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import java.io.IOException; +import java.io.Serializable; + +import static org.opensearch.security.support.SafeSerializationUtils.prohibitUnsafeClasses; + +/** + * Provides support for Serialization/Deserialization of objects of supported classes into/from Base64 encoded stream + * using the OpenSearch's custom serialization protocol implemented by the StreamInput/StreamOutput classes. + */ +public class Base64CustomHelper { + + private enum CustomSerializationFormat { + + WRITEABLE(1), + STREAMABLE(2), + GENERIC(3); + + private final int id; + + CustomSerializationFormat(int id) { + this.id = id; + } + + static CustomSerializationFormat fromId(int id) { + switch (id) { + case 1: + return WRITEABLE; + case 2: + return STREAMABLE; + case 3: + return GENERIC; + default: + throw new IllegalArgumentException(String.format("%d is not a valid id", id)); + } + } + + } + + private static final BiMap, Integer> writeableClassToIdMap = HashBiMap.create(); + private static final StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + + static { + registerAllWriteables(); + } + + protected static String serializeObject(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + final BytesStreamOutput streamOutput = new SafeBytesStreamOutput(128); + Class clazz = object.getClass(); + try { + prohibitUnsafeClasses(clazz); + CustomSerializationFormat customSerializationFormat = getCustomSerializationMode(clazz); + switch (customSerializationFormat) { + case WRITEABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.WRITEABLE.id); + streamOutput.writeByte((byte) getWriteableClassID(clazz).intValue()); + ((Writeable) object).writeTo(streamOutput); + break; + case STREAMABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.STREAMABLE.id); + streamableRegistry.writeTo(streamOutput, object); + break; + case GENERIC: + streamOutput.writeByte((byte) CustomSerializationFormat.GENERIC.id); + streamOutput.writeGenericValue(object); + break; + default: + throw new IllegalArgumentException( + String.format("Could not determine custom serialization mode for class %s", clazz.getName()) + ); + } + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = streamOutput.bytes().toBytesRef().bytes; + streamOutput.close(); + return BaseEncoding.base64().encode(bytes); + } + + protected static Serializable deserializeObject(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty"); + final byte[] bytes = BaseEncoding.base64().decode(string); + Serializable obj = null; + try (final BytesStreamInput streamInput = new SafeBytesStreamInput(bytes)) { + CustomSerializationFormat serializationFormat = CustomSerializationFormat.fromId(streamInput.readByte()); + switch (serializationFormat) { + case WRITEABLE: + final int classId = streamInput.readByte(); + Class clazz = getWriteableClassFromId(classId); + obj = (Serializable) clazz.getConstructor(StreamInput.class).newInstance(streamInput); + break; + case STREAMABLE: + obj = (Serializable) streamableRegistry.readFrom(streamInput); + break; + case GENERIC: + obj = (Serializable) streamInput.readGenericValue(); + break; + default: + throw new IllegalArgumentException("Could not determine custom deserialization mode"); + } + prohibitUnsafeClasses(obj.getClass()); + return obj; + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + private static boolean isWriteable(Class clazz) { + return Writeable.class.isAssignableFrom(clazz); + } + + /** + * Returns integer ID for the registered Writeable class + *
+ * Protected for testing + */ + protected static Integer getWriteableClassID(Class clazz) { + if (!isWriteable(clazz)) { + throw new OpenSearchException("clazz should implement Writeable ", clazz); + } + if (!writeableClassToIdMap.containsKey(clazz)) { + throw new OpenSearchException("Writeable clazz not registered ", clazz); + } + return writeableClassToIdMap.get(clazz); + } + + private static Class getWriteableClassFromId(int id) { + return writeableClassToIdMap.inverse().get(id); + } + + /** + * Registers the given Writeable class for custom serialization by assigning an incrementing integer ID + * IDs are stored in a HashBiMap + * @param clazz class to be registered + */ + private static void registerWriteable(Class clazz) { + if (writeableClassToIdMap.containsKey(clazz)) { + throw new OpenSearchException("writeable clazz is already registered ", clazz.getName()); + } + int id = writeableClassToIdMap.size() + 1; + writeableClassToIdMap.put(clazz, id); + } + + /** + * Registers all Writeable classes for custom serialization support. + * Removing existing classes / changing order of registration will cause a breaking change in the serialization protocol + * as registerWriteable assigns an incrementing integer ID to each of the classes in the order it is called + * starting from 1. + *
+ * New classes can safely be added towards the end. + */ + private static void registerAllWriteables() { + registerWriteable(User.class); + registerWriteable(LdapUser.class); + registerWriteable(UserInjector.InjectedUser.class); + registerWriteable(SourceFieldsContext.class); + } + + private static CustomSerializationFormat getCustomSerializationMode(Class clazz) { + if (isWriteable(clazz)) { + return CustomSerializationFormat.WRITEABLE; + } else if (streamableRegistry.isStreamable(clazz)) { + return CustomSerializationFormat.STREAMABLE; + } else { + return CustomSerializationFormat.GENERIC; + } + } + + private static class SafeBytesStreamOutput extends BytesStreamOutput { + + public SafeBytesStreamOutput(int expectedSize) { + super(expectedSize); + } + + @Override + public void writeGenericValue(@Nullable Object value) throws IOException { + prohibitUnsafeClasses(value.getClass()); + super.writeGenericValue(value); + } + } + + private static class SafeBytesStreamInput extends BytesStreamInput { + + public SafeBytesStreamInput(byte[] bytes) { + super(bytes); + } + + @Override + public Object readGenericValue() throws IOException { + Object object = super.readGenericValue(); + prohibitUnsafeClasses(object.getClass()); + return object; + } + } +} diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index 836858decb..a5fbab8515 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -26,174 +26,47 @@ package org.opensearch.security.support; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InvalidClassException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.ObjectStreamClass; -import java.io.OutputStream; import java.io.Serializable; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Pattern; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.io.BaseEncoding; -import org.ldaptive.AbstractLdapBean; -import org.ldaptive.LdapAttribute; -import org.ldaptive.LdapEntry; -import org.ldaptive.SearchEntry; - -import com.amazon.dlic.auth.ldap.LdapUser; - -import org.opensearch.OpenSearchException; -import org.opensearch.SpecialPermission; -import org.opensearch.core.common.Strings; -import org.opensearch.security.user.User; public class Base64Helper { - private static final Set> SAFE_CLASSES = ImmutableSet.of( - String.class, - SocketAddress.class, - InetSocketAddress.class, - Pattern.class, - User.class, - SourceFieldsContext.class, - LdapUser.class, - SearchEntry.class, - LdapEntry.class, - AbstractLdapBean.class, - LdapAttribute.class - ); - - private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( - InetAddress.class, - Number.class, - Collection.class, - Map.class, - Enum.class - ); - - private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); - - private static boolean isSafeClass(Class cls) { - return cls.isArray() - || SAFE_CLASSES.contains(cls) - || SAFE_CLASS_NAMES.contains(cls.getName()) - || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); - } - - private final static class SafeObjectOutputStream extends ObjectOutputStream { - - private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); - - @SuppressWarnings("removal") - private static boolean checkSubstitutionPermission() { - SecurityManager sm = System.getSecurityManager(); - if (sm != null) { - try { - sm.checkPermission(new SpecialPermission()); - - AccessController.doPrivileged((PrivilegedAction) () -> { - AccessController.checkPermission(SUBSTITUTION_PERMISSION); - return null; - }); - } catch (SecurityException e) { - return false; - } - } - return true; - } - - static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { - try { - return useSafeObjectOutputStream ? new SafeObjectOutputStream(out) : new ObjectOutputStream(out); - } catch (SecurityException e) { - // As we try to create SafeObjectOutputStream only when necessary permissions are granted, we should - // not reach here, but if we do, we can still return ObjectOutputStream after resetting ByteArrayOutputStream - out.reset(); - return new ObjectOutputStream(out); - } - } - - @SuppressWarnings("removal") - private SafeObjectOutputStream(OutputStream out) throws IOException { - super(out); - - SecurityManager sm = System.getSecurityManager(); - if (sm != null) { - sm.checkPermission(new SpecialPermission()); - } - - AccessController.doPrivileged((PrivilegedAction) () -> enableReplaceObject(true)); - } - - @Override - protected Object replaceObject(Object obj) throws IOException { - Class clazz = obj.getClass(); - if (isSafeClass(clazz)) { - return obj; - } - throw new IOException("Unauthorized serialization attempt " + clazz.getName()); - } + public static String serializeObject(final Serializable object, final boolean useJDKSerialization) { + return useJDKSerialization ? Base64JDKHelper.serializeObject(object) : Base64CustomHelper.serializeObject(object); } public static String serializeObject(final Serializable object) { - - Preconditions.checkArgument(object != null, "object must not be null"); - - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try (final ObjectOutputStream out = SafeObjectOutputStream.create(bos)) { - out.writeObject(object); - } catch (final Exception e) { - throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); - } - final byte[] bytes = bos.toByteArray(); - return BaseEncoding.base64().encode(bytes); + return serializeObject(object, false); } public static Serializable deserializeObject(final String string) { - - Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty"); - - final byte[] bytes = BaseEncoding.base64().decode(string); - final ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - try (SafeObjectInputStream in = new SafeObjectInputStream(bis)) { - return (Serializable) in.readObject(); - } catch (final Exception e) { - throw new OpenSearchException(e); - } + return deserializeObject(string, false); } - private final static class SafeObjectInputStream extends ObjectInputStream { - - public SafeObjectInputStream(InputStream in) throws IOException { - super(in); - } - - @Override - protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - - Class clazz = super.resolveClass(desc); - if (isSafeClass(clazz)) { - return clazz; - } + public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) { + return useJDKDeserialization ? Base64JDKHelper.deserializeObject(string) : Base64CustomHelper.deserializeObject(string); + } - throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); + /** + * Ensures that the returned string is JDK serialized. + * + * If the supplied string is a custom serialized representation, will deserialize it and further serialize using + * JDK, otherwise returns the string as is. + * + * @param string original string, can be JDK or custom serialized + * @return jdk serialized string + */ + public static String ensureJDKSerialized(final String string) { + Serializable serializable; + try { + serializable = Base64Helper.deserializeObject(string, false); + } catch (Exception e) { + // We received an exception when de-serializing the given string. It is probably JDK serialized. + // Try to deserialize using JDK + Base64Helper.deserializeObject(string, true); + // Since we could deserialize the object using JDK, the string is already JDK serialized, return as is + return string; } + // If we see an exception now, we want the caller to see it - + return Base64Helper.serializeObject(serializable, true); } } diff --git a/src/main/java/org/opensearch/security/support/Base64JDKHelper.java b/src/main/java/org/opensearch/security/support/Base64JDKHelper.java new file mode 100644 index 0000000000..a4ab87d813 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/Base64JDKHelper.java @@ -0,0 +1,156 @@ +/* + * Copyright 2015-2018 _floragunn_ GmbH + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidClassException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.io.OutputStream; +import java.io.Serializable; +import java.security.AccessController; +import java.security.PrivilegedAction; + +import com.google.common.base.Preconditions; +import com.google.common.io.BaseEncoding; + +import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; +import org.opensearch.core.common.Strings; + +import static org.opensearch.security.support.SafeSerializationUtils.isSafeClass; + +/** + * Provides support for Serialization/Deserialization of objects of supported classes into/from Base64 encoded stream + * using JDK's in-built serialization protocol implemented by the ObjectOutputStream and ObjectInputStream classes. + */ +public class Base64JDKHelper { + + private final static class SafeObjectOutputStream extends ObjectOutputStream { + + private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); + + @SuppressWarnings("removal") + private static boolean checkSubstitutionPermission() { + SecurityManager sm = System.getSecurityManager(); + if (sm != null) { + try { + sm.checkPermission(new SpecialPermission()); + + AccessController.doPrivileged((PrivilegedAction) () -> { + AccessController.checkPermission(SUBSTITUTION_PERMISSION); + return null; + }); + } catch (SecurityException e) { + return false; + } + } + return true; + } + + static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { + try { + return useSafeObjectOutputStream ? new SafeObjectOutputStream(out) : new ObjectOutputStream(out); + } catch (SecurityException e) { + // As we try to create SafeObjectOutputStream only when necessary permissions are granted, we should + // not reach here, but if we do, we can still return ObjectOutputStream after resetting ByteArrayOutputStream + out.reset(); + return new ObjectOutputStream(out); + } + } + + @SuppressWarnings("removal") + private SafeObjectOutputStream(OutputStream out) throws IOException { + super(out); + + SecurityManager sm = System.getSecurityManager(); + if (sm != null) { + sm.checkPermission(new SpecialPermission()); + } + + AccessController.doPrivileged((PrivilegedAction) () -> enableReplaceObject(true)); + } + + @Override + protected Object replaceObject(Object obj) throws IOException { + Class clazz = obj.getClass(); + if (isSafeClass(clazz)) { + return obj; + } + throw new IOException("Unauthorized serialization attempt " + clazz.getName()); + } + } + + public static String serializeObject(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (final ObjectOutputStream out = SafeObjectOutputStream.create(bos)) { + out.writeObject(object); + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = bos.toByteArray(); + return BaseEncoding.base64().encode(bytes); + } + + public static Serializable deserializeObject(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty"); + + final byte[] bytes = BaseEncoding.base64().decode(string); + final ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + try (SafeObjectInputStream in = new SafeObjectInputStream(bis)) { + return (Serializable) in.readObject(); + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + private final static class SafeObjectInputStream extends ObjectInputStream { + + public SafeObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + + Class clazz = super.resolveClass(desc); + if (isSafeClass(clazz)) { + return clazz; + } + + throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); + } + } +} diff --git a/src/main/java/org/opensearch/security/support/ConfigConstants.java b/src/main/java/org/opensearch/security/support/ConfigConstants.java index 8317d65335..9ac73cd579 100644 --- a/src/main/java/org/opensearch/security/support/ConfigConstants.java +++ b/src/main/java/org/opensearch/security/support/ConfigConstants.java @@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.opensearch.Version; import org.opensearch.common.settings.Settings; import org.opensearch.security.auditlog.impl.AuditCategory; @@ -242,6 +243,7 @@ public class ConfigConstants { "opendistro_security.compliance.history.write.ignore_users"; public static final String OPENDISTRO_SECURITY_COMPLIANCE_HISTORY_EXTERNAL_CONFIG_ENABLED = "opendistro_security.compliance.history.external_config_enabled"; + public static final String OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT = OPENDISTRO_SECURITY_CONFIG_PREFIX + "source_field_context"; public static final String SECURITY_COMPLIANCE_DISABLE_ANONYMOUS_AUTHENTICATION = "plugins.security.compliance.disable_anonymous_authentication"; public static final String SECURITY_COMPLIANCE_IMMUTABLE_INDICES = "plugins.security.compliance.immutable_indices"; @@ -323,6 +325,9 @@ public enum RolesMappingResolution { public static final String TENANCY_GLOBAL_TENANT_NAME = "global"; public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = ""; + public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; + public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_3_0_0; + // On-behalf-of endpoints settings // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings public static final String EXTENSIONS_BWC_PLUGIN_MODE = "bwcPluginMode"; diff --git a/src/main/java/org/opensearch/security/support/HeaderHelper.java b/src/main/java/org/opensearch/security/support/HeaderHelper.java index e8d50346a8..bbb44664fa 100644 --- a/src/main/java/org/opensearch/security/support/HeaderHelper.java +++ b/src/main/java/org/opensearch/security/support/HeaderHelper.java @@ -27,6 +27,8 @@ package org.opensearch.security.support; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; import com.google.common.base.Strings; @@ -68,7 +70,7 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context final String objectAsBase64 = getSafeFromHeader(context, headerName); if (!Strings.isNullOrEmpty(objectAsBase64)) { - return Base64Helper.deserializeObject(objectAsBase64); + return Base64Helper.deserializeObject(objectAsBase64, context.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } return null; @@ -77,4 +79,16 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context public static boolean isTrustedClusterRequest(final ThreadContext context) { return context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_TRANSPORT_TRUSTED_CLUSTER_REQUEST) == Boolean.TRUE; } + + public static List getAllSerializedHeaderNames() { + return Arrays.asList( + ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT + ); + } } diff --git a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java new file mode 100644 index 0000000000..c980959f68 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.ldaptive.AbstractLdapBean; +import org.ldaptive.LdapAttribute; +import org.ldaptive.LdapEntry; +import org.ldaptive.SearchEntry; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * Provides functionality to verify if a class is categorised to be safe for serialization or + * deserialization by the security plugin. + *
+ * All methods are package private. + */ +public final class SafeSerializationUtils { + + private static final Set> SAFE_CLASSES = ImmutableSet.of( + String.class, + SocketAddress.class, + InetSocketAddress.class, + Pattern.class, + User.class, + UserInjector.InjectedUser.class, + SourceFieldsContext.class, + LdapUser.class, + SearchEntry.class, + LdapEntry.class, + AbstractLdapBean.class, + LdapAttribute.class + ); + + private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( + InetAddress.class, + Number.class, + Collection.class, + Map.class, + Enum.class + ); + + private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); + + static boolean isSafeClass(Class cls) { + return cls.isArray() + || SAFE_CLASSES.contains(cls) + || SAFE_CLASS_NAMES.contains(cls.getName()) + || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); + } + + static void prohibitUnsafeClasses(Class clazz) throws IOException { + if (!isSafeClass(clazz)) { + throw new IOException("Unauthorized serialization attempt " + clazz.getName()); + } + } + +} diff --git a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java index 02f0ad9226..83bbb683e9 100644 --- a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java +++ b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java @@ -26,13 +26,18 @@ package org.opensearch.security.support; +import java.io.IOException; import java.io.Serializable; import java.util.Arrays; +import java.util.Objects; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; -public class SourceFieldsContext implements Serializable { +public class SourceFieldsContext implements Serializable, Writeable { private String[] includes; private String[] excludes; @@ -77,6 +82,18 @@ public SourceFieldsContext(SearchRequest request) { // } } + public SourceFieldsContext(StreamInput in) throws IOException { + includes = in.readStringArray(); + if (includes.length == 0) { + includes = null; + } + excludes = in.readStringArray(); + if (excludes.length == 0) { + excludes = null; + } + fetchSource = in.readBoolean(); + } + public SourceFieldsContext(GetRequest request) { if (request.fetchSourceContext() != null) { includes = request.fetchSourceContext().includes(); @@ -117,4 +134,11 @@ public String toString() { + fetchSource + "]"; } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeStringArray(Objects.requireNonNullElseGet(includes, () -> new String[] {})); + streamOutput.writeStringArray(Objects.requireNonNullElseGet(excludes, () -> new String[] {})); + streamOutput.writeBoolean(fetchSource); + } } diff --git a/src/main/java/org/opensearch/security/support/StreamableRegistry.java b/src/main/java/org/opensearch/security/support/StreamableRegistry.java new file mode 100644 index 0000000000..bfde866376 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/StreamableRegistry.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; + +import org.opensearch.OpenSearchException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Registry for any class that does NOT implement the Writeable interface + * and needs to be serialized over the wire. Supports registration of writer and reader via registerStreamable + * for such classes and provides methods writeTo and readFrom for objects of such registered classes. + *
+ * Methods are protected and intended to be accessed from only within the package. (mostly by Base64Helper) + */ +public class StreamableRegistry { + + private static final StreamableRegistry INSTANCE = new StreamableRegistry(); + public final BiMap, Integer> classToIdMap = HashBiMap.create(); + private final Map idToEntryMap = new HashMap<>(); + + private StreamableRegistry() { + registerAllStreamables(); + } + + private static class Entry { + Writeable.Writer writer; + Writeable.Reader reader; + + Entry(Writeable.Writer writer, Writeable.Reader reader) { + this.writer = writer; + this.reader = reader; + } + } + + private Writeable.Writer getWriter(Class clazz) { + if (!classToIdMap.containsKey(clazz)) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return idToEntryMap.get(classToIdMap.get(clazz)).writer; + } + + private Writeable.Reader getReader(int id) { + if (!idToEntryMap.containsKey(id)) { + throw new OpenSearchException(String.format("No reader registered for id %s", id)); + } + return idToEntryMap.get(id).reader; + } + + private int getId(Class clazz) { + if (!classToIdMap.containsKey(clazz)) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return classToIdMap.get(clazz); + } + + protected boolean isStreamable(Class clazz) { + return classToIdMap.containsKey(clazz); + } + + protected void writeTo(StreamOutput out, Object object) throws IOException { + out.writeByte((byte) getId(object.getClass())); + getWriter(object.getClass()).write(out, object); + } + + protected Object readFrom(StreamInput in) throws IOException { + int id = in.readByte(); + return getReader(id).read(in); + } + + protected static StreamableRegistry getInstance() { + return INSTANCE; + } + + protected void registerStreamable(int streamableId, Class clazz, Writeable.Writer writer, Writeable.Reader reader) { + if (Writeable.class.isAssignableFrom(clazz)) { + throw new IllegalArgumentException( + String.format("%s is Writeable and should not be registered as a streamable", clazz.getName()) + ); + } + classToIdMap.put(clazz, streamableId); + idToEntryMap.put(streamableId, new Entry(writer, reader)); + } + + protected int getStreamableID(Class clazz) { + if (!isStreamable(clazz)) { + throw new OpenSearchException(String.format("class %s is in streamable registry", clazz.getName())); + } else { + return classToIdMap.get(clazz); + } + } + + /** + * Register all streamables here. + *
+ * Caution - Register new streamables towards the end. Removing / reordering a registered streamable will change the typeIDs associated with the streamables + * causing a breaking change in the serialization format. + */ + private void registerAllStreamables() { + + // InetSocketAddress + this.registerStreamable(1, InetSocketAddress.class, (o, v) -> { + final InetSocketAddress inetSocketAddress = (InetSocketAddress) v; + o.writeString(inetSocketAddress.getHostString()); + o.writeByteArray(inetSocketAddress.getAddress().getAddress()); + o.writeInt(inetSocketAddress.getPort()); + }, i -> { + String host = i.readString(); + byte[] addressBytes = i.readByteArray(); + int port = i.readInt(); + return new InetSocketAddress(InetAddress.getByAddress(host, addressBytes), port); + }); + } + +} diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index 0c645c9a00..f064f0af04 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -59,6 +59,7 @@ import org.opensearch.security.ssl.transport.SSLConfig; import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.HeaderHelper; import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; @@ -147,6 +148,7 @@ public void sendRequestDecorate( final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); + final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { @@ -224,9 +226,26 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL ); } + if (useJDKSerialization) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); + headerMap.putAll(jdkSerializedHeaders); + } + getThreadContext().putHeader(headerMap); - ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString, isSameNodeRequest); + ensureCorrectHeaders( + remoteAddress0, + user0, + origin0, + injectedUserString, + injectedRolesString, + isSameNodeRequest, + useJDKSerialization + ); if (isActionTraceEnabled()) { getThreadContext().putHeader( @@ -253,7 +272,8 @@ private void ensureCorrectHeaders( final String origin, final String injectedUserString, final String injectedRolesString, - boolean isSameNodeRequest + final boolean isSameNodeRequest, + final boolean useJDKSerialization ) { // keep original address @@ -294,7 +314,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE if (transportAddress != null) { getThreadContext().putHeader( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, - Base64Helper.serializeObject(transportAddress.address()) + Base64Helper.serializeObject(transportAddress.address(), useJDKSerialization) ); } @@ -302,7 +322,10 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE if (userHeader == null) { // put as headers for other requests if (origUser != null) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser)); + getThreadContext().putHeader( + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + Base64Helper.serializeObject(origUser, useJDKSerialization) + ); } else if (StringUtils.isNotEmpty(injectedRolesString)) { getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString); } else if (StringUtils.isNotEmpty(injectedUserString)) { diff --git a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java index 1284ca9781..3ba379dd67 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java +++ b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java @@ -107,6 +107,8 @@ protected void messageReceivedDecorate( resolvedActionClass = ((ConcreteShardRequest) request).getRequest().getClass().getSimpleName(); } + final boolean useJDKSerialization = getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION); + String initialActionClassValue = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER); final ThreadContext.StoredContext sgContext = getThreadContext().newStoredContext(false); @@ -181,7 +183,7 @@ protected void messageReceivedDecorate( } else { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_USER, - Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader)) + Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader, useJDKSerialization)) ); } @@ -190,7 +192,7 @@ protected void messageReceivedDecorate( if (!Strings.isNullOrEmpty(originalRemoteAddress)) { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, - new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress)) + new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress, useJDKSerialization)) ); } else { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress()); diff --git a/src/main/java/org/opensearch/security/user/User.java b/src/main/java/org/opensearch/security/user/User.java index 2642b368d7..394b251271 100644 --- a/src/main/java/org/opensearch/security/user/User.java +++ b/src/main/java/org/opensearch/security/user/User.java @@ -83,6 +83,9 @@ public User(final StreamInput in) throws IOException { name = in.readString(); roles.addAll(in.readList(StreamInput::readString)); requestedTenant = in.readString(); + if (requestedTenant.isEmpty()) { + requestedTenant = null; + } attributes = Collections.synchronizedMap(in.readMap(StreamInput::readString, StreamInput::readString)); securityRoles.addAll(in.readList(StreamInput::readString)); } @@ -167,9 +170,9 @@ public final boolean isUserInRole(final String role) { } /** - * Associate this user with a set of backend roles + * Associate this user with a set of custom attributes * - * @param roles The backend roles + * @param attributes custom attributes */ public final void addAttributes(final Map attributes) { if (attributes != null) { @@ -255,7 +258,7 @@ public final void copyRolesFrom(final User user) { public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeStringCollection(new ArrayList(roles)); - out.writeString(requestedTenant); + out.writeString(requestedTenant == null ? "" : requestedTenant); out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); out.writeStringCollection(securityRoles == null ? Collections.emptyList() : new ArrayList(securityRoles)); } diff --git a/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java b/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java new file mode 100644 index 0000000000..e35e1d72ba --- /dev/null +++ b/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import org.junit.Assert; +import org.junit.Test; +import org.ldaptive.LdapEntry; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.user.User; + +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.HashMap; + +import static org.opensearch.security.support.Base64CustomHelper.deserializeObject; +import static org.opensearch.security.support.Base64CustomHelper.serializeObject; + +public class Base64CustomHelperTest { + + private static final class NotSafeStreamable implements Serializable { + private static final long serialVersionUID = 5135559266828470092L; + } + + private static final class NotSafeWriteable implements Writeable, Serializable { + @Override + public void writeTo(StreamOutput out) { + + } + } + + private static Serializable ds(Serializable s) { + return deserializeObject(serializeObject(s)); + } + + @Test + public void testString() { + String string = "string"; + Assert.assertEquals(string, ds(string)); + } + + @Test + public void testInteger() { + Integer integer = 0; + Assert.assertEquals(integer, ds(integer)); + } + + @Test + public void testDouble() { + Double number = 0.; + Assert.assertEquals(number, ds(number)); + } + + @Test + public void testInetSocketAddress() { + InetSocketAddress inetSocketAddress = new InetSocketAddress(0); + Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); + } + + @Test + public void testUser() { + User user = new User("user"); + Assert.assertEquals(user, ds(user)); + } + + @Test + public void testSourceFieldsContext() { + SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); + Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + } + + @Test + public void testHashMap() { + HashMap map = new HashMap<>() { + { + put("key", "value"); + } + }; + Assert.assertEquals(map, ds(map)); + } + + @Test + public void testArrayList() { + ArrayList list = new ArrayList<>() { + { + add("value"); + } + }; + Assert.assertEquals(list, ds(list)); + } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + } + + @Test + public void testGetWriteableClassID() { + // a need to make a change in this test signifies a breaking change in security plugin's custom serialization + // format + Assert.assertEquals(Integer.valueOf(1), Base64CustomHelper.getWriteableClassID(User.class)); + Assert.assertEquals(Integer.valueOf(2), Base64CustomHelper.getWriteableClassID(LdapUser.class)); + Assert.assertEquals(Integer.valueOf(3), Base64CustomHelper.getWriteableClassID(UserInjector.InjectedUser.class)); + Assert.assertEquals(Integer.valueOf(4), Base64CustomHelper.getWriteableClassID(SourceFieldsContext.class)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // for custom serialization, we expect InjectedUser to be returned on deserialization + UserInjector.InjectedUser deserializedInjecteduser = (UserInjector.InjectedUser) ds(injectedUser); + Assert.assertEquals(injectedUser, deserializedInjecteduser); + Assert.assertTrue(deserializedInjecteduser.isInjected()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeStreamable() { + Base64JDKHelper.serializeObject(new NotSafeStreamable()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeWriteable() { + Base64JDKHelper.serializeObject(new NotSafeWriteable()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeGeneric() { + HashMap map = new HashMap<>(); + map.put(1, ZonedDateTime.now()); + map.put(2, ZonedDateTime.now()); + Base64JDKHelper.serializeObject(map); + } + +} diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 81c2505985..f55581c7e7 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -10,100 +10,44 @@ */ package org.opensearch.security.support; -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; -import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.regex.Pattern; -import com.google.common.io.BaseEncoding; import org.junit.Assert; import org.junit.Test; -import org.opensearch.OpenSearchException; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.security.user.User; - import static org.opensearch.security.support.Base64Helper.deserializeObject; import static org.opensearch.security.support.Base64Helper.serializeObject; public class Base64HelperTest { - private static final class NotSafeSerializable implements Serializable { - private static final long serialVersionUID = 5135559266828470092L; + private static Serializable dsJDK(Serializable s) { + return deserializeObject(serializeObject(s, true), true); } private static Serializable ds(Serializable s) { return deserializeObject(serializeObject(s)); } + /** + * Just one sanity test comprising invocation of JDK and Custom Serialization. + * + * Individual scenarios are covered by Base64CustomHelperTest and Base64JDKHelperTest + */ @Test - public void testString() { - String string = "string"; - Assert.assertEquals(string, ds(string)); - } - - @Test - public void testInteger() { - Integer integer = Integer.valueOf(0); - Assert.assertEquals(integer, ds(integer)); - } - - @Test - public void testDouble() { - Double number = Double.valueOf(0.); - Assert.assertEquals(number, ds(number)); - } - - @Test - public void testInetSocketAddress() { - InetSocketAddress inetSocketAddress = new InetSocketAddress(0); - Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); - } - - @Test - public void testPattern() { - Pattern pattern = Pattern.compile(".*"); - Assert.assertEquals(pattern.pattern(), ((Pattern) ds(pattern)).pattern()); - } - - @Test - public void testUser() { - User user = new User("user"); - Assert.assertEquals(user, ds(user)); - } - - @Test - public void testSourceFieldsContext() { - SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); - Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); - } - - @Test - public void testHashMap() { - HashMap map = new HashMap(); - Assert.assertEquals(map, ds(map)); + public void testSerde() { + String test = "string"; + Assert.assertEquals(test, ds(test)); + Assert.assertEquals(test, dsJDK(test)); } @Test - public void testArrayList() { - ArrayList list = new ArrayList(); - Assert.assertEquals(list, ds(list)); - } + public void testEnsureJDKSerialized() { + String test = "string"; + String jdkSerialized = Base64Helper.serializeObject(test, true); + String customSerialized = Base64Helper.serializeObject(test, false); + Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized)); + Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized)); - @Test(expected = OpenSearchException.class) - public void notSafeSerializable() { - serializeObject(new NotSafeSerializable()); } - @Test(expected = OpenSearchException.class) - public void notSafeDeserializable() throws Exception { - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try (final ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeObject(new NotSafeSerializable()); - } - deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); - } } diff --git a/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java b/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java new file mode 100644 index 0000000000..704f1dc1d7 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.io.BaseEncoding; +import org.junit.Assert; +import org.junit.Test; +import org.ldaptive.LdapEntry; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.user.User; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashMap; + +public class Base64JDKHelperTest { + private static final class NotSafeSerializable implements Serializable { + private static final long serialVersionUID = 5135559266828470092L; + } + + private static Serializable ds(Serializable s) { + return Base64JDKHelper.deserializeObject(Base64JDKHelper.serializeObject(s)); + } + + @Test + public void testString() { + String string = "string"; + Assert.assertEquals(string, ds(string)); + } + + @Test + public void testInteger() { + Integer integer = 0; + Assert.assertEquals(integer, ds(integer)); + } + + @Test + public void testDouble() { + Double number = 0.0; + Assert.assertEquals(number, ds(number)); + } + + @Test + public void testInetSocketAddress() { + InetSocketAddress inetSocketAddress = new InetSocketAddress(0); + Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); + } + + @Test + public void testUser() { + User user = new User("user"); + Assert.assertEquals(user, ds(user)); + } + + @Test + public void testSourceFieldsContext() { + SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); + Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + } + + @Test + public void testHashMap() { + HashMap map = new HashMap<>(); + map.put("key", "value"); + Assert.assertEquals(map, ds(map)); + } + + @Test + public void testArrayList() { + ArrayList list = new ArrayList<>(); + list.add("value"); + Assert.assertEquals(list, ds(list)); + } + + @Test(expected = OpenSearchException.class) + public void notSafeSerializable() { + Base64JDKHelper.serializeObject(new NotSafeSerializable()); + } + + @Test(expected = OpenSearchException.class) + public void notSafeDeserializable() throws Exception { + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (final ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(new NotSafeSerializable()); + } + Base64JDKHelper.deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); + } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // we expect to get User object when deserializing InjectedUser via JDK serialization + User user = new User("username"); + User deserializedUser = (User) ds(injectedUser); + Assert.assertEquals(user, deserializedUser); + Assert.assertTrue(deserializedUser.isInjected()); + } +} diff --git a/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java new file mode 100644 index 0000000000..13f2448b30 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.OpenSearchException; + +import java.net.InetSocketAddress; + +public class StreamableRegistryTest { + + StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + + @Test + public void testStreamableTypeIDs() { + Assert.assertEquals(1, streamableRegistry.getStreamableID(InetSocketAddress.class)); + Assert.assertThrows(OpenSearchException.class, () -> streamableRegistry.getStreamableID(String.class)); + } +} diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index d3363c54d8..abc0e314ef 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -47,9 +47,6 @@ import static java.util.Collections.emptySet; import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; // CS-ENFORCE-SINGLE @@ -110,9 +107,8 @@ public void setup() { ); } - @Test - public void testSendRequestDecorate() { - + private void testSendRequestDecorate(Version remoteNodeVersion) { + boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); ClusterName clusterName = ClusterName.DEFAULT; when(clusterService.getClusterName()).thenReturn(clusterName); @@ -140,7 +136,6 @@ public void testSendRequestDecorate() { User user = new User("John Doe"); threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, user); - AsyncSender sender = mock(AsyncSender.class); String action = "testAction"; TransportRequest request = mock(TransportRequest.class); TransportRequestOptions options = mock(TransportRequestOptions.class); @@ -156,37 +151,65 @@ public void testSendRequestDecorate() { DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT); Connection connection1 = transportService.getConnection(localNode); - DiscoveryNode otherNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 4321), Version.CURRENT); + DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion); Connection connection2 = transportService.getConnection(otherNode); + // from thread context inside sendRequestDecorate + AsyncSender sender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + assertEquals(transientUser, user); + } + }; // isSameNodeRequest = true securityInterceptor.sendRequestDecorate(sender, connection1, action, request, options, handler, localNode); - // from thread context inside sendRequestDecorate - doAnswer(i -> { - User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); - assertEquals(transientUser, user); - return null; - }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); // from original context User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser, user); assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); - // isSameNodeRequest = false - securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler, otherNode); // checking thread context inside sendRequestDecorate - doAnswer(i -> { - String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); - assertEquals(serializedUserHeader, Base64Helper.serializeObject(user)); - return null; - }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); + sender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization)); + } + }; + // isSameNodeRequest = false + securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler, localNode); // from original context User transientUser2 = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser2, user); assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + } + + @Test + public void testSendRequestDecorate() { + testSendRequestDecorate(Version.CURRENT); + } + /** + * Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization + */ + @Test + public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() { + testSendRequestDecorate(Version.V_2_0_0); } } diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java new file mode 100644 index 0000000000..23a64e4be3 --- /dev/null +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -0,0 +1,80 @@ +package org.opensearch.security.transport; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.opensearch.Version; +import org.opensearch.common.settings.Settings; +import org.opensearch.security.ssl.SslExceptionHandler; +import org.opensearch.security.ssl.transport.PrincipalExtractor; +import org.opensearch.security.ssl.transport.SSLConfig; +import org.opensearch.security.ssl.transport.SecuritySSLRequestHandler; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SecuritySSLRequestHandlerTests { + + @Mock + TransportRequestHandler actualHandler; + @Mock + SSLConfig sslConfig; + ThreadPool threadPool; + SslExceptionHandler sslExceptionHandler; + Settings settings; + SecuritySSLRequestHandler securitySSLRequestHandler; + String testAction; + + @Mock + private PrincipalExtractor principalExtractor; + + @Before + public void setUp() { + settings = Settings.builder() + .put("node.name", SecurityInterceptorTests.class.getSimpleName()) + .put("request.headers.default", "1") + .build(); + threadPool = new ThreadPool(settings); + testAction = "test_action"; + sslExceptionHandler = mock(SslExceptionHandler.class); + securitySSLRequestHandler = new SecuritySSLRequestHandler<>( + testAction, + actualHandler, + threadPool, + principalExtractor, + sslConfig, + sslExceptionHandler + ); + doNothing().when(sslExceptionHandler) + .logError(any(Exception.class), any(TransportRequest.class), any(String.class), any(Task.class), anyInt()); + } + + @Test + public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + Task task = mock(Task.class); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + when(transportChannel.getChannelType()).thenReturn("transport"); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + } +} From 8ea5e56c3344a5a6b6b9333fedf30c9bf542717f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 12:25:31 -0400 Subject: [PATCH 2/8] dependabot: bump org.xerial.snappy:snappy-java from 1.1.10.4 to 1.1.10.5 (#3435) Bumps [org.xerial.snappy:snappy-java](https://github.com/xerial/snappy-java) from 1.1.10.4 to 1.1.10.5. Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index d3b3846edd..8e192900cc 100644 --- a/build.gradle +++ b/build.gradle @@ -430,7 +430,7 @@ configurations { force "io.netty:netty-transport-native-unix-common:${versions.netty}" force "org.apache.bcel:bcel:6.7.0" // This line should be removed once Spotbugs is upgraded to 4.7.4 force "com.github.luben:zstd-jni:${versions.zstd}" - force "org.xerial.snappy:snappy-java:1.1.10.4" + force "org.xerial.snappy:snappy-java:1.1.10.5" force "com.google.guava:guava:${guava_version}" } } @@ -559,7 +559,7 @@ dependencies { runtimeOnly 'io.dropwizard.metrics:metrics-core:4.2.19' runtimeOnly 'org.slf4j:slf4j-api:1.7.36' runtimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:${versions.log4j}" - runtimeOnly 'org.xerial.snappy:snappy-java:1.1.10.4' + runtimeOnly 'org.xerial.snappy:snappy-java:1.1.10.5' runtimeOnly 'org.codehaus.woodstox:stax2-api:4.2.1' runtimeOnly "org.glassfish.jaxb:txw2:${jaxb_version}" runtimeOnly 'com.fasterxml.woodstox:woodstox-core:6.5.1' From 97dea2a9be8c01c3481688f7996a6f9539baef05 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 12:28:13 -0400 Subject: [PATCH 3/8] dependabot: bump org.ow2.asm:asm from 9.5 to 9.6 (#3433) Bumps org.ow2.asm:asm from 9.5 to 9.6. Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 8e192900cc..92b7844252 100644 --- a/build.gradle +++ b/build.gradle @@ -525,7 +525,7 @@ dependencies { runtimeOnly 'com.google.errorprone:error_prone_annotations:2.22.0' runtimeOnly 'com.sun.istack:istack-commons-runtime:4.2.0' runtimeOnly 'jakarta.xml.bind:jakarta.xml.bind-api:4.0.0' - runtimeOnly 'org.ow2.asm:asm:9.5' + runtimeOnly 'org.ow2.asm:asm:9.6' testImplementation 'org.apache.camel:camel-xmlsecurity:3.21.0' From 567217cd5b4a28ce4ceb7fe6c8923ce7e1e489c9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:02:09 -0400 Subject: [PATCH 4/8] dependabot: bump org.apache.camel:camel-xmlsecurity from 3.21.0 to 3.21.1 (#3432) Bumps org.apache.camel:camel-xmlsecurity from 3.21.0 to 3.21.1. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=org.apache.camel:camel-xmlsecurity&package-manager=gradle&previous-version=3.21.0&new-version=3.21.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 92b7844252..d2e08af0d7 100644 --- a/build.gradle +++ b/build.gradle @@ -527,7 +527,7 @@ dependencies { runtimeOnly 'jakarta.xml.bind:jakarta.xml.bind-api:4.0.0' runtimeOnly 'org.ow2.asm:asm:9.6' - testImplementation 'org.apache.camel:camel-xmlsecurity:3.21.0' + testImplementation 'org.apache.camel:camel-xmlsecurity:3.21.1' //OpenSAML implementation 'net.shibboleth.utilities:java-support:8.4.0' From fa52472ad18863716ed201ab2bf1422726ea8fa8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:02:40 -0400 Subject: [PATCH 5/8] dependabot: bump com.github.wnameless.json:json-base from 2.4.2 to 2.4.3 (#3434) Bumps [com.github.wnameless.json:json-base](https://github.com/wnameless/json-base) from 2.4.2 to 2.4.3.
Changelog

Sourced from com.github.wnameless.json:json-base's changelog.

Version 1.0.0

  • First release

Version 1.1.0

  • Support Java 9 Module
  • Add isEmpty() to JsonArrayBase and JsonObjectBase
  • Change package name from com.github.wnameless.json to com.github.wnameless.json.base

Version 1.1.1

  • Using "requires static" on Gson and Jackson

Version 1.2.0

  • Add Jsonable interface

Version 2.0.0

  • Add #asBigInteger, #asBigDecimal, #asNumber, #asNull
  • Add #toMap, #toList
  • Add JsonValueCore, JsonObjectCore, JsonArrayCore, JsonCore, JsonSource
  • Add JsonPrinter, JsonValueUtils

Version 2.1.0

  • Fix JsonProter#prettyPrint bug
  • Improve module-info.java

Version 2.2.0

  • Alter all "requires static transitive" to "requires static" in module-info.java to avoid "module not found" error while compiling by other projects

Version 2.2.1

  • Fix JsonPrinter bug on the edge case: having backslash before ending double quotes

Version 2.3.0

  • Add org.json lib support
  • Add Jakarta lib support
  • Change the return type of JsonArrayCore#remove(int) from boolean to JsonArrayCore
  • Increase JUnit code coverage to 100%
  • Remove Cobertura maven dependency

Version 2.4.0

  • Add #stream to JsonArrayBase and JsonObjectBase

Version 2.4.1

  • Improve OrgJsonValue#asNumber
  • Add JsonPrinter#toJsonString

Version 2.4.2

  • Upgrade POM

Version 2.4.3

  • Modify JsonValueUtils#toJavaNumber for preserving precise scale of the float number
Commits
  • a2785b7 [maven-release-plugin] prepare release json-base-2.4.3
  • 849d039 Improve GsonJsonValue implementation
  • 2442050 To preserve precise scale of the float number
  • 10f6d99 [maven-release-plugin] prepare for next development iteration
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=com.github.wnameless.json:json-base&package-manager=gradle&previous-version=2.4.2&new-version=2.4.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index d2e08af0d7..5649e8058a 100644 --- a/build.gradle +++ b/build.gradle @@ -494,7 +494,7 @@ dependencies { implementation "io.jsonwebtoken:jjwt-impl:${jjwt_version}" implementation "io.jsonwebtoken:jjwt-jackson:${jjwt_version}" // JSON flattener - implementation ("com.github.wnameless.json:json-base:2.4.2") { + implementation ("com.github.wnameless.json:json-base:2.4.3") { exclude group: "org.glassfish", module: "jakarta.json" exclude group: "com.google.code.gson", module: "gson" exclude group: "org.json", module: "json" From 702cc3ca4d08bd44c4ae2de3a357524ed08fb10d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 14:20:48 -0400 Subject: [PATCH 6/8] dependabot: bump commons-io:commons-io from 2.13.0 to 2.14.0 (#3431) Bumps commons-io:commons-io from 2.13.0 to 2.14.0. Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 5649e8058a..741fd41262 100644 --- a/build.gradle +++ b/build.gradle @@ -627,7 +627,7 @@ dependencies { integrationTestImplementation 'junit:junit:4.13.2' integrationTestImplementation "org.opensearch.plugin:reindex-client:${opensearch_version}" integrationTestImplementation "org.opensearch.plugin:percolator-client:${opensearch_version}" - integrationTestImplementation 'commons-io:commons-io:2.13.0' + integrationTestImplementation 'commons-io:commons-io:2.14.0' integrationTestImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" integrationTestImplementation "org.apache.logging.log4j:log4j-jul:${versions.log4j}" integrationTestImplementation 'org.hamcrest:hamcrest:2.2' From 1ffa23cc96a6bf96752c33b7f134e4e6139dc828 Mon Sep 17 00:00:00 2001 From: Ryan Liang <109499885+RyanL1997@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:53:06 -0700 Subject: [PATCH 7/8] [Enhancement] Setup auth token utils for obo (#3419) Setup auth token utils for obo (#3419) --------- Signed-off-by: Ryan Liang --- .../security/authtoken/jwt/JwtVendor.java | 10 +-- .../http/OnBehalfOfAuthenticator.java | 6 +- .../securityconf/DynamicConfigModelV7.java | 4 +- .../security/util/AuthTokenUtils.java | 41 ++++++++++ .../authtoken/jwt/AuthTokenUtilsTest.java | 78 +++++++++++++++++++ 5 files changed, 129 insertions(+), 10 deletions(-) create mode 100644 src/main/java/org/opensearch/security/util/AuthTokenUtils.java create mode 100644 src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java diff --git a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java index 5d3262799f..e68a5ef2d7 100644 --- a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java +++ b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java @@ -16,7 +16,6 @@ import java.util.Optional; import java.util.function.LongSupplier; -import com.google.common.base.Strings; import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; import org.apache.cxf.rs.security.jose.jwk.KeyType; @@ -32,6 +31,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.security.ssl.util.ExceptionUtils; +import static org.opensearch.security.util.AuthTokenUtils.isKeyNull; + public class JwtVendor { private static final Logger logger = LogManager.getLogger(JwtVendor.class); @@ -53,7 +54,7 @@ public JwtVendor(final Settings settings, final Optional timeProvi throw ExceptionUtils.createJwkCreationException(e); } this.jwtProducer = jwtProducer; - if (settings.get("encryption_key") == null) { + if (isKeyNull(settings, "encryption_key")) { throw new IllegalArgumentException("encryption_key cannot be null"); } else { this.claimsEncryptionKey = settings.get("encryption_key"); @@ -73,9 +74,8 @@ public JwtVendor(final Settings settings, final Optional timeProvi * Encryption Algorithm: HS512 * */ static JsonWebKey createJwkFromSettings(Settings settings) throws Exception { - String signingKey = settings.get("signing_key"); - - if (!Strings.isNullOrEmpty(signingKey)) { + if (!isKeyNull(settings, "signing_key")) { + String signingKey = settings.get("signing_key"); JsonWebKey jwk = new JsonWebKey(); diff --git a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java index 467edd8ac4..c47e850b75 100644 --- a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java @@ -43,13 +43,12 @@ import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX; import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX; +import static org.opensearch.security.util.AuthTokenUtils.isAccessToRestrictedEndpoints; public class OnBehalfOfAuthenticator implements HTTPAuthenticator { private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); - private static final String ON_BEHALF_OF_SUFFIX = "api/generateonbehalfoftoken"; - private static final String ACCOUNT_SUFFIX = "api/account"; protected final Logger log = LogManager.getLogger(this.getClass()); @@ -233,8 +232,7 @@ private void logDebug(String message, Object... args) { public Boolean isRequestAllowed(final RestRequest request) { Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (request.method() == RestRequest.Method.POST && ON_BEHALF_OF_SUFFIX.equals(suffix) - || request.method() == RestRequest.Method.PUT && ACCOUNT_SUFFIX.equals(suffix)) { + if (isAccessToRestrictedEndpoints(request, suffix)) { final OpenSearchException exception = ExceptionUtils.invalidUsageOfOBOTokenException(); log.error(exception.toString()); return false; diff --git a/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java b/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java index fcbf985f60..0de83f2e2e 100644 --- a/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java +++ b/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java @@ -67,6 +67,8 @@ import org.opensearch.security.securityconf.impl.v7.ConfigV7.AuthzDomain; import org.opensearch.security.support.ReflectionHelper; +import static org.opensearch.security.util.AuthTokenUtils.isKeyNull; + public class DynamicConfigModelV7 extends DynamicConfigModel { private final ConfigV7 config; @@ -383,7 +385,7 @@ private void buildAAA() { * order: -1 - prioritize the OBO authentication when it gets enabled */ Settings oboSettings = getDynamicOnBehalfOfSettings(); - if (oboSettings.get("signing_key") != null && oboSettings.get("encryption_key") != null) { + if (!isKeyNull(oboSettings, "signing_key") && !isKeyNull(oboSettings, "encryption_key")) { final AuthDomain _ad = new AuthDomain( new NoOpAuthenticationBackend(Settings.EMPTY, null), new OnBehalfOfAuthenticator(getDynamicOnBehalfOfSettings(), this.cih.getClusterName()), diff --git a/src/main/java/org/opensearch/security/util/AuthTokenUtils.java b/src/main/java/org/opensearch/security/util/AuthTokenUtils.java new file mode 100644 index 0000000000..30f331d3a7 --- /dev/null +++ b/src/main/java/org/opensearch/security/util/AuthTokenUtils.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.util; + +import org.opensearch.common.settings.Settings; +import org.opensearch.rest.RestRequest; + +import static org.opensearch.rest.RestRequest.Method.POST; +import static org.opensearch.rest.RestRequest.Method.PUT; + +public class AuthTokenUtils { + private static final String ON_BEHALF_OF_SUFFIX = "api/generateonbehalfoftoken"; + private static final String ACCOUNT_SUFFIX = "api/account"; + + public static Boolean isAccessToRestrictedEndpoints(final RestRequest request, final String suffix) { + if (suffix == null) { + return false; + } + switch (suffix) { + case ON_BEHALF_OF_SUFFIX: + return request.method() == POST; + case ACCOUNT_SUFFIX: + return request.method() == PUT; + default: + return false; + } + } + + public static Boolean isKeyNull(Settings settings, String key) { + return settings.get(key) == null; + } +} diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java new file mode 100644 index 0000000000..d563308e31 --- /dev/null +++ b/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.authtoken.jwt; + +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.RestRequest; +import org.opensearch.security.util.AuthTokenUtils; +import org.opensearch.test.rest.FakeRestRequest; +import org.junit.Test; + +import java.util.Collections; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class AuthTokenUtilsTest { + + @Test + public void testIsAccessToRestrictedEndpointsForOnBehalfOfToken() { + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList()); + + FakeRestRequest request = new FakeRestRequest.Builder(namedXContentRegistry).withPath("/api/generateonbehalfoftoken") + .withMethod(RestRequest.Method.POST) + .build(); + + assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/generateonbehalfoftoken")); + } + + @Test + public void testIsAccessToRestrictedEndpointsForAccount() { + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList()); + + FakeRestRequest request = new FakeRestRequest.Builder(namedXContentRegistry).withPath("/api/account") + .withMethod(RestRequest.Method.PUT) + .build(); + + assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/account")); + } + + @Test + public void testIsAccessToRestrictedEndpointsFalseCase() { + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList()); + + FakeRestRequest request = new FakeRestRequest.Builder(namedXContentRegistry).withPath("/api/someotherendpoint") + .withMethod(RestRequest.Method.GET) + .build(); + + assertFalse(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/someotherendpoint")); + } + + @Test + public void testIsKeyNullWithNullValue() { + Settings settings = Settings.builder().put("someKey", (String) null).build(); + assertTrue(AuthTokenUtils.isKeyNull(settings, "someKey")); + } + + @Test + public void testIsKeyNullWithNonNullValue() { + Settings settings = Settings.builder().put("someKey", "value").build(); + assertFalse(AuthTokenUtils.isKeyNull(settings, "someKey")); + } + + @Test + public void testIsKeyNullWithAbsentKey() { + Settings settings = Settings.builder().build(); + assertTrue(AuthTokenUtils.isKeyNull(settings, "absentKey")); + } +} From 7924da13a57ecbf3352d84e6d020012723b81fa1 Mon Sep 17 00:00:00 2001 From: Darshit Chanpura <35282393+DarshitChanpura@users.noreply.github.com> Date: Tue, 3 Oct 2023 10:32:15 -0400 Subject: [PATCH 8/8] Refactors reRequestAuthentication to call notifyIpAuthFailureListener before sending the response to the channel (#3411) Prior to this change, the ip auth failure listener was not called upon challengeAuthenticator check invocation, which caused AddressBasedRateLimiter to not be invoked. With this change AddressBasedRateLimiter will be invoked upon multiple wrong requests from an ip. Signed-off-by: Darshit Chanpura --- .../IpBruteForceAttacksPreventionTests.java | 27 +++++----- ...cksPreventionWithDomainChallengeTests.java | 32 ++++++++++++ .../cluster/LocalOpenSearchCluster.java | 2 +- .../jwt/AbstractHTTPJwtAuthenticator.java | 6 +-- .../auth/http/jwt/HTTPJwtAuthenticator.java | 6 +-- .../kerberos/HTTPSpnegoAuthenticator.java | 11 ++-- .../http/saml/AuthTokenProcessorHandler.java | 41 ++++----------- .../auth/http/saml/HTTPSamlAuthenticator.java | 19 ++++--- .../security/auth/BackendRegistry.java | 51 ++++++++++-------- .../security/auth/HTTPAuthenticator.java | 15 +++--- .../security/http/HTTPBasicAuthenticator.java | 6 +-- .../http/HTTPClientCertAuthenticator.java | 6 +-- .../security/http/HTTPProxyAuthenticator.java | 6 +-- .../http/OnBehalfOfAuthenticator.java | 6 +-- .../proxy/HTTPExtendedProxyAuthenticator.java | 6 --- .../http/saml/HTTPSamlAuthenticatorTest.java | 52 +++++++++++++------ .../limiting/AddressBasedRateLimiterTest.java | 21 ++++---- .../UserNameBasedRateLimiterTest.java | 21 ++++---- .../cache/DummyHTTPAuthenticator.java | 6 +-- 19 files changed, 181 insertions(+), 159 deletions(-) create mode 100644 src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java index bb16e0be1b..34e79613f6 100644 --- a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java @@ -12,7 +12,6 @@ import java.util.concurrent.TimeUnit; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,8 +35,8 @@ @RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class IpBruteForceAttacksPreventionTests { - private static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); - private static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); + static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); + static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); public static final int ALLOWED_TRIES = 3; public static final int TIME_WINDOW_SECONDS = 3; @@ -51,7 +50,7 @@ public class IpBruteForceAttacksPreventionTests { public static final String CLIENT_IP_8 = "127.0.0.8"; public static final String CLIENT_IP_9 = "127.0.0.9"; - private static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( + static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( new RateLimiting("internal_authentication_backend_limiting").type("ip") .allowedTries(ALLOWED_TRIES) .timeWindowSeconds(TIME_WINDOW_SECONDS) @@ -60,13 +59,17 @@ public class IpBruteForceAttacksPreventionTests { .maxTrackedClients(500) ); - @ClassRule - public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) - .anonymousAuth(false) - .authFailureListeners(listener) - .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) - .users(USER_1, USER_2) - .build(); + @Rule + public LocalCluster cluster = createCluster(); + + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) + .users(USER_1, USER_2) + .build(); + } @Rule public LogsRule logsRule = new LogsRule("org.opensearch.security.auth.BackendRegistry"); @@ -151,7 +154,7 @@ public void shouldReleaseIpAddressLock() throws InterruptedException { } } - private static void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { + void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { var clientConfiguration = new TestRestClientConfiguration().username(user.getName()) .password("incorrect password") .sourceInetAddress(sourceIpAddress); diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java new file mode 100644 index 0000000000..6159599119 --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.runner.RunWith; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; + +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class IpBruteForceAttacksPreventionWithDomainChallengeTests extends IpBruteForceAttacksPreventionTests { + @Override + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(USER_1, USER_2) + .build(); + } +} diff --git a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java index c09127e592..77890a4645 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java +++ b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java @@ -344,7 +344,7 @@ public String toString() { String clusterManagerNodes = nodeByTypeToString(CLUSTER_MANAGER); String dataNodes = nodeByTypeToString(DATA); String clientNodes = nodeByTypeToString(CLIENT); - return "\nES Cluster " + return "\nOS Cluster " + clusterName + "\ncluster manager nodes: " + clusterManagerNodes diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index 4dab3c7740..6dbb7b7676 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -36,7 +36,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -236,11 +235,10 @@ public String[] extractRoles(JwtClaims claims) { protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception; @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials authCredentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } public String getRequiredAudience() { diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index 03e385d5c0..338a490037 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -31,7 +31,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -171,11 +170,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 29f537e899..99bc23b746 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -49,7 +49,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -280,8 +279,7 @@ public GSSCredential run() throws GSSException { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse; XContentBuilder response = getNegotiateResponseBody(); @@ -291,16 +289,15 @@ public boolean reRequestAuthentication(final RestChannel channel, AuthCredential wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING); } - if (creds == null || creds.getNativeCredentials() == null) { + if (credentials == null || credentials.getNativeCredentials() == null) { wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate"); } else { wwwAuthenticateResponse.addHeader( "WWW-Authenticate", - "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials()) + "Negotiate " + Base64.getEncoder().encodeToString((byte[]) credentials.getNativeCredentials()) ); } - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 3fee4a9444..7210ed5950 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -23,7 +23,6 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPathExpressionException; import com.fasterxml.jackson.core.JsonParseException; @@ -32,7 +31,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.base.Strings; import com.onelogin.saml2.authn.SamlResponse; -import com.onelogin.saml2.exception.SettingsException; import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; @@ -49,7 +47,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; -import org.xml.sax.SAXException; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; @@ -57,7 +54,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; @@ -122,7 +118,7 @@ class AuthTokenProcessorHandler { } @SuppressWarnings("removal") - boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception { + BytesRestResponse handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -130,11 +126,10 @@ boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exceptio sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction() { @Override - public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException, - SAXException, SettingsException { - return handleLowLevel(restRequest, restChannel); + public BytesRestResponse run() throws SamlConfigException, IOException { + return handleLowLevel(restRequest); } }); } catch (PrivilegedActionException e) { @@ -147,13 +142,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc } private AuthTokenProcessorAction.Response handleImpl( - RestRequest restRequest, - RestChannel restChannel, String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings - ) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException { + ) { if (token_log.isDebugEnabled()) { try { token_log.debug( @@ -188,8 +181,7 @@ private AuthTokenProcessorAction.Response handleImpl( } } - private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException, - XPathExpressionException, ParserConfigurationException, SAXException, SettingsException { + private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getMediaType() != XContentType.JSON) { @@ -234,31 +226,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue()); } - AuthTokenProcessorAction.Response responseBody = this.handleImpl( - restRequest, - restChannel, - samlResponseBase64, - samlRequestId, - acsEndpoint, - saml2Settings - ); + AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return false; + return null; } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); - restChannel.sendResponse(authenticateResponse); - - return true; + return new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); - restChannel.sendResponse(authenticateResponse); - return true; + return new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index cd6209952f..64d816fabf 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -56,7 +56,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.Destroyable; @@ -171,13 +170,15 @@ public String getType() { } @Override - public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { try { - RestRequest restRequest = restChannel.request(); - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (API_AUTHTOKEN_SUFFIX.equals(suffix) && this.authTokenProcessorHandler.handle(restRequest, restChannel)) { - return true; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + final BytesRestResponse restResponse = this.authTokenProcessorHandler.handle(request); + if (restResponse != null) { + return restResponse; + } } Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); @@ -185,12 +186,10 @@ public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)); - restChannel.sendResponse(authenticateResponse); - - return true; + return authenticateResponse; } catch (Exception e) { log.error("Error in reRequestAuthentication()", e); - return false; + return null; } } diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index ad1406426b..35d4e05c4c 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -230,7 +230,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g User authenticatedUser = null; - AuthCredentials authCredenetials = null; + AuthCredentials authCredentials = null; HTTPAuthenticator firstChallengingHttpAuthenticator = null; @@ -272,7 +272,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - authCredenetials = ac; + authCredentials = ac; if (ac == null) { // no credentials found in request @@ -280,12 +280,18 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - if (authDomain.isChallenge() && httpAuthenticator.reRequestAuthentication(channel, null)) { - auditLog.logFailedLogin("", false, null, request); - if (isTraceEnabled) { - log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + if (authDomain.isChallenge()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, null); + if (restResponse != null) { + auditLog.logFailedLogin("", false, null, request); + if (isTraceEnabled) { + log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + } + notifyIpAuthFailureListeners(request, authCredentials); + channel.sendResponse(restResponse); + return false; } - return false; + } else { // no reRequest possible if (isTraceEnabled) { @@ -296,9 +302,12 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } else { org.apache.logging.log4j.ThreadContext.put("user", ac.getUsername()); if (!ac.isComplete()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, ac); // credentials found in request but we need another client challenge - if (httpAuthenticator.reRequestAuthentication(channel, ac)) { + if (restResponse != null) { // auditLog.logFailedLogin(ac.getUsername()+" ", request); --noauditlog + notifyIpAuthFailureListeners(request, ac); + channel.sendResponse(restResponse); return false; } else { // no reRequest possible @@ -376,7 +385,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("User still not authenticated after checking {} auth domains", restAuthDomains.size()); } - if (authCredenetials == null && anonymousAuthEnabled) { + if (authCredentials == null && anonymousAuthEnabled) { final String tenant = Utils.coalesce(request.header("securitytenant"), request.header("security_tenant")); User anonymousUser = new User(User.ANONYMOUS.getName(), new HashSet(User.ANONYMOUS.getRoles()), null); anonymousUser.setRequestedTenant(tenant); @@ -388,6 +397,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } return true; } + BytesRestResponse challengeResponse = null; if (firstChallengingHttpAuthenticator != null) { @@ -395,31 +405,28 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("Rerequest with {}", firstChallengingHttpAuthenticator.getClass()); } - if (firstChallengingHttpAuthenticator.reRequestAuthentication(channel, null)) { + challengeResponse = firstChallengingHttpAuthenticator.reRequestAuthentication(request, null); + if (challengeResponse != null) { if (isDebugEnabled) { log.debug("Rerequest {} failed", firstChallengingHttpAuthenticator.getClass()); } - - log.warn( - "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), - remoteAddress - ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); - return false; } } log.warn( "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), + authCredentials == null ? null : authCredentials.getUsername(), remoteAddress ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); + auditLog.logFailedLogin(authCredentials == null ? null : authCredentials.getUsername(), false, null, request); - notifyIpAuthFailureListeners(request, authCredenetials); + notifyIpAuthFailureListeners(request, authCredentials); - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); + channel.sendResponse( + challengeResponse != null + ? challengeResponse + : new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed") + ); return false; } diff --git a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java index fa5065ef68..70b94d6dbf 100644 --- a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java @@ -28,7 +28,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.user.AuthCredentials; @@ -71,15 +71,14 @@ public interface HTTPAuthenticator { /** * If the {@code extractCredentials()} call was not successful or the authentication flow needs another roundtrip this method - * will be called. If the custom HTTP authenticator does not support this method is a no-op and false should be returned. - * + * will be called. If the custom HTTP authenticator does not support this method is a no-op and null response should be returned. * If the custom HTTP authenticator does support re-request authentication or supports authentication flows with multiple roundtrips - * then the response should be sent (through the channel) and true must be returned. + * then the response will be returned which can then be sent via response channel. * - * @param channel The rest channel to sent back the response via {@code channel.sendResponse()} + * @param request * @param credentials The credentials from the prior authentication attempt - * @return false if re-request is not supported/necessary, true otherwise. - * If true is returned {@code channel.sendResponse()} must be called so that the request completes. + * @return null if re-request is not supported/necessary, response object otherwise. + * If an object is returned {@code channel.sendResponse()} must be called so that the request completes. */ - boolean reRequestAuthentication(final RestChannel channel, AuthCredentials credentials); + BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials); } diff --git a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java index 4be83bc2e2..16a822df6d 100644 --- a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java @@ -34,7 +34,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -65,11 +64,10 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "Unauthorized"); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java index b1e5d4ef40..cea59f7311 100644 --- a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java @@ -41,7 +41,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -98,8 +98,8 @@ public AuthCredentials extractCredentials(final RestRequest request, final Threa } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java index a58a842394..1db7b9769b 100644 --- a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java @@ -37,7 +37,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -89,8 +89,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java index c47e850b75..02077bed7c 100644 --- a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java @@ -33,7 +33,7 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.authtoken.jwt.EncryptionDecryptionUtil; @@ -241,8 +241,8 @@ public Boolean isRequestAllowed(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java index ef20374d69..0423fecefe 100644 --- a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java @@ -37,7 +37,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.security.http.HTTPProxyAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -84,11 +83,6 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte return credentials.markComplete(); } - @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; - } - @Override public String getType() { return "extended-proxy"; diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index ff9ec19b09..2594388128 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -47,6 +47,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; @@ -141,7 +142,8 @@ public void basicTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -188,7 +190,8 @@ public void decryptAssertionsTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -236,7 +239,8 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -287,7 +291,8 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -338,7 +343,8 @@ public void shouldNotEscapeSamlEntities() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -389,7 +395,8 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -436,7 +443,8 @@ public void testMetadataBody() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -501,7 +509,8 @@ public void unsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -552,7 +561,8 @@ public void badUnsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); } @@ -584,7 +594,8 @@ public void wrongCertTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -613,7 +624,8 @@ public void noSignatureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -646,7 +658,8 @@ public void rolesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -693,7 +706,8 @@ public void idpEndpointWithQueryStringTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -747,7 +761,8 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -850,7 +865,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); Assert.assertNull(restChannel.response); @@ -870,7 +886,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -893,7 +910,8 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); diff --git a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java index 827bfa24b6..69ddc5c03a 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java @@ -20,28 +20,27 @@ import org.junit.Test; import org.opensearch.common.settings.Settings; -import org.opensearch.security.user.AuthCredentials; + +import java.net.InetAddress; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class AddressBasedRateLimiterTest { - private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; - @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); + AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertTrue(rateLimiter.isBlocked("a")); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); } } diff --git a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java index e42d2bd1b8..a8285c42a7 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java @@ -17,30 +17,31 @@ package org.opensearch.security.auth.limiting; -import java.net.InetAddress; - import org.junit.Test; import org.opensearch.common.settings.Settings; +import org.opensearch.security.user.AuthCredentials; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class UserNameBasedRateLimiterTest { + private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; + @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); + UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertTrue(rateLimiter.isBlocked("a")); } } diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java index 55c2e789c6..37ac45080b 100644 --- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java +++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java @@ -16,7 +16,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -39,8 +39,8 @@ public AuthCredentials extractCredentials(RestRequest request, ThreadContext con } @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials credentials) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } public static long getCount() {