Skip to content

Commit

Permalink
Drop mocksocket in favour of custom security manager checks (tests on…
Browse files Browse the repository at this point in the history
…ly) (#1205)

* Drop mocksocket in favour of custom security manager checks (tests only)

Signed-off-by: Andriy Redko <[email protected]>

* Slightly relaxed host checks to allow all local addresses

Signed-off-by: Andriy Redko <[email protected]>
  • Loading branch information
reta authored Sep 16, 2021
1 parent cbbf967 commit b6c8bdf
Show file tree
Hide file tree
Showing 18 changed files with 83 additions and 38 deletions.
1 change: 0 additions & 1 deletion buildSrc/version.properties
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ commonslogging = 1.1.3
commonscodec = 1.13
hamcrest = 2.1
securemock = 1.2
mocksocket = 1.2
mockito = 1.9.5
objenesis = 1.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import java.security.AccessController;
import java.security.Permission;
import java.security.PrivilegedAction;
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;

/**
* Extension of SecurityManager that works around a few design flaws in Java Security.
Expand Down Expand Up @@ -104,11 +106,49 @@ public SecureSM(final String[] classesThatCanExit) {
* <li><code>org.eclipse.internal.junit.runner.</code></li>
* <li><code>com.intellij.rt.execution.junit.</code></li>
* </ul>
*
* For testing purposes, the security manager grants network permissions "connect, accept"
* to following classes, granted they only access local network interfaces.
*
* <ul>
* <li><code>sun.net.httpserver.ServerImpl</code></li>
* <li><code>java.net.ServerSocket"</code></li>
* <li><code>java.net.Socket</code></li>
* </ul>
*
* @return an instance of SecureSM where test packages can halt or exit the virtual machine
*/
public static SecureSM createTestSecureSM() {
return new SecureSM(TEST_RUNNER_PACKAGES);
public static SecureSM createTestSecureSM(final Set<String> trustedHosts) {
return new SecureSM(TEST_RUNNER_PACKAGES) {
// Trust these callers inside the test suite only
final String[] TRUSTED_CALLERS = new String[] {
"sun.net.httpserver.ServerImpl",
"java.net.ServerSocket",
"java.net.Socket"
};

@Override
public void checkConnect(String host, int port) {
// Allow to connect from selected trusted classes to local addresses only
if (!hasTrustedCallerChain() || !trustedHosts.contains(host)) {
super.checkConnect(host, port);
}
}

@Override
public void checkAccept(String host, int port) {
// Allow to accept connections from selected trusted classes to local addresses only
if (!hasTrustedCallerChain() || !trustedHosts.contains(host)) {
super.checkAccept(host, port);
}
}

private boolean hasTrustedCallerChain() {
return Arrays
.stream(getClassContext())
.anyMatch(c -> Arrays.stream(TRUSTED_CALLERS).anyMatch(t -> c.getName().startsWith(t)));
}
};
}

static final String[] TEST_RUNNER_PACKAGES = new String[] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.security.Permission;
import java.security.Policy;
import java.security.ProtectionDomain;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean;

/** Simple tests for SecureSM */
Expand All @@ -57,7 +58,7 @@ public boolean implies(ProtectionDomain domain, Permission permission) {
return true;
}
});
System.setSecurityManager(SecureSM.createTestSecureSM());
System.setSecurityManager(SecureSM.createTestSecureSM(Collections.emptySet()));
}

@SuppressForbidden(reason = "testing that System#exit is blocked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.opensearch.common.ssl.PemTrustConfig;
import org.opensearch.env.Environment;
import org.opensearch.env.TestEnvironment;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.watcher.ResourceWatcherService;
import org.hamcrest.Matchers;
Expand Down Expand Up @@ -91,7 +90,7 @@ public class ReindexRestClientSslTests extends OpenSearchTestCase {
public static void setupHttpServer() throws Exception {
InetSocketAddress address = new InetSocketAddress("localhost", 0);
SSLContext sslContext = buildServerSslContext();
server = MockHttpServer.createHttps(address, 0);
server = HttpsServer.create(address, 0);
server.setHttpsConfigurator(new ClientAuthHttpsConfigurator(sslContext));
server.start();
server.createContext("/", http -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.AfterClass;
import org.junit.Before;
Expand Down Expand Up @@ -67,7 +66,7 @@ public static void startHttp() throws Exception {
}
blobName = randomAlphaOfLength(8);

httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 6001), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 6001), 0);

httpServer.createContext("/indices/" + blobName, (s) -> {
s.sendResponseHeaders(200, message.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.common.util.MockPageCacheRecycler;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.mocksocket.MockSocket;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.SharedGroupFactory;
Expand Down Expand Up @@ -100,7 +99,7 @@ public void testThatTextMessageIsReturnedOnHTTPLikeRequest() throws Exception {
String randomMethod = randomFrom("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH");
String data = randomMethod + " / HTTP/1.1";

try (Socket socket = new MockSocket(host, port)) {
try (Socket socket = new Socket(host, port)) {
socket.getOutputStream().write(data.getBytes(StandardCharsets.UTF_8));
socket.getOutputStream().flush();

Expand All @@ -111,7 +110,7 @@ public void testThatTextMessageIsReturnedOnHTTPLikeRequest() throws Exception {
}

public void testThatNothingIsReturnedForOtherInvalidPackets() throws Exception {
try (Socket socket = new MockSocket(host, port)) {
try (Socket socket = new Socket(host, port)) {
socket.getOutputStream().write("FOOBAR".getBytes(StandardCharsets.UTF_8));
socket.getOutputStream().flush();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.discovery.DiscoveryModule;
import org.opensearch.env.Environment;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.node.Node;
import org.opensearch.plugin.discovery.azure.classic.AzureDiscoveryPlugin;
import org.opensearch.plugins.Plugin;
Expand Down Expand Up @@ -163,7 +162,7 @@ protected Path nodeConfigPath(int nodeOrdinal) {
public static void startHttpd() throws Exception {
logDir = createTempDir();
SSLContext sslContext = getSSLContext();
httpsServer = MockHttpServer.createHttps(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
httpsServer = HttpsServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext));
httpsServer.createContext("/subscription/services/hostedservices/myservice", (s) -> {
Headers headers = s.getResponseHeaders();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.common.settings.MockSecureSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.internal.io.IOUtils;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.transport.MockTransportService;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -74,7 +73,7 @@ public abstract class AbstractEC2MockAPITestCase extends OpenSearchTestCase {

@Before
public void setUp() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.start();
threadPool = new TestThreadPool(EC2RetriesTests.class.getName());
transportService = createTransportService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.rest.RestStatus;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -74,7 +73,7 @@ public class Ec2NetworkTests extends OpenSearchTestCase {

@BeforeClass
public static void startHttp() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);

BiConsumer<String, String> registerContext = (path, v) ->{
final byte[] message = v.getBytes(UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

package org.opensearch.example.resthandler;

import org.elasticsearch.mocksocket.MockSocket;
import org.opensearch.test.OpenSearchTestCase;

import java.io.BufferedReader;
Expand All @@ -57,7 +56,7 @@ public void testExample() throws Exception {
final URL url = new URL("http://" + externalAddress);
final InetAddress address = InetAddress.getByName(url.getHost());
try (
Socket socket = new MockSocket(address, url.getPort());
Socket socket = new Socket(address, url.getPort());
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.UTF_8));
BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8))
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
import org.opensearch.common.unit.ByteSizeUnit;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.CountDown;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.rest.RestStatus;
import org.opensearch.rest.RestUtils;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -118,7 +117,7 @@ public class AzureBlobContainerRetriesTests extends OpenSearchTestCase {
@Before
public void setUp() throws Exception {
threadPool = new TestThreadPool(getTestClass().getName(), AzureRepositoryPlugin.executorBuilder());
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.start();
super.setUp();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ grant codeBase "${codebase.junit}" {
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
};

grant codeBase "${codebase.mocksocket}" {
// mocksocket makes and accepts socket connections
permission java.net.SocketPermission "*", "accept,connect";
};

grant codeBase "${codebase.opensearch-nio}" {
// opensearch-nio makes and accepts socket connections
permission java.net.SocketPermission "*", "accept,connect";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.internal.io.IOUtils;
import org.opensearch.index.IndexNotFoundException;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -180,7 +179,7 @@ public static MockTransportService startTransport(

@SuppressForbidden(reason = "calls getLocalHost here but it's fine in this case")
public void testSlowNodeCanBeCancelled() throws IOException, InterruptedException {
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
socket.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0), 1);
socket.setReuseAddress(true);
DiscoveryNode seedNode = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(),
Expand Down
1 change: 0 additions & 1 deletion test/framework/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ dependencies {
api "commons-logging:commons-logging:${versions.commonslogging}"
api "commons-codec:commons-codec:${versions.commonscodec}"
api "org.elasticsearch:securemock:${versions.securemock}"
api "org.elasticsearch:mocksocket:${versions.mocksocket}"
}

compileJava.options.compilerArgs -= '-Xlint:cast'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@
import org.opensearch.common.io.FileSystemUtils;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.network.IfConfig;
import org.opensearch.common.network.NetworkAddress;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.secure_sm.SecureSM;
import org.junit.Assert;

import java.io.InputStream;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.SocketPermission;
import java.net.URL;
import java.nio.file.Files;
Expand All @@ -66,6 +69,7 @@
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;

import static com.carrotsearch.randomizedtesting.RandomizedTest.systemPropertyAsBoolean;

Expand Down Expand Up @@ -161,7 +165,7 @@ public boolean implies(ProtectionDomain domain, Permission permission) {
return opensearchPolicy.implies(domain, permission) || testFramework.implies(domain, permission);
}
});
System.setSecurityManager(SecureSM.createTestSecureSM());
System.setSecurityManager(SecureSM.createTestSecureSM(getTrustedHosts()));
Security.selfTest();

// guarantee plugin classes are initialized first, in case they have one-time hacks.
Expand Down Expand Up @@ -271,6 +275,25 @@ static Set<URL> parseClassPathWithSymlinks() throws Exception {
}
return raw;
}

/**
* Collect host addresses of all local interfaces so we could check
* if the network connection is being made only on those.
* @return host names and addresses of all local interfaces
*/
private static Set<String> getTrustedHosts() {
//
try {
return Collections
.list(NetworkInterface.getNetworkInterfaces())
.stream()
.flatMap(iface -> Collections.list(iface.getInetAddresses()).stream())
.map(address -> NetworkAddress.format(address))
.collect(Collectors.toSet());
} catch (final SocketException e) {
return Collections.emptySet();
}
}

// does nothing, just easy way to make sure the class is loaded.
public static void ensureInitialized() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import org.opensearch.common.unit.ByteSizeValue;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.CountDown;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -81,7 +80,7 @@ public abstract class AbstractBlobContainerRetriesTestCase extends OpenSearchTes

@Before
public void setUp() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.start();
super.setUp();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.repositories.Repository;
import org.opensearch.repositories.RepositoryMissingException;
Expand Down Expand Up @@ -102,7 +101,7 @@ protected interface BlobStoreHttpHandler extends HttpHandler {

@BeforeClass
public static void startHttpServer() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.setExecutor(r -> {
try {
r.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.core.internal.io.IOUtils;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.opensearch.node.Node;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -1938,7 +1937,7 @@ public void testRegisterHandlerTwice() {

public void testTimeoutPerConnection() throws IOException {
assumeTrue("Works only on BSD network stacks", Constants.MAC_OS_X || Constants.FREE_BSD);
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
// note - this test uses backlog=1 which is implementation specific ie. it might not work on some TCP/IP stacks
// on linux (at least newer ones) the listen(addr, backlog=1) should just ignore new connections if the queue is full which
// means that once we received an ACK from the client we just drop the packet on the floor (which is what we want) and we run
Expand Down Expand Up @@ -2057,7 +2056,7 @@ public void testTcpHandshake() {
}

public void testTcpHandshakeTimeout() throws IOException {
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
socket.bind(getLocalEphemeral(), 1);
socket.setReuseAddress(true);
DiscoveryNode dummy = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(),
Expand All @@ -2078,7 +2077,7 @@ public void testTcpHandshakeTimeout() throws IOException {
}

public void testTcpHandshakeConnectionReset() throws IOException, InterruptedException {
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
socket.bind(getLocalEphemeral(), 1);
socket.setReuseAddress(true);
DiscoveryNode dummy = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(),
Expand Down

0 comments on commit b6c8bdf

Please sign in to comment.