Skip to content

Commit

Permalink
Support Mutual TLS authentication for Core Peer Forwarding. Resolves o…
Browse files Browse the repository at this point in the history
  • Loading branch information
dlvenable authored Sep 16, 2022
1 parent 21afeb6 commit 7e2331f
Show file tree
Hide file tree
Showing 16 changed files with 458 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.peerforwarder;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

public enum ForwardingAuthentication {
MUTUAL_TLS("mutual_tls"),
UNAUTHENTICATED("unauthenticated");

private static final Map<String, ForwardingAuthentication> STRING_NAME_TO_ENUM_MAP = new HashMap<>();

private final String name;

static {
Arrays.stream(ForwardingAuthentication.values())
.forEach(enumValue -> STRING_NAME_TO_ENUM_MAP.put(enumValue.name, enumValue));
}

ForwardingAuthentication(final String name){
this.name = name;
}

public String getName(){
return name;
}

static ForwardingAuthentication getByName(final String name) {
return Optional.ofNullable(STRING_NAME_TO_ENUM_MAP.get(name))
.orElseThrow(() -> new IllegalArgumentException("Unrecognized ForwardingAuthentication: " + name));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

package org.opensearch.dataprepper.peerforwarder;

import org.opensearch.dataprepper.plugins.certificate.model.Certificate;
import com.linecorp.armeria.client.ClientBuilder;
import com.linecorp.armeria.client.ClientFactory;
import com.linecorp.armeria.client.ClientFactoryBuilder;
import com.linecorp.armeria.client.Clients;
import com.linecorp.armeria.client.WebClient;
import org.opensearch.dataprepper.plugins.certificate.model.Certificate;

import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
Expand All @@ -26,6 +27,7 @@ public class PeerClientPool {
private int clientTimeoutSeconds = 3;
private boolean ssl;
private Certificate certificate;
private ForwardingAuthentication authentication;

public PeerClientPool() {
peerClients = new ConcurrentHashMap<>();
Expand All @@ -47,6 +49,10 @@ public void setCertificate(final Certificate certificate) {
this.certificate = certificate;
}

public void setAuthentication(ForwardingAuthentication authentication) {
this.authentication = authentication;
}

public WebClient getClient(final String address) {
return peerClients.computeIfAbsent(address, this::getHTTPClient);
}
Expand All @@ -58,14 +64,20 @@ private WebClient getHTTPClient(final String ipAddress) {
.writeTimeout(Duration.ofSeconds(clientTimeoutSeconds));

if (ssl) {
final ClientFactory clientFactory = ClientFactory.builder()
final ClientFactoryBuilder clientFactoryBuilder = ClientFactory.builder()
.tlsCustomizer(sslContextBuilder -> sslContextBuilder.trustManager(
new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8))
new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8))
)
).tlsNoVerifyHosts(ipAddress)
.build();

clientBuilder = clientBuilder.factory(clientFactory);
)
.tlsNoVerifyHosts(ipAddress);
// TODO: Add keyManager configuration here
if (authentication == ForwardingAuthentication.MUTUAL_TLS) {
clientFactoryBuilder.tlsCustomizer(sslContextBuilder -> sslContextBuilder.keyManager(
new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8)),
new ByteArrayInputStream(certificate.getPrivateKey().getBytes(StandardCharsets.UTF_8))
));
}
clientBuilder = clientBuilder.factory(clientFactoryBuilder.build());
}

return clientBuilder.build(WebClient.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public PeerClientPool setPeerClientPool() {
final boolean ssl = peerForwarderConfiguration.isSsl();
final boolean useAcmCertForSsl = peerForwarderConfiguration.isUseAcmCertificateForSsl();

peerClientPool.setAuthentication(peerForwarderConfiguration.getAuthentication());

if (ssl || useAcmCertForSsl) {
peerClientPool.setSsl(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class PeerForwarderConfiguration {
private boolean ssl = false;
private String sslCertificateFile;
private String sslKeyFile;
private ForwardingAuthentication authentication = ForwardingAuthentication.UNAUTHENTICATED;
private boolean useAcmCertificateForSsl = false;
private String acmCertificateArn;
private String acmPrivateKeyPassword;
Expand Down Expand Up @@ -60,6 +61,7 @@ public PeerForwarderConfiguration (
@JsonProperty("ssl") final Boolean ssl,
@JsonProperty("ssl_certificate_file") final String sslCertificateFile,
@JsonProperty("ssl_key_file") final String sslKeyFile,
@JsonProperty("authentication") final Map<String, Object> authentication,
@JsonProperty("use_acm_certificate_for_ssl") final Boolean useAcmCertificateForSsl,
@JsonProperty("acm_certificate_arn") final String acmCertificateArn,
@JsonProperty("acm_private_key_password") final String acmPrivateKeyPassword,
Expand All @@ -84,6 +86,7 @@ public PeerForwarderConfiguration (
setUseAcmCertificateForSsl(useAcmCertificateForSsl);
setSslCertificateFile(sslCertificateFile);
setSslKeyFile(sslKeyFile);
setAuthentication(authentication);
setAcmCertificateArn(acmCertificateArn);
this.acmPrivateKeyPassword = acmPrivateKeyPassword;
setAcmCertificateTimeoutMillis(acmCertificateTimeoutMillis);
Expand All @@ -98,6 +101,7 @@ public PeerForwarderConfiguration (
setBatchSize(batchSize);
setBufferSize(bufferSize);
checkForCertAndKeyFileInS3();
validateSslAndAuthentication();
}

public int getServerPort() {
Expand Down Expand Up @@ -259,6 +263,25 @@ private void setSslKeyFile(final String sslKeyFile) {
}
}

private void setAuthentication(final Map<String, Object> authentication) {
if(authentication == null)
return;

if (authentication.isEmpty())
return;

if (authentication.size() > 1)
throw new IllegalArgumentException("Invalid authentication configuration.");

final String authenticationName = authentication.keySet().iterator().next();

this.authentication = ForwardingAuthentication.getByName(authenticationName);
}

public ForwardingAuthentication getAuthentication() {
return authentication;
}

private void setUseAcmCertificateForSsl(final Boolean useAcmCertificateForSsl) {
if (useAcmCertificateForSsl != null) {
this.useAcmCertificateForSsl = useAcmCertificateForSsl;
Expand Down Expand Up @@ -382,4 +405,9 @@ private void checkForCertAndKeyFileInS3() {
public boolean isSslCertAndKeyFileInS3() {
return sslCertAndKeyFileInS3;
}

private void validateSslAndAuthentication() {
if(authentication == ForwardingAuthentication.MUTUAL_TLS && !ssl)
throw new IllegalArgumentException("Mutual TLS is only available when SSL is enabled.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

package org.opensearch.dataprepper.peerforwarder.server;

import org.opensearch.dataprepper.plugins.certificate.CertificateProvider;
import org.opensearch.dataprepper.plugins.certificate.model.Certificate;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerBuilder;
import io.netty.handler.ssl.ClientAuth;
import org.opensearch.dataprepper.peerforwarder.ForwardingAuthentication;
import org.opensearch.dataprepper.peerforwarder.PeerForwarderConfiguration;
import org.opensearch.dataprepper.peerforwarder.certificate.CertificateProviderFactory;
import org.opensearch.dataprepper.plugins.certificate.CertificateProvider;
import org.opensearch.dataprepper.plugins.certificate.model.Certificate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -47,21 +49,29 @@ public Server get() {
sb.disableServerHeader();

if (peerForwarderConfiguration.isSsl()) {
LOG.info("Creating http source with SSL/TLS enabled.");
final CertificateProvider certificateProvider = certificateProviderFactory.getCertificateProvider();
final Certificate certificate = certificateProvider.getCertificate();
LOG.info("Creating http source with SSL/TLS enabled.");
// TODO: enable encrypted key with password
sb.https(peerForwarderConfiguration.getServerPort())
.tls(
new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8)),
new ByteArrayInputStream(certificate.getPrivateKey().getBytes(StandardCharsets.UTF_8)
)
);

if (peerForwarderConfiguration.getAuthentication() == ForwardingAuthentication.MUTUAL_TLS) {
sb.tlsCustomizer(sslContextBuilder -> sslContextBuilder.trustManager(
new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8))
)
.clientAuth(ClientAuth.REQUIRE));
}
} else {
LOG.warn("Creating Peer Forwarder server without SSL/TLS. This is not secure.");
sb.http(peerForwarderConfiguration.getServerPort());
}


sb.maxNumConnections(peerForwarderConfiguration.getMaxConnectionCount());
sb.requestTimeout(Duration.ofMillis(peerForwarderConfiguration.getRequestTimeout()));
final int threadCount = peerForwarderConfiguration.getServerThreadCount();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.peerforwarder;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import org.junit.jupiter.params.provider.EnumSource;

import java.util.UUID;
import java.util.stream.Stream;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;

class ForwardingAuthenticationTest {

@ParameterizedTest
@ArgumentsSource(EnumToStringNameArgumentsProvider.class)
void getValue_returns_expected_value (final ForwardingAuthentication enumValue, final String expectedName) {
assertThat(enumValue.getName(), equalTo(expectedName));
}

@ParameterizedTest
@EnumSource(ForwardingAuthentication.class)
void getByName_returns_correct_enum_from_expected_name(final ForwardingAuthentication enumValue) {

final String stringName = enumValue.getName();

assertThat(ForwardingAuthentication.getByName(stringName), equalTo(enumValue));
}

@Test
void getByName_throws_for_null() {
assertThrows(IllegalArgumentException.class, () -> ForwardingAuthentication.getByName(null));
}

@Test
void getByName_throws_for_empty_string() {
assertThrows(IllegalArgumentException.class, () -> ForwardingAuthentication.getByName(""));
}

@Test
void getByName_throws_for_unrecognized_non_empty_name() {
assertThrows(IllegalArgumentException.class, () -> ForwardingAuthentication.getByName(UUID.randomUUID().toString()));
}

private static class EnumToStringNameArgumentsProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(final ExtensionContext context) {
return Stream.of(
arguments(ForwardingAuthentication.MUTUAL_TLS, "mutual_tls"),
arguments(ForwardingAuthentication.UNAUTHENTICATED, "unauthenticated")
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Objects;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -62,4 +63,26 @@ void testGetClientWithSSL(final String address) throws IOException {
assertThat(client.uri(), equalTo(URI.create("https://" + address + ":" + PORT + "/")));
}

@ParameterizedTest
@ValueSource(strings = {VALID_ADDRESS, LOCALHOST})
void testGetClientWithMutualTls(final String address) throws IOException {
final PeerClientPool objectUnderTest = new PeerClientPool();
objectUnderTest.setSsl(true);
objectUnderTest.setPort(PORT);
objectUnderTest.setAuthentication(ForwardingAuthentication.MUTUAL_TLS);

final Path certFilePath = new File(Objects.requireNonNull(PeerClientPoolTest.class.getClassLoader().getResource("test-crt.crt")).getFile()).toPath();
final Path keyFilePath = new File(Objects.requireNonNull(PeerClientPoolTest.class.getClassLoader().getResource("test-key.key")).getFile()).toPath();
final String certAsString = Files.readString(certFilePath);
final String keyAsString = Files.readString(keyFilePath);
final Certificate certificate = new Certificate(certAsString, keyAsString);

objectUnderTest.setCertificate(certificate);

final WebClient client = objectUnderTest.getClient(address);

assertThat(client, notNullValue());
assertThat(client.uri(), equalTo(URI.create("https://" + address + ":" + PORT + "/")));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.Collections;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.hamcrest.MatcherAssert.assertThat;
import org.hamcrest.core.IsInstanceOf;
import org.opensearch.dataprepper.peerforwarder.certificate.CertificateProviderFactory;
import org.opensearch.dataprepper.peerforwarder.discovery.DiscoveryMode;

import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -58,10 +62,16 @@ void testCreateHashRing_without_endpoints_should_throw() {
void testCreatePeerClientPool_should_return() {
PeerForwarderClientFactory peerForwarderClientFactory = createObjectUnderTest();

PeerClientPool peerClientPool = peerForwarderClientFactory.setPeerClientPool();
PeerClientPool returnedPeerClientPool = peerForwarderClientFactory.setPeerClientPool();

assertThat(peerClientPool, new IsInstanceOf(PeerClientPool.class));
assertThat(returnedPeerClientPool, equalTo(peerClientPool));
}


@ParameterizedTest
@EnumSource(ForwardingAuthentication.class)
void testCreatePeerClientPool_should_set_the_authentication(final ForwardingAuthentication authentication) {
when(peerForwarderConfiguration.getAuthentication()).thenReturn(authentication);
createObjectUnderTest().setPeerClientPool();
verify(peerClientPool).setAuthentication(authentication);
}
}
Loading

0 comments on commit 7e2331f

Please sign in to comment.