diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1b6d92453..36d3b65ec2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -122,6 +122,24 @@ jobs: arguments: | integrationTest -Dbuild.snapshot=false + backward-compatibility-build: + runs-on: ubuntu-latest + steps: + - uses: actions/setup-java@v3 + with: + distribution: temurin # Temurin is a distribution of adoptium + java-version: 17 + + - name: Checkout Security Repo + uses: actions/checkout@v4 + + - name: Build BWC tests + uses: gradle/gradle-build-action@v2 + with: + cache-disabled: true + arguments: | + -p bwc-test build -x test -x integTest + backward-compatibility: strategy: fail-fast: false diff --git a/build.gradle b/build.gradle index 07068bf2e5..e62021b4e0 100644 --- a/build.gradle +++ b/build.gradle @@ -432,7 +432,7 @@ configurations { force "io.netty:netty-transport-native-unix-common:${versions.netty}" force "org.apache.bcel:bcel:6.7.0" // This line should be removed once Spotbugs is upgraded to 4.7.4 force "com.github.luben:zstd-jni:${versions.zstd}" - force "org.xerial.snappy:snappy-java:1.1.10.4" + force "org.xerial.snappy:snappy-java:1.1.10.5" force "com.google.guava:guava:${guava_version}" } } @@ -496,7 +496,7 @@ dependencies { implementation "io.jsonwebtoken:jjwt-impl:${jjwt_version}" implementation "io.jsonwebtoken:jjwt-jackson:${jjwt_version}" // JSON flattener - implementation ("com.github.wnameless.json:json-base:2.4.2") { + implementation ("com.github.wnameless.json:json-base:2.4.3") { exclude group: "org.glassfish", module: "jakarta.json" exclude group: "com.google.code.gson", module: "gson" exclude group: "org.json", module: "json" @@ -527,9 +527,9 @@ dependencies { runtimeOnly 'com.google.errorprone:error_prone_annotations:2.22.0' runtimeOnly 'com.sun.istack:istack-commons-runtime:4.2.0' runtimeOnly 'jakarta.xml.bind:jakarta.xml.bind-api:4.0.0' - runtimeOnly 'org.ow2.asm:asm:9.5' + runtimeOnly 'org.ow2.asm:asm:9.6' - testImplementation 'org.apache.camel:camel-xmlsecurity:3.21.0' + testImplementation 'org.apache.camel:camel-xmlsecurity:3.21.1' //OpenSAML implementation 'net.shibboleth.utilities:java-support:8.4.0' @@ -561,7 +561,7 @@ dependencies { runtimeOnly 'io.dropwizard.metrics:metrics-core:4.2.19' runtimeOnly 'org.slf4j:slf4j-api:1.7.36' runtimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:${versions.log4j}" - runtimeOnly 'org.xerial.snappy:snappy-java:1.1.10.4' + runtimeOnly 'org.xerial.snappy:snappy-java:1.1.10.5' runtimeOnly 'org.codehaus.woodstox:stax2-api:4.2.1' runtimeOnly "org.glassfish.jaxb:txw2:${jaxb_version}" runtimeOnly 'com.fasterxml.woodstox:woodstox-core:6.5.1' @@ -629,7 +629,7 @@ dependencies { integrationTestImplementation 'junit:junit:4.13.2' integrationTestImplementation "org.opensearch.plugin:reindex-client:${opensearch_version}" integrationTestImplementation "org.opensearch.plugin:percolator-client:${opensearch_version}" - integrationTestImplementation 'commons-io:commons-io:2.13.0' + integrationTestImplementation 'commons-io:commons-io:2.14.0' integrationTestImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" integrationTestImplementation "org.apache.logging.log4j:log4j-jul:${versions.log4j}" integrationTestImplementation 'org.hamcrest:hamcrest:2.2' diff --git a/bwc-test/build.gradle b/bwc-test/build.gradle index 24cc645ba1..6fb7fc2348 100644 --- a/bwc-test/build.gradle +++ b/bwc-test/build.gradle @@ -47,6 +47,7 @@ buildscript { opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") opensearch_group = "org.opensearch" common_utils_version = System.getProperty("common_utils.version", '2.9.0.0-SNAPSHOT') + jackson_version = System.getProperty("jackson_version", "2.15.2") } repositories { mavenLocal() @@ -72,6 +73,9 @@ dependencies { testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" testImplementation "org.opensearch:common-utils:${common_utils_version}" + testImplementation "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + testImplementation "com.fasterxml.jackson.core:jackson-annotations:${jackson_version}" + } loggerUsageCheck.enabled = false diff --git a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java deleted file mode 100644 index 3758b43265..0000000000 --- a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.security.bwc; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -import org.apache.hc.client5.http.auth.AuthScope; -import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; -import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; -import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; -import org.apache.hc.client5.http.nio.AsyncClientConnectionManager; -import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; -import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; -import org.apache.hc.core5.function.Factory; -import org.apache.hc.core5.http.Header; -import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.message.BasicHeader; -import org.apache.hc.core5.http.nio.ssl.TlsStrategy; -import org.apache.hc.core5.reactor.ssl.TlsDetails; -import org.apache.hc.core5.ssl.SSLContextBuilder; -import org.junit.Assume; -import org.junit.Before; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.test.rest.OpenSearchRestTestCase; - -import org.opensearch.Version; -import org.opensearch.common.settings.Settings; -import org.opensearch.test.rest.OpenSearchRestTestCase; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.hasItem; - -import org.opensearch.client.RestClient; -import org.opensearch.client.RestClientBuilder; - -import org.junit.Assert; - -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLEngine; - -public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase { - - private ClusterType CLUSTER_TYPE; - private String CLUSTER_NAME; - - @Before - private void testSetup() { - final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite"); - Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null); - CLUSTER_TYPE = ClusterType.parse(bwcsuiteString); - CLUSTER_NAME = System.getProperty("tests.clustername"); - } - - @Override - protected final boolean preserveClusterUponCompletion() { - return true; - } - - @Override - protected final boolean preserveIndicesUponCompletion() { - return true; - } - - @Override - protected final boolean preserveReposUponCompletion() { - return true; - } - - @Override - protected boolean preserveTemplatesUponCompletion() { - return true; - } - - @Override - protected String getProtocol() { - return "https"; - } - - @Override - protected final Settings restClientSettings() { - return Settings.builder() - .put(super.restClientSettings()) - // increase the timeout here to 90 seconds to handle long waits for a green - // cluster health. the waits for green need to be longer than a minute to - // account for delayed shards - .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") - .build(); - } - - @Override - protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { - RestClientBuilder builder = RestClient.builder(hosts); - configureHttpsClient(builder, settings); - boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); - builder.setStrictDeprecationMode(strictDeprecationMode); - return builder.build(); - } - - protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { - Map headers = ThreadContext.buildDefaultHeaders(settings); - Header[] defaultHeaders = new Header[headers.size()]; - int i = 0; - for (Map.Entry entry : headers.entrySet()) { - defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); - } - builder.setDefaultHeaders(defaultHeaders); - builder.setHttpClientConfigCallback(httpClientBuilder -> { - String userName = Optional.ofNullable(System.getProperty("tests.opensearch.username")) - .orElseThrow(() -> new RuntimeException("user name is missing")); - String password = Optional.ofNullable(System.getProperty("tests.opensearch.password")) - .orElseThrow(() -> new RuntimeException("password is missing")); - BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials(new AuthScope(null, -1), new UsernamePasswordCredentials(userName, password.toCharArray())); - try { - SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build(); - - TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() - .setSslContext(sslContext) - .setTlsVersions(new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "SSLv3" }) - .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) - // See please https://issues.apache.org/jira/browse/HTTPCLIENT-2219 - .setTlsDetailsFactory(new Factory() { - @Override - public TlsDetails create(final SSLEngine sslEngine) { - return new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol()); - } - }) - .build(); - - final AsyncClientConnectionManager cm = PoolingAsyncClientConnectionManagerBuilder.create() - .setTlsStrategy(tlsStrategy) - .build(); - return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(cm); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - } - - public void testBasicBackwardsCompatibility() throws Exception { - String round = System.getProperty("tests.rest.bwcsuite_round"); - - if (round.equals("first") || round.equals("old")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-0/plugins"); - } else if (round.equals("second")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-1/plugins"); - } else if (round.equals("third")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-2/plugins"); - } - } - - @SuppressWarnings("unchecked") - public void testWhoAmI() throws Exception { - Map responseMap = (Map) getAsMap("_plugins/_security/whoami"); - Assert.assertTrue(responseMap.containsKey("dn")); - } - - private enum ClusterType { - OLD, - MIXED, - UPGRADED; - - public static ClusterType parse(String value) { - switch (value) { - case "old_cluster": - return OLD; - case "mixed_cluster": - return MIXED; - case "upgraded_cluster": - return UPGRADED; - default: - throw new AssertionError("unknown cluster type: " + value); - } - } - } - - @SuppressWarnings("unchecked") - private void assertPluginUpgrade(String uri) throws Exception { - Map> responseMap = (Map>) getAsMap(uri).get("nodes"); - for (Map response : responseMap.values()) { - List> plugins = (List>) response.get("plugins"); - Set pluginNames = plugins.stream().map(map -> (String) map.get("name")).collect(Collectors.toSet()); - - final Version minNodeVersion = this.minimumNodeVersion(); - - if (minNodeVersion.major <= 1) { - assertThat(pluginNames, hasItem("opensearch_security")); - } else { - assertThat(pluginNames, hasItem("opensearch-security")); - } - - } - } -} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java b/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java new file mode 100644 index 0000000000..7fe849d5b3 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.bwc; + +public enum ClusterType { + OLD, + MIXED, + UPGRADED; + + public static ClusterType parse(String value) { + switch (value) { + case "old_cluster": + return OLD; + case "mixed_cluster": + return MIXED; + case "upgraded_cluster": + return UPGRADED; + default: + throw new AssertionError("unknown cluster type: " + value); + } + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java new file mode 100644 index 0000000000..1647dbb132 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java @@ -0,0 +1,367 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.security.bwc; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.net.ssl.SSLContext; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.nio.AsyncClientConnectionManager; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.ssl.TlsStrategy; +import org.apache.hc.core5.reactor.ssl.TlsDetails; +import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.Randomness; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.security.bwc.helper.RestHelper; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.Version; + +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; + +public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase { + + private ClusterType CLUSTER_TYPE; + private String CLUSTER_NAME; + + private final String TEST_USER = "user"; + private final String TEST_PASSWORD = "290735c0-355d-4aaf-9b42-1aaa1f2a3cee"; + private final String TEST_ROLE = "test-dls-fls-role"; + private static RestClient testUserRestClient = null; + + @Before + public void testSetup() { + final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite"); + Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null); + CLUSTER_TYPE = ClusterType.parse(bwcsuiteString); + CLUSTER_NAME = System.getProperty("tests.clustername"); + if (testUserRestClient == null) { + testUserRestClient = buildClient( + super.restClientSettings(), + super.getClusterHosts().toArray(new HttpHost[0]), + TEST_USER, + TEST_PASSWORD + ); + } + } + + @Override + protected final boolean preserveClusterUponCompletion() { + return true; + } + + @Override + protected final boolean preserveIndicesUponCompletion() { + return true; + } + + @Override + protected final boolean preserveReposUponCompletion() { + return true; + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + @Override + protected String getProtocol() { + return "https"; + } + + @Override + protected final Settings restClientSettings() { + return Settings.builder() + .put(super.restClientSettings()) + // increase the timeout here to 90 seconds to handle long waits for a green + // cluster health. the waits for green need to be longer than a minute to + // account for delayed shards + .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") + .build(); + } + + protected RestClient buildClient(Settings settings, HttpHost[] hosts, String username, String password) { + RestClientBuilder builder = RestClient.builder(hosts); + configureHttpsClient(builder, settings, username, password); + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) { + String username = Optional.ofNullable(System.getProperty("tests.opensearch.username")) + .orElseThrow(() -> new RuntimeException("user name is missing")); + String password = Optional.ofNullable(System.getProperty("tests.opensearch.password")) + .orElseThrow(() -> new RuntimeException("password is missing")); + return buildClient(super.restClientSettings(), super.getClusterHosts().toArray(new HttpHost[0]), username, password); + } + + private static void configureHttpsClient(RestClientBuilder builder, Settings settings, String userName, String password) { + Map headers = ThreadContext.buildDefaultHeaders(settings); + Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(new AuthScope(null, -1), new UsernamePasswordCredentials(userName, password.toCharArray())); + try { + SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build(); + + TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() + .setSslContext(sslContext) + .setTlsVersions(new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "SSLv3" }) + .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + // See please https://issues.apache.org/jira/browse/HTTPCLIENT-2219 + .setTlsDetailsFactory(sslEngine -> new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol())) + .build(); + + final AsyncClientConnectionManager cm = PoolingAsyncClientConnectionManagerBuilder.create() + .setTlsStrategy(tlsStrategy) + .build(); + return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(cm); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + public void testWhoAmI() throws Exception { + Map responseMap = getAsMap("_plugins/_security/whoami"); + assertThat(responseMap, hasKey("dn")); + } + + public void testBasicBackwardsCompatibility() throws Exception { + String round = System.getProperty("tests.rest.bwcsuite_round"); + + if (round.equals("first") || round.equals("old")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-0/plugins"); + } else if (round.equals("second")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-1/plugins"); + } else if (round.equals("third")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-2/plugins"); + } + } + + /** + * Tests backward compatibility by created a test user and role with DLS, FLS and masked field settings. Ingests + * data into a test index and runs a matchAll query against the same. + */ + public void testDataIngestionAndSearchBackwardsCompatibility() throws Exception { + String round = System.getProperty("tests.rest.bwcsuite_round"); + String index = "test_index"; + if (round.equals("old")) { + createTestRoleIfNotExists(TEST_ROLE); + createUserIfNotExists(TEST_USER, TEST_PASSWORD, TEST_ROLE); + createIndexIfNotExists(index); + } + ingestData(index); + searchMatchAll(index); + } + + public void testNodeStats() throws IOException { + List responses = RestHelper.requestAgainstAllNodes(client(), "GET", "_nodes/stats", null); + responses.forEach(r -> Assert.assertEquals(200, r.getStatusLine().getStatusCode())); + } + + @SuppressWarnings("unchecked") + private void assertPluginUpgrade(String uri) throws Exception { + Map> responseMap = (Map>) getAsMap(uri).get("nodes"); + for (Map response : responseMap.values()) { + List> plugins = (List>) response.get("plugins"); + Set pluginNames = plugins.stream().map(map -> (String) map.get("name")).collect(Collectors.toSet()); + + final Version minNodeVersion = minimumNodeVersion(); + + if (minNodeVersion.major <= 1) { + assertThat(pluginNames, hasItem("opensearch_security")); // With underscore seperator + } else { + assertThat(pluginNames, hasItem("opensearch-security")); // With dash seperator + } + } + } + + /** + * Ingests data into the test index + * @param index index to ingest data into + */ + + private void ingestData(String index) throws IOException { + StringBuilder bulkRequestBody = new StringBuilder(); + ObjectMapper objectMapper = new ObjectMapper(); + int numberOfRequests = Randomness.get().nextInt(10); + while (numberOfRequests-- > 0) { + for (int i = 0; i < Randomness.get().nextInt(100); i++) { + Map> indexRequest = new HashMap<>(); + indexRequest.put("index", new HashMap<>() { + { + put("_index", index); + } + }); + bulkRequestBody.append(objectMapper.writeValueAsString(indexRequest) + "\n"); + bulkRequestBody.append(objectMapper.writeValueAsString(Song.randomSong().asJson()) + "\n"); + } + List responses = RestHelper.requestAgainstAllNodes( + testUserRestClient, + "POST", + "_bulk?refresh=wait_for", + RestHelper.toHttpEntity(bulkRequestBody.toString()) + ); + responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode())); + } + } + + /** + * Runs a matchAll query against the test index + * @param index index to search + */ + private void searchMatchAll(String index) throws IOException { + String matchAllQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; + int numberOfRequests = Randomness.get().nextInt(10); + while (numberOfRequests-- > 0) { + List responses = RestHelper.requestAgainstAllNodes( + testUserRestClient, + "POST", + index + "/_search", + RestHelper.toHttpEntity(matchAllQuery) + ); + responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode())); + } + } + + /** + * Checks if a resource at the specified URL exists + * @param url of the resource to be checked for existence + * @return true if the resource exists, false otherwise + */ + + private boolean resourceExists(String url) throws IOException { + try { + RestHelper.get(adminClient(), url); + return true; + } catch (ResponseException e) { + if (e.getResponse().getStatusLine().getStatusCode() == 404) { + return false; + } else { + throw e; + } + } + } + + /** + * Creates a test role with DLS, FLS and masked field settings on the test index. + */ + private void createTestRoleIfNotExists(String role) throws IOException { + String url = "_plugins/_security/api/roles/" + role; + String roleSettings = "{\n" + + " \"cluster_permissions\": [\n" + + " \"unlimited\"\n" + + " ],\n" + + " \"index_permissions\": [\n" + + " {\n" + + " \"index_patterns\": [\n" + + " \"test_index*\"\n" + + " ],\n" + + " \"dls\": \"{ \\\"bool\\\": { \\\"must\\\": { \\\"match\\\": { \\\"genre\\\": \\\"rock\\\" } } } }\",\n" + + " \"fls\": [\n" + + " \"~lyrics\"\n" + + " ],\n" + + " \"masked_fields\": [\n" + + " \"artist\"\n" + + " ],\n" + + " \"allowed_actions\": [\n" + + " \"read\",\n" + + " \"write\"\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"tenant_permissions\": []\n" + + "}\n"; + Response response = RestHelper.makeRequest(adminClient(), "PUT", url, RestHelper.toHttpEntity(roleSettings)); + + assertThat(response.getStatusLine().getStatusCode(), anyOf(equalTo(200), equalTo(201))); + } + + /** + * Creates a test index if it does not exist already + * @param index index to create + */ + + private void createIndexIfNotExists(String index) throws IOException { + String settings = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"number_of_shards\": 3,\n" + + " \"number_of_replicas\": 1\n" + + " }\n" + + " }\n" + + "}"; + if (!resourceExists(index)) { + Response response = RestHelper.makeRequest(client(), "PUT", index, RestHelper.toHttpEntity(settings)); + assertThat(response.getStatusLine().getStatusCode(), equalTo(200)); + } + } + + /** + * Creates the test user if it does not exist already and maps it to the test role with DLS/FLS settings. + * @param user user to be created + * @param password password for the new user + * @param role roles that the user has to be mapped to + */ + private void createUserIfNotExists(String user, String password, String role) throws IOException { + String url = "_plugins/_security/api/internalusers/" + user; + if (!resourceExists(url)) { + String userSettings = String.format( + Locale.ENGLISH, + "{\n" + " \"password\": \"%s\",\n" + " \"opendistro_security_roles\": [\"%s\"],\n" + " \"backend_roles\": []\n" + "}", + password, + role + ); + Response response = RestHelper.makeRequest(adminClient(), "PUT", url, RestHelper.toHttpEntity(userSettings)); + assertThat(response.getStatusLine().getStatusCode(), equalTo(201)); + } + } + + @AfterClass + public static void cleanUp() throws IOException { + OpenSearchRestTestCase.closeClients(); + IOUtils.close(testUserRestClient); + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java b/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java new file mode 100644 index 0000000000..3cfd2c03e8 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java @@ -0,0 +1,117 @@ +/* +* Copyright OpenSearch Contributors +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +* +*/ +package org.opensearch.security.bwc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.common.Randomness; + +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +public class Song { + + public static final String FIELD_TITLE = "title"; + public static final String FIELD_ARTIST = "artist"; + public static final String FIELD_LYRICS = "lyrics"; + public static final String FIELD_STARS = "stars"; + public static final String FIELD_GENRE = "genre"; + public static final String ARTIST_FIRST = "First artist"; + public static final String ARTIST_STRING = "String"; + public static final String ARTIST_TWINS = "Twins"; + public static final String TITLE_MAGNUM_OPUS = "Magnum Opus"; + public static final String TITLE_SONG_1_PLUS_1 = "Song 1+1"; + public static final String TITLE_NEXT_SONG = "Next song"; + public static final String ARTIST_NO = "No!"; + public static final String TITLE_POISON = "Poison"; + + public static final String ARTIST_YES = "yes"; + + public static final String TITLE_AFFIRMATIVE = "Affirmative"; + + public static final String ARTIST_UNKNOWN = "unknown"; + public static final String TITLE_CONFIDENTIAL = "confidential"; + + public static final String LYRICS_1 = "Very deep subject"; + public static final String LYRICS_2 = "Once upon a time"; + public static final String LYRICS_3 = "giant nonsense"; + public static final String LYRICS_4 = "Much too much"; + public static final String LYRICS_5 = "Little to little"; + public static final String LYRICS_6 = "confidential secret classified"; + + public static final String GENRE_ROCK = "rock"; + public static final String GENRE_JAZZ = "jazz"; + public static final String GENRE_BLUES = "blues"; + + public static final String QUERY_TITLE_NEXT_SONG = FIELD_TITLE + ":" + "\"" + TITLE_NEXT_SONG + "\""; + public static final String QUERY_TITLE_POISON = FIELD_TITLE + ":" + TITLE_POISON; + public static final String QUERY_TITLE_MAGNUM_OPUS = FIELD_TITLE + ":" + TITLE_MAGNUM_OPUS; + + public static final Song[] SONGS = { + new Song(ARTIST_FIRST, TITLE_MAGNUM_OPUS, LYRICS_1, 1, GENRE_ROCK), + new Song(ARTIST_STRING, TITLE_SONG_1_PLUS_1, LYRICS_2, 2, GENRE_BLUES), + new Song(ARTIST_TWINS, TITLE_NEXT_SONG, LYRICS_3, 3, GENRE_JAZZ), + new Song(ARTIST_NO, TITLE_POISON, LYRICS_4, 4, GENRE_ROCK), + new Song(ARTIST_YES, TITLE_AFFIRMATIVE, LYRICS_5, 5, GENRE_BLUES), + new Song(ARTIST_UNKNOWN, TITLE_CONFIDENTIAL, LYRICS_6, 6, GENRE_JAZZ) }; + + private final String artist; + private final String title; + private final String lyrics; + private final Integer stars; + private final String genre; + + public Song(String artist, String title, String lyrics, Integer stars, String genre) { + this.artist = Objects.requireNonNull(artist, "Artist is required"); + this.title = Objects.requireNonNull(title, "Title is required"); + this.lyrics = Objects.requireNonNull(lyrics, "Lyrics is required"); + this.stars = Objects.requireNonNull(stars, "Stars field is required"); + this.genre = Objects.requireNonNull(genre, "Genre field is required"); + } + + public String getArtist() { + return artist; + } + + public String getTitle() { + return title; + } + + public String getLyrics() { + return lyrics; + } + + public Integer getStars() { + return stars; + } + + public String getGenre() { + return genre; + } + + public Map asMap() { + return Map.of(FIELD_ARTIST, artist, FIELD_TITLE, title, FIELD_LYRICS, lyrics, FIELD_STARS, stars, FIELD_GENRE, genre); + } + + public String asJson() throws JsonProcessingException { + return new ObjectMapper().writeValueAsString(this.asMap()); + } + + public static Song randomSong() { + return new Song( + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + Randomness.get().nextInt(5), + UUID.randomUUID().toString() + ); + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java b/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java new file mode 100644 index 0000000000..3272ac736a --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.bwc.helper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; + +import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON; + +public class RestHelper { + + private static final Logger log = LogManager.getLogger(RestHelper.class); + + public static HttpEntity toHttpEntity(String jsonString) { + return new StringEntity(jsonString, APPLICATION_JSON); + } + + public static Response get(RestClient client, String url) throws IOException { + return makeRequest(client, "GET", url, null, null); + } + + public static Response makeRequest(RestClient client, String method, String endpoint, HttpEntity entity) throws IOException { + return makeRequest(client, method, endpoint, entity, null); + } + + public static Response makeRequest(RestClient client, String method, String endpoint, HttpEntity entity, List
headers) + throws IOException { + log.info("Making request " + method + " " + endpoint + ", with headers " + headers); + + Request request = new Request(method, endpoint); + + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(WarningsHandler.PERMISSIVE); + if (headers != null) { + headers.forEach(header -> options.addHeader(header.getName(), header.getValue())); + } + request.setOptions(options.build()); + + if (entity != null) { + request.setEntity(entity); + } + + Response response = client.performRequest(request); + log.info("Recieved response " + response.getStatusLine()); + return response; + } + + public static List requestAgainstAllNodes(RestClient client, String method, String endpoint, HttpEntity entity) + throws IOException { + return requestAgainstAllNodes(client, method, endpoint, entity, null); + } + + public static List requestAgainstAllNodes( + RestClient client, + String method, + String endpoint, + HttpEntity entity, + List
headers + ) throws IOException { + int nodeCount = client.getNodes().size(); + List responses = new ArrayList<>(); + while (nodeCount-- > 0) { + responses.add(makeRequest(client, method, endpoint, entity, headers)); + } + return responses; + } + + public static Header getAuthorizationHeader(String username, String password) { + return new BasicHeader("Authorization", "Basic " + username + ":" + password); + } +} diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java index bb16e0be1b..34e79613f6 100644 --- a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java @@ -12,7 +12,6 @@ import java.util.concurrent.TimeUnit; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,8 +35,8 @@ @RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class IpBruteForceAttacksPreventionTests { - private static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); - private static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); + static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); + static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); public static final int ALLOWED_TRIES = 3; public static final int TIME_WINDOW_SECONDS = 3; @@ -51,7 +50,7 @@ public class IpBruteForceAttacksPreventionTests { public static final String CLIENT_IP_8 = "127.0.0.8"; public static final String CLIENT_IP_9 = "127.0.0.9"; - private static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( + static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( new RateLimiting("internal_authentication_backend_limiting").type("ip") .allowedTries(ALLOWED_TRIES) .timeWindowSeconds(TIME_WINDOW_SECONDS) @@ -60,13 +59,17 @@ public class IpBruteForceAttacksPreventionTests { .maxTrackedClients(500) ); - @ClassRule - public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) - .anonymousAuth(false) - .authFailureListeners(listener) - .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) - .users(USER_1, USER_2) - .build(); + @Rule + public LocalCluster cluster = createCluster(); + + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) + .users(USER_1, USER_2) + .build(); + } @Rule public LogsRule logsRule = new LogsRule("org.opensearch.security.auth.BackendRegistry"); @@ -151,7 +154,7 @@ public void shouldReleaseIpAddressLock() throws InterruptedException { } } - private static void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { + void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { var clientConfiguration = new TestRestClientConfiguration().username(user.getName()) .password("incorrect password") .sourceInetAddress(sourceIpAddress); diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java new file mode 100644 index 0000000000..6159599119 --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.runner.RunWith; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; + +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class IpBruteForceAttacksPreventionWithDomainChallengeTests extends IpBruteForceAttacksPreventionTests { + @Override + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(USER_1, USER_2) + .build(); + } +} diff --git a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java index c09127e592..77890a4645 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java +++ b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java @@ -344,7 +344,7 @@ public String toString() { String clusterManagerNodes = nodeByTypeToString(CLUSTER_MANAGER); String dataNodes = nodeByTypeToString(DATA); String clientNodes = nodeByTypeToString(CLIENT); - return "\nES Cluster " + return "\nOS Cluster " + clusterName + "\ncluster manager nodes: " + clusterManagerNodes diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index 4dab3c7740..6dbb7b7676 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -36,7 +36,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -236,11 +235,10 @@ public String[] extractRoles(JwtClaims claims) { protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception; @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials authCredentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } public String getRequiredAudience() { diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index 03e385d5c0..338a490037 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -31,7 +31,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -171,11 +170,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 29f537e899..99bc23b746 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -49,7 +49,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -280,8 +279,7 @@ public GSSCredential run() throws GSSException { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse; XContentBuilder response = getNegotiateResponseBody(); @@ -291,16 +289,15 @@ public boolean reRequestAuthentication(final RestChannel channel, AuthCredential wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING); } - if (creds == null || creds.getNativeCredentials() == null) { + if (credentials == null || credentials.getNativeCredentials() == null) { wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate"); } else { wwwAuthenticateResponse.addHeader( "WWW-Authenticate", - "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials()) + "Negotiate " + Base64.getEncoder().encodeToString((byte[]) credentials.getNativeCredentials()) ); } - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 3fee4a9444..7210ed5950 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -23,7 +23,6 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPathExpressionException; import com.fasterxml.jackson.core.JsonParseException; @@ -32,7 +31,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.base.Strings; import com.onelogin.saml2.authn.SamlResponse; -import com.onelogin.saml2.exception.SettingsException; import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; @@ -49,7 +47,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; -import org.xml.sax.SAXException; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; @@ -57,7 +54,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; @@ -122,7 +118,7 @@ class AuthTokenProcessorHandler { } @SuppressWarnings("removal") - boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception { + BytesRestResponse handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -130,11 +126,10 @@ boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exceptio sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction() { @Override - public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException, - SAXException, SettingsException { - return handleLowLevel(restRequest, restChannel); + public BytesRestResponse run() throws SamlConfigException, IOException { + return handleLowLevel(restRequest); } }); } catch (PrivilegedActionException e) { @@ -147,13 +142,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc } private AuthTokenProcessorAction.Response handleImpl( - RestRequest restRequest, - RestChannel restChannel, String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings - ) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException { + ) { if (token_log.isDebugEnabled()) { try { token_log.debug( @@ -188,8 +181,7 @@ private AuthTokenProcessorAction.Response handleImpl( } } - private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException, - XPathExpressionException, ParserConfigurationException, SAXException, SettingsException { + private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getMediaType() != XContentType.JSON) { @@ -234,31 +226,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue()); } - AuthTokenProcessorAction.Response responseBody = this.handleImpl( - restRequest, - restChannel, - samlResponseBase64, - samlRequestId, - acsEndpoint, - saml2Settings - ); + AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return false; + return null; } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); - restChannel.sendResponse(authenticateResponse); - - return true; + return new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); - restChannel.sendResponse(authenticateResponse); - return true; + return new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index cd6209952f..64d816fabf 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -56,7 +56,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.Destroyable; @@ -171,13 +170,15 @@ public String getType() { } @Override - public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { try { - RestRequest restRequest = restChannel.request(); - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (API_AUTHTOKEN_SUFFIX.equals(suffix) && this.authTokenProcessorHandler.handle(restRequest, restChannel)) { - return true; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + final BytesRestResponse restResponse = this.authTokenProcessorHandler.handle(request); + if (restResponse != null) { + return restResponse; + } } Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); @@ -185,12 +186,10 @@ public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)); - restChannel.sendResponse(authenticateResponse); - - return true; + return authenticateResponse; } catch (Exception e) { log.error("Error in reRequestAuthentication()", e); - return false; + return null; } } diff --git a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java index 907d605860..f752ce4a49 100755 --- a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java +++ b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java @@ -11,6 +11,7 @@ package com.amazon.dlic.auth.ldap; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -20,6 +21,8 @@ import com.amazon.dlic.auth.ldap.util.Utils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.security.support.WildcardMatcher; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.user.User; @@ -45,6 +48,12 @@ public LdapUser( attributes.putAll(extractLdapAttributes(originalUsername, userEntry, customAttrMaxValueLen, allowlistedCustomLdapAttrMatcher)); } + public LdapUser(StreamInput in) throws IOException { + super(in); + userEntry = null; + originalUsername = in.readString(); + } + /** * May return null because ldapEntry is transient * @@ -88,4 +97,10 @@ public static Map extractLdapAttributes( } return Collections.unmodifiableMap(attributes); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(originalUsername); + } } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java index 804e0a2114..a8f511be97 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java @@ -773,7 +773,8 @@ private TransportAddress getRemoteAddress() { if (address == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) != null) { address = new TransportAddress( (InetSocketAddress) Base64Helper.deserializeObject( - threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) + threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER), + threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ) ); } @@ -784,7 +785,8 @@ private String getUser() { User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); if (user == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) != null) { user = (User) Base64Helper.deserializeObject( - threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) + threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), + threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ); } return user == null ? null : user.getName(); diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 1ad93e5d8b..82c75239ba 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -231,7 +231,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g User authenticatedUser = null; - AuthCredentials authCredenetials = null; + AuthCredentials authCredentials = null; HTTPAuthenticator firstChallengingHttpAuthenticator = null; @@ -273,7 +273,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - authCredenetials = ac; + authCredentials = ac; if (ac == null) { // no credentials found in request @@ -281,12 +281,18 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - if (authDomain.isChallenge() && httpAuthenticator.reRequestAuthentication(channel, null)) { - auditLog.logFailedLogin("", false, null, request); - if (isTraceEnabled) { - log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + if (authDomain.isChallenge()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, null); + if (restResponse != null) { + auditLog.logFailedLogin("", false, null, request); + if (isTraceEnabled) { + log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + } + notifyIpAuthFailureListeners(request, authCredentials); + channel.sendResponse(restResponse); + return false; } - return false; + } else { // no reRequest possible if (isTraceEnabled) { @@ -297,9 +303,12 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } else { org.apache.logging.log4j.ThreadContext.put("user", ac.getUsername()); if (!ac.isComplete()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, ac); // credentials found in request but we need another client challenge - if (httpAuthenticator.reRequestAuthentication(channel, ac)) { + if (restResponse != null) { // auditLog.logFailedLogin(ac.getUsername()+" ", request); --noauditlog + notifyIpAuthFailureListeners(request, ac); + channel.sendResponse(restResponse); return false; } else { // no reRequest possible @@ -377,7 +386,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("User still not authenticated after checking {} auth domains", restAuthDomains.size()); } - if (authCredenetials == null && anonymousAuthEnabled) { + if (authCredentials == null && anonymousAuthEnabled) { final String tenant = Utils.coalesce(request.header("securitytenant"), request.header("security_tenant")); User anonymousUser = new User(User.ANONYMOUS.getName(), new HashSet(User.ANONYMOUS.getRoles()), null); anonymousUser.setRequestedTenant(tenant); @@ -389,6 +398,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } return true; } + BytesRestResponse challengeResponse = null; if (firstChallengingHttpAuthenticator != null) { @@ -396,31 +406,28 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("Rerequest with {}", firstChallengingHttpAuthenticator.getClass()); } - if (firstChallengingHttpAuthenticator.reRequestAuthentication(channel, null)) { + challengeResponse = firstChallengingHttpAuthenticator.reRequestAuthentication(request, null); + if (challengeResponse != null) { if (isDebugEnabled) { log.debug("Rerequest {} failed", firstChallengingHttpAuthenticator.getClass()); } - - log.warn( - "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), - remoteAddress - ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); - return false; } } log.warn( "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), + authCredentials == null ? null : authCredentials.getUsername(), remoteAddress ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); + auditLog.logFailedLogin(authCredentials == null ? null : authCredentials.getUsername(), false, null, request); - notifyIpAuthFailureListeners(request, authCredenetials); + notifyIpAuthFailureListeners(request, authCredentials); - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); + channel.sendResponse( + challengeResponse != null + ? challengeResponse + : new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed") + ); return false; } diff --git a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java index fa5065ef68..70b94d6dbf 100644 --- a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java @@ -28,7 +28,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.user.AuthCredentials; @@ -71,15 +71,14 @@ public interface HTTPAuthenticator { /** * If the {@code extractCredentials()} call was not successful or the authentication flow needs another roundtrip this method - * will be called. If the custom HTTP authenticator does not support this method is a no-op and false should be returned. - * + * will be called. If the custom HTTP authenticator does not support this method is a no-op and null response should be returned. * If the custom HTTP authenticator does support re-request authentication or supports authentication flows with multiple roundtrips - * then the response should be sent (through the channel) and true must be returned. + * then the response will be returned which can then be sent via response channel. * - * @param channel The rest channel to sent back the response via {@code channel.sendResponse()} + * @param request * @param credentials The credentials from the prior authentication attempt - * @return false if re-request is not supported/necessary, true otherwise. - * If true is returned {@code channel.sendResponse()} must be called so that the request completes. + * @return null if re-request is not supported/necessary, response object otherwise. + * If an object is returned {@code channel.sendResponse()} must be called so that the request completes. */ - boolean reRequestAuthentication(final RestChannel channel, AuthCredentials credentials); + BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials); } diff --git a/src/main/java/org/opensearch/security/auth/UserInjector.java b/src/main/java/org/opensearch/security/auth/UserInjector.java index 3e89a52e93..30df84ef5f 100644 --- a/src/main/java/org/opensearch/security/auth/UserInjector.java +++ b/src/main/java/org/opensearch/security/auth/UserInjector.java @@ -26,6 +26,7 @@ package org.opensearch.security.auth; +import java.io.IOException; import java.io.ObjectStreamException; import java.net.InetAddress; import java.net.UnknownHostException; @@ -36,6 +37,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.rest.RestRequest; @@ -63,13 +66,18 @@ public UserInjector(Settings settings, ThreadPool threadPool, AuditLog auditLog, } - static class InjectedUser extends User { + public static class InjectedUser extends User { private transient TransportAddress transportAddress; public InjectedUser(String name) { super(name); } + public InjectedUser(StreamInput in) throws IOException { + super(in); + this.setInjected(true); + } + private Object writeReplace() throws ObjectStreamException { User user = new User(getName()); user.addRoles(getRoles()); @@ -96,6 +104,11 @@ public void setTransportAddress(String addr) throws UnknownHostException, Illega this.transportAddress = new TransportAddress(iAdress, port); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } } public InjectedUser getInjectedUser() { diff --git a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java index 5d3262799f..e68a5ef2d7 100644 --- a/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java +++ b/src/main/java/org/opensearch/security/authtoken/jwt/JwtVendor.java @@ -16,7 +16,6 @@ import java.util.Optional; import java.util.function.LongSupplier; -import com.google.common.base.Strings; import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; import org.apache.cxf.rs.security.jose.jwk.KeyType; @@ -32,6 +31,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.security.ssl.util.ExceptionUtils; +import static org.opensearch.security.util.AuthTokenUtils.isKeyNull; + public class JwtVendor { private static final Logger logger = LogManager.getLogger(JwtVendor.class); @@ -53,7 +54,7 @@ public JwtVendor(final Settings settings, final Optional timeProvi throw ExceptionUtils.createJwkCreationException(e); } this.jwtProducer = jwtProducer; - if (settings.get("encryption_key") == null) { + if (isKeyNull(settings, "encryption_key")) { throw new IllegalArgumentException("encryption_key cannot be null"); } else { this.claimsEncryptionKey = settings.get("encryption_key"); @@ -73,9 +74,8 @@ public JwtVendor(final Settings settings, final Optional timeProvi * Encryption Algorithm: HS512 * */ static JsonWebKey createJwkFromSettings(Settings settings) throws Exception { - String signingKey = settings.get("signing_key"); - - if (!Strings.isNullOrEmpty(signingKey)) { + if (!isKeyNull(settings, "signing_key")) { + String signingKey = settings.get("signing_key"); JsonWebKey jwk = new JsonWebKey(); diff --git a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java index 14eaed4e0d..b35137a35d 100644 --- a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java +++ b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java @@ -443,7 +443,8 @@ private void setDlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) != null) { Object deserializedDlsQueries = Base64Helper.deserializeObject( - threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ); if (!dlsQueries.equals(deserializedDlsQueries)) { throw new OpenSearchSecurityException( @@ -506,7 +507,10 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER) != null) { if (!maskedFieldsMap.equals( - Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER)) + Base64Helper.deserializeObject( + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) + ) )) { throw new OpenSearchSecurityException( ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER + " does not match (SG 901D)" @@ -542,7 +546,10 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) != null) { if (!flsFields.equals( - Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER)) + Base64Helper.deserializeObject( + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) + ) )) { throw new OpenSearchSecurityException( ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER @@ -550,7 +557,8 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) + flsFields + "---" + Base64Helper.deserializeObject( - threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ) ); } else { diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index 06f2fae397..f433a5857d 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -183,6 +183,10 @@ private void ap threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN, Origin.LOCAL.toString()); } + if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) { + threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false); + } + final ComplianceConfig complianceConfig = auditLog.getComplianceConfig(); if (complianceConfig != null && complianceConfig.isEnabled()) { attachSourceFieldContext(request); diff --git a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java index 4be83bc2e2..16a822df6d 100644 --- a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java @@ -34,7 +34,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -65,11 +64,10 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "Unauthorized"); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java index b1e5d4ef40..cea59f7311 100644 --- a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java @@ -41,7 +41,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -98,8 +98,8 @@ public AuthCredentials extractCredentials(final RestRequest request, final Threa } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java index a58a842394..1db7b9769b 100644 --- a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java @@ -37,7 +37,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -89,8 +89,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java index 467edd8ac4..02077bed7c 100644 --- a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java @@ -33,7 +33,7 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.authtoken.jwt.EncryptionDecryptionUtil; @@ -43,13 +43,12 @@ import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX; import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX; +import static org.opensearch.security.util.AuthTokenUtils.isAccessToRestrictedEndpoints; public class OnBehalfOfAuthenticator implements HTTPAuthenticator { private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); - private static final String ON_BEHALF_OF_SUFFIX = "api/generateonbehalfoftoken"; - private static final String ACCOUNT_SUFFIX = "api/account"; protected final Logger log = LogManager.getLogger(this.getClass()); @@ -233,8 +232,7 @@ private void logDebug(String message, Object... args) { public Boolean isRequestAllowed(final RestRequest request) { Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (request.method() == RestRequest.Method.POST && ON_BEHALF_OF_SUFFIX.equals(suffix) - || request.method() == RestRequest.Method.PUT && ACCOUNT_SUFFIX.equals(suffix)) { + if (isAccessToRestrictedEndpoints(request, suffix)) { final OpenSearchException exception = ExceptionUtils.invalidUsageOfOBOTokenException(); log.error(exception.toString()); return false; @@ -243,8 +241,8 @@ public Boolean isRequestAllowed(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java index ef20374d69..0423fecefe 100644 --- a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java @@ -37,7 +37,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.security.http.HTTPProxyAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -84,11 +83,6 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte return credentials.markComplete(); } - @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; - } - @Override public String getType() { return "extended-proxy"; diff --git a/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java b/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java index fcbf985f60..0de83f2e2e 100644 --- a/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java +++ b/src/main/java/org/opensearch/security/securityconf/DynamicConfigModelV7.java @@ -67,6 +67,8 @@ import org.opensearch.security.securityconf.impl.v7.ConfigV7.AuthzDomain; import org.opensearch.security.support.ReflectionHelper; +import static org.opensearch.security.util.AuthTokenUtils.isKeyNull; + public class DynamicConfigModelV7 extends DynamicConfigModel { private final ConfigV7 config; @@ -383,7 +385,7 @@ private void buildAAA() { * order: -1 - prioritize the OBO authentication when it gets enabled */ Settings oboSettings = getDynamicOnBehalfOfSettings(); - if (oboSettings.get("signing_key") != null && oboSettings.get("encryption_key") != null) { + if (!isKeyNull(oboSettings, "signing_key") && !isKeyNull(oboSettings, "encryption_key")) { final AuthDomain _ad = new AuthDomain( new NoOpAuthenticationBackend(Settings.EMPTY, null), new OnBehalfOfAuthenticator(getDynamicOnBehalfOfSettings(), this.cih.getClusterName()), diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 0a1b94548e..c67579e30f 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -83,8 +83,14 @@ protected ThreadContext getThreadContext() { @Override public final void messageReceived(T request, TransportChannel channel, Task task) throws Exception { + ThreadContext threadContext = getThreadContext(); + threadContext.putTransient( + ConfigConstants.USE_JDK_SERIALIZATION, + channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + ); + if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { final Exception exception = ExceptionUtils.createBadHeaderException(); channel.sendResponse(exception); diff --git a/src/main/java/org/opensearch/security/support/Base64CustomHelper.java b/src/main/java/org/opensearch/security/support/Base64CustomHelper.java new file mode 100644 index 0000000000..dc66268fcd --- /dev/null +++ b/src/main/java/org/opensearch/security/support/Base64CustomHelper.java @@ -0,0 +1,225 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.io.BaseEncoding; +import org.opensearch.OpenSearchException; +import org.opensearch.common.Nullable; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.common.Strings; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import java.io.IOException; +import java.io.Serializable; + +import static org.opensearch.security.support.SafeSerializationUtils.prohibitUnsafeClasses; + +/** + * Provides support for Serialization/Deserialization of objects of supported classes into/from Base64 encoded stream + * using the OpenSearch's custom serialization protocol implemented by the StreamInput/StreamOutput classes. + */ +public class Base64CustomHelper { + + private enum CustomSerializationFormat { + + WRITEABLE(1), + STREAMABLE(2), + GENERIC(3); + + private final int id; + + CustomSerializationFormat(int id) { + this.id = id; + } + + static CustomSerializationFormat fromId(int id) { + switch (id) { + case 1: + return WRITEABLE; + case 2: + return STREAMABLE; + case 3: + return GENERIC; + default: + throw new IllegalArgumentException(String.format("%d is not a valid id", id)); + } + } + + } + + private static final BiMap, Integer> writeableClassToIdMap = HashBiMap.create(); + private static final StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + + static { + registerAllWriteables(); + } + + protected static String serializeObject(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + final BytesStreamOutput streamOutput = new SafeBytesStreamOutput(128); + Class clazz = object.getClass(); + try { + prohibitUnsafeClasses(clazz); + CustomSerializationFormat customSerializationFormat = getCustomSerializationMode(clazz); + switch (customSerializationFormat) { + case WRITEABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.WRITEABLE.id); + streamOutput.writeByte((byte) getWriteableClassID(clazz).intValue()); + ((Writeable) object).writeTo(streamOutput); + break; + case STREAMABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.STREAMABLE.id); + streamableRegistry.writeTo(streamOutput, object); + break; + case GENERIC: + streamOutput.writeByte((byte) CustomSerializationFormat.GENERIC.id); + streamOutput.writeGenericValue(object); + break; + default: + throw new IllegalArgumentException( + String.format("Could not determine custom serialization mode for class %s", clazz.getName()) + ); + } + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = streamOutput.bytes().toBytesRef().bytes; + streamOutput.close(); + return BaseEncoding.base64().encode(bytes); + } + + protected static Serializable deserializeObject(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty"); + final byte[] bytes = BaseEncoding.base64().decode(string); + Serializable obj = null; + try (final BytesStreamInput streamInput = new SafeBytesStreamInput(bytes)) { + CustomSerializationFormat serializationFormat = CustomSerializationFormat.fromId(streamInput.readByte()); + switch (serializationFormat) { + case WRITEABLE: + final int classId = streamInput.readByte(); + Class clazz = getWriteableClassFromId(classId); + obj = (Serializable) clazz.getConstructor(StreamInput.class).newInstance(streamInput); + break; + case STREAMABLE: + obj = (Serializable) streamableRegistry.readFrom(streamInput); + break; + case GENERIC: + obj = (Serializable) streamInput.readGenericValue(); + break; + default: + throw new IllegalArgumentException("Could not determine custom deserialization mode"); + } + prohibitUnsafeClasses(obj.getClass()); + return obj; + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + private static boolean isWriteable(Class clazz) { + return Writeable.class.isAssignableFrom(clazz); + } + + /** + * Returns integer ID for the registered Writeable class + *
+ * Protected for testing + */ + protected static Integer getWriteableClassID(Class clazz) { + if (!isWriteable(clazz)) { + throw new OpenSearchException("clazz should implement Writeable ", clazz); + } + if (!writeableClassToIdMap.containsKey(clazz)) { + throw new OpenSearchException("Writeable clazz not registered ", clazz); + } + return writeableClassToIdMap.get(clazz); + } + + private static Class getWriteableClassFromId(int id) { + return writeableClassToIdMap.inverse().get(id); + } + + /** + * Registers the given Writeable class for custom serialization by assigning an incrementing integer ID + * IDs are stored in a HashBiMap + * @param clazz class to be registered + */ + private static void registerWriteable(Class clazz) { + if (writeableClassToIdMap.containsKey(clazz)) { + throw new OpenSearchException("writeable clazz is already registered ", clazz.getName()); + } + int id = writeableClassToIdMap.size() + 1; + writeableClassToIdMap.put(clazz, id); + } + + /** + * Registers all Writeable classes for custom serialization support. + * Removing existing classes / changing order of registration will cause a breaking change in the serialization protocol + * as registerWriteable assigns an incrementing integer ID to each of the classes in the order it is called + * starting from 1. + *
+ * New classes can safely be added towards the end. + */ + private static void registerAllWriteables() { + registerWriteable(User.class); + registerWriteable(LdapUser.class); + registerWriteable(UserInjector.InjectedUser.class); + registerWriteable(SourceFieldsContext.class); + } + + private static CustomSerializationFormat getCustomSerializationMode(Class clazz) { + if (isWriteable(clazz)) { + return CustomSerializationFormat.WRITEABLE; + } else if (streamableRegistry.isStreamable(clazz)) { + return CustomSerializationFormat.STREAMABLE; + } else { + return CustomSerializationFormat.GENERIC; + } + } + + private static class SafeBytesStreamOutput extends BytesStreamOutput { + + public SafeBytesStreamOutput(int expectedSize) { + super(expectedSize); + } + + @Override + public void writeGenericValue(@Nullable Object value) throws IOException { + prohibitUnsafeClasses(value.getClass()); + super.writeGenericValue(value); + } + } + + private static class SafeBytesStreamInput extends BytesStreamInput { + + public SafeBytesStreamInput(byte[] bytes) { + super(bytes); + } + + @Override + public Object readGenericValue() throws IOException { + Object object = super.readGenericValue(); + prohibitUnsafeClasses(object.getClass()); + return object; + } + } +} diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index 836858decb..a5fbab8515 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -26,174 +26,47 @@ package org.opensearch.security.support; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InvalidClassException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.ObjectStreamClass; -import java.io.OutputStream; import java.io.Serializable; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Pattern; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.io.BaseEncoding; -import org.ldaptive.AbstractLdapBean; -import org.ldaptive.LdapAttribute; -import org.ldaptive.LdapEntry; -import org.ldaptive.SearchEntry; - -import com.amazon.dlic.auth.ldap.LdapUser; - -import org.opensearch.OpenSearchException; -import org.opensearch.SpecialPermission; -import org.opensearch.core.common.Strings; -import org.opensearch.security.user.User; public class Base64Helper { - private static final Set> SAFE_CLASSES = ImmutableSet.of( - String.class, - SocketAddress.class, - InetSocketAddress.class, - Pattern.class, - User.class, - SourceFieldsContext.class, - LdapUser.class, - SearchEntry.class, - LdapEntry.class, - AbstractLdapBean.class, - LdapAttribute.class - ); - - private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( - InetAddress.class, - Number.class, - Collection.class, - Map.class, - Enum.class - ); - - private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); - - private static boolean isSafeClass(Class cls) { - return cls.isArray() - || SAFE_CLASSES.contains(cls) - || SAFE_CLASS_NAMES.contains(cls.getName()) - || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); - } - - private final static class SafeObjectOutputStream extends ObjectOutputStream { - - private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); - - @SuppressWarnings("removal") - private static boolean checkSubstitutionPermission() { - SecurityManager sm = System.getSecurityManager(); - if (sm != null) { - try { - sm.checkPermission(new SpecialPermission()); - - AccessController.doPrivileged((PrivilegedAction) () -> { - AccessController.checkPermission(SUBSTITUTION_PERMISSION); - return null; - }); - } catch (SecurityException e) { - return false; - } - } - return true; - } - - static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { - try { - return useSafeObjectOutputStream ? new SafeObjectOutputStream(out) : new ObjectOutputStream(out); - } catch (SecurityException e) { - // As we try to create SafeObjectOutputStream only when necessary permissions are granted, we should - // not reach here, but if we do, we can still return ObjectOutputStream after resetting ByteArrayOutputStream - out.reset(); - return new ObjectOutputStream(out); - } - } - - @SuppressWarnings("removal") - private SafeObjectOutputStream(OutputStream out) throws IOException { - super(out); - - SecurityManager sm = System.getSecurityManager(); - if (sm != null) { - sm.checkPermission(new SpecialPermission()); - } - - AccessController.doPrivileged((PrivilegedAction) () -> enableReplaceObject(true)); - } - - @Override - protected Object replaceObject(Object obj) throws IOException { - Class clazz = obj.getClass(); - if (isSafeClass(clazz)) { - return obj; - } - throw new IOException("Unauthorized serialization attempt " + clazz.getName()); - } + public static String serializeObject(final Serializable object, final boolean useJDKSerialization) { + return useJDKSerialization ? Base64JDKHelper.serializeObject(object) : Base64CustomHelper.serializeObject(object); } public static String serializeObject(final Serializable object) { - - Preconditions.checkArgument(object != null, "object must not be null"); - - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try (final ObjectOutputStream out = SafeObjectOutputStream.create(bos)) { - out.writeObject(object); - } catch (final Exception e) { - throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); - } - final byte[] bytes = bos.toByteArray(); - return BaseEncoding.base64().encode(bytes); + return serializeObject(object, false); } public static Serializable deserializeObject(final String string) { - - Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty"); - - final byte[] bytes = BaseEncoding.base64().decode(string); - final ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - try (SafeObjectInputStream in = new SafeObjectInputStream(bis)) { - return (Serializable) in.readObject(); - } catch (final Exception e) { - throw new OpenSearchException(e); - } + return deserializeObject(string, false); } - private final static class SafeObjectInputStream extends ObjectInputStream { - - public SafeObjectInputStream(InputStream in) throws IOException { - super(in); - } - - @Override - protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - - Class clazz = super.resolveClass(desc); - if (isSafeClass(clazz)) { - return clazz; - } + public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) { + return useJDKDeserialization ? Base64JDKHelper.deserializeObject(string) : Base64CustomHelper.deserializeObject(string); + } - throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); + /** + * Ensures that the returned string is JDK serialized. + * + * If the supplied string is a custom serialized representation, will deserialize it and further serialize using + * JDK, otherwise returns the string as is. + * + * @param string original string, can be JDK or custom serialized + * @return jdk serialized string + */ + public static String ensureJDKSerialized(final String string) { + Serializable serializable; + try { + serializable = Base64Helper.deserializeObject(string, false); + } catch (Exception e) { + // We received an exception when de-serializing the given string. It is probably JDK serialized. + // Try to deserialize using JDK + Base64Helper.deserializeObject(string, true); + // Since we could deserialize the object using JDK, the string is already JDK serialized, return as is + return string; } + // If we see an exception now, we want the caller to see it - + return Base64Helper.serializeObject(serializable, true); } } diff --git a/src/main/java/org/opensearch/security/support/Base64JDKHelper.java b/src/main/java/org/opensearch/security/support/Base64JDKHelper.java new file mode 100644 index 0000000000..a4ab87d813 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/Base64JDKHelper.java @@ -0,0 +1,156 @@ +/* + * Copyright 2015-2018 _floragunn_ GmbH + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidClassException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.io.OutputStream; +import java.io.Serializable; +import java.security.AccessController; +import java.security.PrivilegedAction; + +import com.google.common.base.Preconditions; +import com.google.common.io.BaseEncoding; + +import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; +import org.opensearch.core.common.Strings; + +import static org.opensearch.security.support.SafeSerializationUtils.isSafeClass; + +/** + * Provides support for Serialization/Deserialization of objects of supported classes into/from Base64 encoded stream + * using JDK's in-built serialization protocol implemented by the ObjectOutputStream and ObjectInputStream classes. + */ +public class Base64JDKHelper { + + private final static class SafeObjectOutputStream extends ObjectOutputStream { + + private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); + + @SuppressWarnings("removal") + private static boolean checkSubstitutionPermission() { + SecurityManager sm = System.getSecurityManager(); + if (sm != null) { + try { + sm.checkPermission(new SpecialPermission()); + + AccessController.doPrivileged((PrivilegedAction) () -> { + AccessController.checkPermission(SUBSTITUTION_PERMISSION); + return null; + }); + } catch (SecurityException e) { + return false; + } + } + return true; + } + + static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { + try { + return useSafeObjectOutputStream ? new SafeObjectOutputStream(out) : new ObjectOutputStream(out); + } catch (SecurityException e) { + // As we try to create SafeObjectOutputStream only when necessary permissions are granted, we should + // not reach here, but if we do, we can still return ObjectOutputStream after resetting ByteArrayOutputStream + out.reset(); + return new ObjectOutputStream(out); + } + } + + @SuppressWarnings("removal") + private SafeObjectOutputStream(OutputStream out) throws IOException { + super(out); + + SecurityManager sm = System.getSecurityManager(); + if (sm != null) { + sm.checkPermission(new SpecialPermission()); + } + + AccessController.doPrivileged((PrivilegedAction) () -> enableReplaceObject(true)); + } + + @Override + protected Object replaceObject(Object obj) throws IOException { + Class clazz = obj.getClass(); + if (isSafeClass(clazz)) { + return obj; + } + throw new IOException("Unauthorized serialization attempt " + clazz.getName()); + } + } + + public static String serializeObject(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (final ObjectOutputStream out = SafeObjectOutputStream.create(bos)) { + out.writeObject(object); + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = bos.toByteArray(); + return BaseEncoding.base64().encode(bytes); + } + + public static Serializable deserializeObject(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty"); + + final byte[] bytes = BaseEncoding.base64().decode(string); + final ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + try (SafeObjectInputStream in = new SafeObjectInputStream(bis)) { + return (Serializable) in.readObject(); + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + private final static class SafeObjectInputStream extends ObjectInputStream { + + public SafeObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + + Class clazz = super.resolveClass(desc); + if (isSafeClass(clazz)) { + return clazz; + } + + throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); + } + } +} diff --git a/src/main/java/org/opensearch/security/support/ConfigConstants.java b/src/main/java/org/opensearch/security/support/ConfigConstants.java index 8317d65335..9ac73cd579 100644 --- a/src/main/java/org/opensearch/security/support/ConfigConstants.java +++ b/src/main/java/org/opensearch/security/support/ConfigConstants.java @@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.opensearch.Version; import org.opensearch.common.settings.Settings; import org.opensearch.security.auditlog.impl.AuditCategory; @@ -242,6 +243,7 @@ public class ConfigConstants { "opendistro_security.compliance.history.write.ignore_users"; public static final String OPENDISTRO_SECURITY_COMPLIANCE_HISTORY_EXTERNAL_CONFIG_ENABLED = "opendistro_security.compliance.history.external_config_enabled"; + public static final String OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT = OPENDISTRO_SECURITY_CONFIG_PREFIX + "source_field_context"; public static final String SECURITY_COMPLIANCE_DISABLE_ANONYMOUS_AUTHENTICATION = "plugins.security.compliance.disable_anonymous_authentication"; public static final String SECURITY_COMPLIANCE_IMMUTABLE_INDICES = "plugins.security.compliance.immutable_indices"; @@ -323,6 +325,9 @@ public enum RolesMappingResolution { public static final String TENANCY_GLOBAL_TENANT_NAME = "global"; public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = ""; + public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; + public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_3_0_0; + // On-behalf-of endpoints settings // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings public static final String EXTENSIONS_BWC_PLUGIN_MODE = "bwcPluginMode"; diff --git a/src/main/java/org/opensearch/security/support/HeaderHelper.java b/src/main/java/org/opensearch/security/support/HeaderHelper.java index e8d50346a8..bbb44664fa 100644 --- a/src/main/java/org/opensearch/security/support/HeaderHelper.java +++ b/src/main/java/org/opensearch/security/support/HeaderHelper.java @@ -27,6 +27,8 @@ package org.opensearch.security.support; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; import com.google.common.base.Strings; @@ -68,7 +70,7 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context final String objectAsBase64 = getSafeFromHeader(context, headerName); if (!Strings.isNullOrEmpty(objectAsBase64)) { - return Base64Helper.deserializeObject(objectAsBase64); + return Base64Helper.deserializeObject(objectAsBase64, context.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } return null; @@ -77,4 +79,16 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context public static boolean isTrustedClusterRequest(final ThreadContext context) { return context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_TRANSPORT_TRUSTED_CLUSTER_REQUEST) == Boolean.TRUE; } + + public static List getAllSerializedHeaderNames() { + return Arrays.asList( + ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT + ); + } } diff --git a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java new file mode 100644 index 0000000000..c980959f68 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.ldaptive.AbstractLdapBean; +import org.ldaptive.LdapAttribute; +import org.ldaptive.LdapEntry; +import org.ldaptive.SearchEntry; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * Provides functionality to verify if a class is categorised to be safe for serialization or + * deserialization by the security plugin. + *
+ * All methods are package private. + */ +public final class SafeSerializationUtils { + + private static final Set> SAFE_CLASSES = ImmutableSet.of( + String.class, + SocketAddress.class, + InetSocketAddress.class, + Pattern.class, + User.class, + UserInjector.InjectedUser.class, + SourceFieldsContext.class, + LdapUser.class, + SearchEntry.class, + LdapEntry.class, + AbstractLdapBean.class, + LdapAttribute.class + ); + + private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( + InetAddress.class, + Number.class, + Collection.class, + Map.class, + Enum.class + ); + + private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); + + static boolean isSafeClass(Class cls) { + return cls.isArray() + || SAFE_CLASSES.contains(cls) + || SAFE_CLASS_NAMES.contains(cls.getName()) + || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); + } + + static void prohibitUnsafeClasses(Class clazz) throws IOException { + if (!isSafeClass(clazz)) { + throw new IOException("Unauthorized serialization attempt " + clazz.getName()); + } + } + +} diff --git a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java index 02f0ad9226..83bbb683e9 100644 --- a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java +++ b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java @@ -26,13 +26,18 @@ package org.opensearch.security.support; +import java.io.IOException; import java.io.Serializable; import java.util.Arrays; +import java.util.Objects; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; -public class SourceFieldsContext implements Serializable { +public class SourceFieldsContext implements Serializable, Writeable { private String[] includes; private String[] excludes; @@ -77,6 +82,18 @@ public SourceFieldsContext(SearchRequest request) { // } } + public SourceFieldsContext(StreamInput in) throws IOException { + includes = in.readStringArray(); + if (includes.length == 0) { + includes = null; + } + excludes = in.readStringArray(); + if (excludes.length == 0) { + excludes = null; + } + fetchSource = in.readBoolean(); + } + public SourceFieldsContext(GetRequest request) { if (request.fetchSourceContext() != null) { includes = request.fetchSourceContext().includes(); @@ -117,4 +134,11 @@ public String toString() { + fetchSource + "]"; } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeStringArray(Objects.requireNonNullElseGet(includes, () -> new String[] {})); + streamOutput.writeStringArray(Objects.requireNonNullElseGet(excludes, () -> new String[] {})); + streamOutput.writeBoolean(fetchSource); + } } diff --git a/src/main/java/org/opensearch/security/support/StreamableRegistry.java b/src/main/java/org/opensearch/security/support/StreamableRegistry.java new file mode 100644 index 0000000000..bfde866376 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/StreamableRegistry.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; + +import org.opensearch.OpenSearchException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Registry for any class that does NOT implement the Writeable interface + * and needs to be serialized over the wire. Supports registration of writer and reader via registerStreamable + * for such classes and provides methods writeTo and readFrom for objects of such registered classes. + *
+ * Methods are protected and intended to be accessed from only within the package. (mostly by Base64Helper) + */ +public class StreamableRegistry { + + private static final StreamableRegistry INSTANCE = new StreamableRegistry(); + public final BiMap, Integer> classToIdMap = HashBiMap.create(); + private final Map idToEntryMap = new HashMap<>(); + + private StreamableRegistry() { + registerAllStreamables(); + } + + private static class Entry { + Writeable.Writer writer; + Writeable.Reader reader; + + Entry(Writeable.Writer writer, Writeable.Reader reader) { + this.writer = writer; + this.reader = reader; + } + } + + private Writeable.Writer getWriter(Class clazz) { + if (!classToIdMap.containsKey(clazz)) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return idToEntryMap.get(classToIdMap.get(clazz)).writer; + } + + private Writeable.Reader getReader(int id) { + if (!idToEntryMap.containsKey(id)) { + throw new OpenSearchException(String.format("No reader registered for id %s", id)); + } + return idToEntryMap.get(id).reader; + } + + private int getId(Class clazz) { + if (!classToIdMap.containsKey(clazz)) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return classToIdMap.get(clazz); + } + + protected boolean isStreamable(Class clazz) { + return classToIdMap.containsKey(clazz); + } + + protected void writeTo(StreamOutput out, Object object) throws IOException { + out.writeByte((byte) getId(object.getClass())); + getWriter(object.getClass()).write(out, object); + } + + protected Object readFrom(StreamInput in) throws IOException { + int id = in.readByte(); + return getReader(id).read(in); + } + + protected static StreamableRegistry getInstance() { + return INSTANCE; + } + + protected void registerStreamable(int streamableId, Class clazz, Writeable.Writer writer, Writeable.Reader reader) { + if (Writeable.class.isAssignableFrom(clazz)) { + throw new IllegalArgumentException( + String.format("%s is Writeable and should not be registered as a streamable", clazz.getName()) + ); + } + classToIdMap.put(clazz, streamableId); + idToEntryMap.put(streamableId, new Entry(writer, reader)); + } + + protected int getStreamableID(Class clazz) { + if (!isStreamable(clazz)) { + throw new OpenSearchException(String.format("class %s is in streamable registry", clazz.getName())); + } else { + return classToIdMap.get(clazz); + } + } + + /** + * Register all streamables here. + *
+ * Caution - Register new streamables towards the end. Removing / reordering a registered streamable will change the typeIDs associated with the streamables + * causing a breaking change in the serialization format. + */ + private void registerAllStreamables() { + + // InetSocketAddress + this.registerStreamable(1, InetSocketAddress.class, (o, v) -> { + final InetSocketAddress inetSocketAddress = (InetSocketAddress) v; + o.writeString(inetSocketAddress.getHostString()); + o.writeByteArray(inetSocketAddress.getAddress().getAddress()); + o.writeInt(inetSocketAddress.getPort()); + }, i -> { + String host = i.readString(); + byte[] addressBytes = i.readByteArray(); + int port = i.readInt(); + return new InetSocketAddress(InetAddress.getByAddress(host, addressBytes), port); + }); + } + +} diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index 0c645c9a00..f064f0af04 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -59,6 +59,7 @@ import org.opensearch.security.ssl.transport.SSLConfig; import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.HeaderHelper; import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; @@ -147,6 +148,7 @@ public void sendRequestDecorate( final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); + final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { @@ -224,9 +226,26 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL ); } + if (useJDKSerialization) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); + headerMap.putAll(jdkSerializedHeaders); + } + getThreadContext().putHeader(headerMap); - ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString, isSameNodeRequest); + ensureCorrectHeaders( + remoteAddress0, + user0, + origin0, + injectedUserString, + injectedRolesString, + isSameNodeRequest, + useJDKSerialization + ); if (isActionTraceEnabled()) { getThreadContext().putHeader( @@ -253,7 +272,8 @@ private void ensureCorrectHeaders( final String origin, final String injectedUserString, final String injectedRolesString, - boolean isSameNodeRequest + final boolean isSameNodeRequest, + final boolean useJDKSerialization ) { // keep original address @@ -294,7 +314,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE if (transportAddress != null) { getThreadContext().putHeader( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, - Base64Helper.serializeObject(transportAddress.address()) + Base64Helper.serializeObject(transportAddress.address(), useJDKSerialization) ); } @@ -302,7 +322,10 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE if (userHeader == null) { // put as headers for other requests if (origUser != null) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser)); + getThreadContext().putHeader( + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + Base64Helper.serializeObject(origUser, useJDKSerialization) + ); } else if (StringUtils.isNotEmpty(injectedRolesString)) { getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString); } else if (StringUtils.isNotEmpty(injectedUserString)) { diff --git a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java index 1284ca9781..3ba379dd67 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java +++ b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java @@ -107,6 +107,8 @@ protected void messageReceivedDecorate( resolvedActionClass = ((ConcreteShardRequest) request).getRequest().getClass().getSimpleName(); } + final boolean useJDKSerialization = getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION); + String initialActionClassValue = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER); final ThreadContext.StoredContext sgContext = getThreadContext().newStoredContext(false); @@ -181,7 +183,7 @@ protected void messageReceivedDecorate( } else { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_USER, - Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader)) + Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader, useJDKSerialization)) ); } @@ -190,7 +192,7 @@ protected void messageReceivedDecorate( if (!Strings.isNullOrEmpty(originalRemoteAddress)) { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, - new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress)) + new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress, useJDKSerialization)) ); } else { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress()); diff --git a/src/main/java/org/opensearch/security/user/User.java b/src/main/java/org/opensearch/security/user/User.java index 2642b368d7..394b251271 100644 --- a/src/main/java/org/opensearch/security/user/User.java +++ b/src/main/java/org/opensearch/security/user/User.java @@ -83,6 +83,9 @@ public User(final StreamInput in) throws IOException { name = in.readString(); roles.addAll(in.readList(StreamInput::readString)); requestedTenant = in.readString(); + if (requestedTenant.isEmpty()) { + requestedTenant = null; + } attributes = Collections.synchronizedMap(in.readMap(StreamInput::readString, StreamInput::readString)); securityRoles.addAll(in.readList(StreamInput::readString)); } @@ -167,9 +170,9 @@ public final boolean isUserInRole(final String role) { } /** - * Associate this user with a set of backend roles + * Associate this user with a set of custom attributes * - * @param roles The backend roles + * @param attributes custom attributes */ public final void addAttributes(final Map attributes) { if (attributes != null) { @@ -255,7 +258,7 @@ public final void copyRolesFrom(final User user) { public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeStringCollection(new ArrayList(roles)); - out.writeString(requestedTenant); + out.writeString(requestedTenant == null ? "" : requestedTenant); out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); out.writeStringCollection(securityRoles == null ? Collections.emptyList() : new ArrayList(securityRoles)); } diff --git a/src/main/java/org/opensearch/security/util/AuthTokenUtils.java b/src/main/java/org/opensearch/security/util/AuthTokenUtils.java new file mode 100644 index 0000000000..30f331d3a7 --- /dev/null +++ b/src/main/java/org/opensearch/security/util/AuthTokenUtils.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.util; + +import org.opensearch.common.settings.Settings; +import org.opensearch.rest.RestRequest; + +import static org.opensearch.rest.RestRequest.Method.POST; +import static org.opensearch.rest.RestRequest.Method.PUT; + +public class AuthTokenUtils { + private static final String ON_BEHALF_OF_SUFFIX = "api/generateonbehalfoftoken"; + private static final String ACCOUNT_SUFFIX = "api/account"; + + public static Boolean isAccessToRestrictedEndpoints(final RestRequest request, final String suffix) { + if (suffix == null) { + return false; + } + switch (suffix) { + case ON_BEHALF_OF_SUFFIX: + return request.method() == POST; + case ACCOUNT_SUFFIX: + return request.method() == PUT; + default: + return false; + } + } + + public static Boolean isKeyNull(Settings settings, String key) { + return settings.get(key) == null; + } +} diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index ff9ec19b09..2594388128 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -47,6 +47,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; @@ -141,7 +142,8 @@ public void basicTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -188,7 +190,8 @@ public void decryptAssertionsTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -236,7 +239,8 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -287,7 +291,8 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -338,7 +343,8 @@ public void shouldNotEscapeSamlEntities() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -389,7 +395,8 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -436,7 +443,8 @@ public void testMetadataBody() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -501,7 +509,8 @@ public void unsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -552,7 +561,8 @@ public void badUnsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); } @@ -584,7 +594,8 @@ public void wrongCertTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -613,7 +624,8 @@ public void noSignatureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -646,7 +658,8 @@ public void rolesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -693,7 +706,8 @@ public void idpEndpointWithQueryStringTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -747,7 +761,8 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -850,7 +865,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); Assert.assertNull(restChannel.response); @@ -870,7 +886,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -893,7 +910,8 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); diff --git a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java index 827bfa24b6..69ddc5c03a 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java @@ -20,28 +20,27 @@ import org.junit.Test; import org.opensearch.common.settings.Settings; -import org.opensearch.security.user.AuthCredentials; + +import java.net.InetAddress; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class AddressBasedRateLimiterTest { - private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; - @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); + AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertTrue(rateLimiter.isBlocked("a")); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); } } diff --git a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java index e42d2bd1b8..a8285c42a7 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java @@ -17,30 +17,31 @@ package org.opensearch.security.auth.limiting; -import java.net.InetAddress; - import org.junit.Test; import org.opensearch.common.settings.Settings; +import org.opensearch.security.user.AuthCredentials; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class UserNameBasedRateLimiterTest { + private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; + @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); + UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertTrue(rateLimiter.isBlocked("a")); } } diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java new file mode 100644 index 0000000000..d563308e31 --- /dev/null +++ b/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.authtoken.jwt; + +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.RestRequest; +import org.opensearch.security.util.AuthTokenUtils; +import org.opensearch.test.rest.FakeRestRequest; +import org.junit.Test; + +import java.util.Collections; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class AuthTokenUtilsTest { + + @Test + public void testIsAccessToRestrictedEndpointsForOnBehalfOfToken() { + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList()); + + FakeRestRequest request = new FakeRestRequest.Builder(namedXContentRegistry).withPath("/api/generateonbehalfoftoken") + .withMethod(RestRequest.Method.POST) + .build(); + + assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/generateonbehalfoftoken")); + } + + @Test + public void testIsAccessToRestrictedEndpointsForAccount() { + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList()); + + FakeRestRequest request = new FakeRestRequest.Builder(namedXContentRegistry).withPath("/api/account") + .withMethod(RestRequest.Method.PUT) + .build(); + + assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/account")); + } + + @Test + public void testIsAccessToRestrictedEndpointsFalseCase() { + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList()); + + FakeRestRequest request = new FakeRestRequest.Builder(namedXContentRegistry).withPath("/api/someotherendpoint") + .withMethod(RestRequest.Method.GET) + .build(); + + assertFalse(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/someotherendpoint")); + } + + @Test + public void testIsKeyNullWithNullValue() { + Settings settings = Settings.builder().put("someKey", (String) null).build(); + assertTrue(AuthTokenUtils.isKeyNull(settings, "someKey")); + } + + @Test + public void testIsKeyNullWithNonNullValue() { + Settings settings = Settings.builder().put("someKey", "value").build(); + assertFalse(AuthTokenUtils.isKeyNull(settings, "someKey")); + } + + @Test + public void testIsKeyNullWithAbsentKey() { + Settings settings = Settings.builder().build(); + assertTrue(AuthTokenUtils.isKeyNull(settings, "absentKey")); + } +} diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java index 55c2e789c6..37ac45080b 100644 --- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java +++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java @@ -16,7 +16,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -39,8 +39,8 @@ public AuthCredentials extractCredentials(RestRequest request, ThreadContext con } @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials credentials) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } public static long getCount() { diff --git a/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java b/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java new file mode 100644 index 0000000000..e35e1d72ba --- /dev/null +++ b/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import org.junit.Assert; +import org.junit.Test; +import org.ldaptive.LdapEntry; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.user.User; + +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.HashMap; + +import static org.opensearch.security.support.Base64CustomHelper.deserializeObject; +import static org.opensearch.security.support.Base64CustomHelper.serializeObject; + +public class Base64CustomHelperTest { + + private static final class NotSafeStreamable implements Serializable { + private static final long serialVersionUID = 5135559266828470092L; + } + + private static final class NotSafeWriteable implements Writeable, Serializable { + @Override + public void writeTo(StreamOutput out) { + + } + } + + private static Serializable ds(Serializable s) { + return deserializeObject(serializeObject(s)); + } + + @Test + public void testString() { + String string = "string"; + Assert.assertEquals(string, ds(string)); + } + + @Test + public void testInteger() { + Integer integer = 0; + Assert.assertEquals(integer, ds(integer)); + } + + @Test + public void testDouble() { + Double number = 0.; + Assert.assertEquals(number, ds(number)); + } + + @Test + public void testInetSocketAddress() { + InetSocketAddress inetSocketAddress = new InetSocketAddress(0); + Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); + } + + @Test + public void testUser() { + User user = new User("user"); + Assert.assertEquals(user, ds(user)); + } + + @Test + public void testSourceFieldsContext() { + SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); + Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + } + + @Test + public void testHashMap() { + HashMap map = new HashMap<>() { + { + put("key", "value"); + } + }; + Assert.assertEquals(map, ds(map)); + } + + @Test + public void testArrayList() { + ArrayList list = new ArrayList<>() { + { + add("value"); + } + }; + Assert.assertEquals(list, ds(list)); + } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + } + + @Test + public void testGetWriteableClassID() { + // a need to make a change in this test signifies a breaking change in security plugin's custom serialization + // format + Assert.assertEquals(Integer.valueOf(1), Base64CustomHelper.getWriteableClassID(User.class)); + Assert.assertEquals(Integer.valueOf(2), Base64CustomHelper.getWriteableClassID(LdapUser.class)); + Assert.assertEquals(Integer.valueOf(3), Base64CustomHelper.getWriteableClassID(UserInjector.InjectedUser.class)); + Assert.assertEquals(Integer.valueOf(4), Base64CustomHelper.getWriteableClassID(SourceFieldsContext.class)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // for custom serialization, we expect InjectedUser to be returned on deserialization + UserInjector.InjectedUser deserializedInjecteduser = (UserInjector.InjectedUser) ds(injectedUser); + Assert.assertEquals(injectedUser, deserializedInjecteduser); + Assert.assertTrue(deserializedInjecteduser.isInjected()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeStreamable() { + Base64JDKHelper.serializeObject(new NotSafeStreamable()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeWriteable() { + Base64JDKHelper.serializeObject(new NotSafeWriteable()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeGeneric() { + HashMap map = new HashMap<>(); + map.put(1, ZonedDateTime.now()); + map.put(2, ZonedDateTime.now()); + Base64JDKHelper.serializeObject(map); + } + +} diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 81c2505985..f55581c7e7 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -10,100 +10,44 @@ */ package org.opensearch.security.support; -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; -import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.regex.Pattern; -import com.google.common.io.BaseEncoding; import org.junit.Assert; import org.junit.Test; -import org.opensearch.OpenSearchException; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.security.user.User; - import static org.opensearch.security.support.Base64Helper.deserializeObject; import static org.opensearch.security.support.Base64Helper.serializeObject; public class Base64HelperTest { - private static final class NotSafeSerializable implements Serializable { - private static final long serialVersionUID = 5135559266828470092L; + private static Serializable dsJDK(Serializable s) { + return deserializeObject(serializeObject(s, true), true); } private static Serializable ds(Serializable s) { return deserializeObject(serializeObject(s)); } + /** + * Just one sanity test comprising invocation of JDK and Custom Serialization. + * + * Individual scenarios are covered by Base64CustomHelperTest and Base64JDKHelperTest + */ @Test - public void testString() { - String string = "string"; - Assert.assertEquals(string, ds(string)); - } - - @Test - public void testInteger() { - Integer integer = Integer.valueOf(0); - Assert.assertEquals(integer, ds(integer)); - } - - @Test - public void testDouble() { - Double number = Double.valueOf(0.); - Assert.assertEquals(number, ds(number)); - } - - @Test - public void testInetSocketAddress() { - InetSocketAddress inetSocketAddress = new InetSocketAddress(0); - Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); - } - - @Test - public void testPattern() { - Pattern pattern = Pattern.compile(".*"); - Assert.assertEquals(pattern.pattern(), ((Pattern) ds(pattern)).pattern()); - } - - @Test - public void testUser() { - User user = new User("user"); - Assert.assertEquals(user, ds(user)); - } - - @Test - public void testSourceFieldsContext() { - SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); - Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); - } - - @Test - public void testHashMap() { - HashMap map = new HashMap(); - Assert.assertEquals(map, ds(map)); + public void testSerde() { + String test = "string"; + Assert.assertEquals(test, ds(test)); + Assert.assertEquals(test, dsJDK(test)); } @Test - public void testArrayList() { - ArrayList list = new ArrayList(); - Assert.assertEquals(list, ds(list)); - } + public void testEnsureJDKSerialized() { + String test = "string"; + String jdkSerialized = Base64Helper.serializeObject(test, true); + String customSerialized = Base64Helper.serializeObject(test, false); + Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized)); + Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized)); - @Test(expected = OpenSearchException.class) - public void notSafeSerializable() { - serializeObject(new NotSafeSerializable()); } - @Test(expected = OpenSearchException.class) - public void notSafeDeserializable() throws Exception { - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try (final ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeObject(new NotSafeSerializable()); - } - deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); - } } diff --git a/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java b/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java new file mode 100644 index 0000000000..704f1dc1d7 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.io.BaseEncoding; +import org.junit.Assert; +import org.junit.Test; +import org.ldaptive.LdapEntry; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.user.User; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashMap; + +public class Base64JDKHelperTest { + private static final class NotSafeSerializable implements Serializable { + private static final long serialVersionUID = 5135559266828470092L; + } + + private static Serializable ds(Serializable s) { + return Base64JDKHelper.deserializeObject(Base64JDKHelper.serializeObject(s)); + } + + @Test + public void testString() { + String string = "string"; + Assert.assertEquals(string, ds(string)); + } + + @Test + public void testInteger() { + Integer integer = 0; + Assert.assertEquals(integer, ds(integer)); + } + + @Test + public void testDouble() { + Double number = 0.0; + Assert.assertEquals(number, ds(number)); + } + + @Test + public void testInetSocketAddress() { + InetSocketAddress inetSocketAddress = new InetSocketAddress(0); + Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); + } + + @Test + public void testUser() { + User user = new User("user"); + Assert.assertEquals(user, ds(user)); + } + + @Test + public void testSourceFieldsContext() { + SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); + Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + } + + @Test + public void testHashMap() { + HashMap map = new HashMap<>(); + map.put("key", "value"); + Assert.assertEquals(map, ds(map)); + } + + @Test + public void testArrayList() { + ArrayList list = new ArrayList<>(); + list.add("value"); + Assert.assertEquals(list, ds(list)); + } + + @Test(expected = OpenSearchException.class) + public void notSafeSerializable() { + Base64JDKHelper.serializeObject(new NotSafeSerializable()); + } + + @Test(expected = OpenSearchException.class) + public void notSafeDeserializable() throws Exception { + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (final ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(new NotSafeSerializable()); + } + Base64JDKHelper.deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); + } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // we expect to get User object when deserializing InjectedUser via JDK serialization + User user = new User("username"); + User deserializedUser = (User) ds(injectedUser); + Assert.assertEquals(user, deserializedUser); + Assert.assertTrue(deserializedUser.isInjected()); + } +} diff --git a/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java new file mode 100644 index 0000000000..13f2448b30 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.OpenSearchException; + +import java.net.InetSocketAddress; + +public class StreamableRegistryTest { + + StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + + @Test + public void testStreamableTypeIDs() { + Assert.assertEquals(1, streamableRegistry.getStreamableID(InetSocketAddress.class)); + Assert.assertThrows(OpenSearchException.class, () -> streamableRegistry.getStreamableID(String.class)); + } +} diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index d3363c54d8..abc0e314ef 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -47,9 +47,6 @@ import static java.util.Collections.emptySet; import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; // CS-ENFORCE-SINGLE @@ -110,9 +107,8 @@ public void setup() { ); } - @Test - public void testSendRequestDecorate() { - + private void testSendRequestDecorate(Version remoteNodeVersion) { + boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); ClusterName clusterName = ClusterName.DEFAULT; when(clusterService.getClusterName()).thenReturn(clusterName); @@ -140,7 +136,6 @@ public void testSendRequestDecorate() { User user = new User("John Doe"); threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, user); - AsyncSender sender = mock(AsyncSender.class); String action = "testAction"; TransportRequest request = mock(TransportRequest.class); TransportRequestOptions options = mock(TransportRequestOptions.class); @@ -156,37 +151,65 @@ public void testSendRequestDecorate() { DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT); Connection connection1 = transportService.getConnection(localNode); - DiscoveryNode otherNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 4321), Version.CURRENT); + DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion); Connection connection2 = transportService.getConnection(otherNode); + // from thread context inside sendRequestDecorate + AsyncSender sender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + assertEquals(transientUser, user); + } + }; // isSameNodeRequest = true securityInterceptor.sendRequestDecorate(sender, connection1, action, request, options, handler, localNode); - // from thread context inside sendRequestDecorate - doAnswer(i -> { - User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); - assertEquals(transientUser, user); - return null; - }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); // from original context User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser, user); assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); - // isSameNodeRequest = false - securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler, otherNode); // checking thread context inside sendRequestDecorate - doAnswer(i -> { - String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); - assertEquals(serializedUserHeader, Base64Helper.serializeObject(user)); - return null; - }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); + sender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization)); + } + }; + // isSameNodeRequest = false + securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler, localNode); // from original context User transientUser2 = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser2, user); assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + } + + @Test + public void testSendRequestDecorate() { + testSendRequestDecorate(Version.CURRENT); + } + /** + * Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization + */ + @Test + public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() { + testSendRequestDecorate(Version.V_2_0_0); } } diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java new file mode 100644 index 0000000000..23a64e4be3 --- /dev/null +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -0,0 +1,80 @@ +package org.opensearch.security.transport; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.opensearch.Version; +import org.opensearch.common.settings.Settings; +import org.opensearch.security.ssl.SslExceptionHandler; +import org.opensearch.security.ssl.transport.PrincipalExtractor; +import org.opensearch.security.ssl.transport.SSLConfig; +import org.opensearch.security.ssl.transport.SecuritySSLRequestHandler; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SecuritySSLRequestHandlerTests { + + @Mock + TransportRequestHandler actualHandler; + @Mock + SSLConfig sslConfig; + ThreadPool threadPool; + SslExceptionHandler sslExceptionHandler; + Settings settings; + SecuritySSLRequestHandler securitySSLRequestHandler; + String testAction; + + @Mock + private PrincipalExtractor principalExtractor; + + @Before + public void setUp() { + settings = Settings.builder() + .put("node.name", SecurityInterceptorTests.class.getSimpleName()) + .put("request.headers.default", "1") + .build(); + threadPool = new ThreadPool(settings); + testAction = "test_action"; + sslExceptionHandler = mock(SslExceptionHandler.class); + securitySSLRequestHandler = new SecuritySSLRequestHandler<>( + testAction, + actualHandler, + threadPool, + principalExtractor, + sslConfig, + sslExceptionHandler + ); + doNothing().when(sslExceptionHandler) + .logError(any(Exception.class), any(TransportRequest.class), any(String.class), any(Task.class), anyInt()); + } + + @Test + public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + Task task = mock(Task.class); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + when(transportChannel.getChannelType()).thenReturn("transport"); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + } +}