Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use auth tokens passed from core and introduce extension and user REST clients #892

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/sdk/ExtensionSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.yaml.snakeyaml.Yaml;

import static org.opensearch.sdk.ssl.SSLConfigConstants.SSL_HTTP_ENABLED;
import static org.opensearch.sdk.ssl.SSLConfigConstants.SSL_TRANSPORT_CLIENT_PEMCERT_FILEPATH;
import static org.opensearch.sdk.ssl.SSLConfigConstants.SSL_TRANSPORT_CLIENT_PEMKEY_FILEPATH;
import static org.opensearch.sdk.ssl.SSLConfigConstants.SSL_TRANSPORT_CLIENT_PEMTRUSTEDCAS_FILEPATH;
Expand Down Expand Up @@ -62,6 +63,7 @@ public class ExtensionSettings {
*/
public static final Set<String> SECURITY_SETTINGS_KEYS = Set.of(
"path.home", // TODO Find the right place to put this setting
SSL_HTTP_ENABLED,
SSL_TRANSPORT_CLIENT_PEMCERT_FILEPATH,
SSL_TRANSPORT_CLIENT_PEMKEY_FILEPATH,
SSL_TRANSPORT_CLIENT_PEMTRUSTEDCAS_FILEPATH,
Expand Down
34 changes: 34 additions & 0 deletions src/main/java/org/opensearch/sdk/ExtensionsRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.TransportAction;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.discovery.InitializeExtensionSecurityRequest;
import org.opensearch.extensions.rest.ExtensionRestRequest;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -125,6 +127,7 @@ public class ExtensionsRunner {
private final SDKNamedXContentRegistry sdkNamedXContentRegistry;
private final SDKNamedWriteableRegistry sdkNamedWriteableRegistry;
private final SDKClient sdkClient;
private OpenSearchClient extensionRestClient;
private final SDKClusterService sdkClusterService;
private final SDKTransportService sdkTransportService;
private final SDKActionModule sdkActionModule;
Expand Down Expand Up @@ -344,6 +347,19 @@ public void setExtensionNode(DiscoveryExtensionNode extensionNode) {
this.extensionNode = extensionNode;
}

/**
* Initializes a REST Client for this extension to interact with an OpenSearch cluster on its own behalf
*
* @param serviceAccountToken Access token that permits an extension to make requests on its own behalf.
* Common examples of usages of service account tokens include interacting with
* an extension's reserved indices.
*/
public void initializeExtensionRestClient(String serviceAccountToken) {
OpenSearchClient restClient = getSdkClient()
.initializeJavaClientWithHeaders(Map.of("Authorization", "Bearer " + serviceAccountToken));
this.extensionRestClient = restClient;
}

/**
* Returns the discovery extension node set during extension initialization
*
Expand Down Expand Up @@ -403,6 +419,15 @@ public void startTransportService(TransportService transportService) {
(request, channel, task) -> channel.sendResponse(extensionsInitRequestHandler.handleExtensionInitRequest(request))
);

transportService.registerRequestHandler(
ExtensionsManager.REQUEST_EXTENSION_REGISTER_SECURITY_SETTINGS,
ThreadPool.Names.GENERIC,
false,
false,
InitializeExtensionSecurityRequest::new,
(request, channel, task) -> channel.sendResponse(extensionsInitRequestHandler.handleExtensionSecurityInitRequest(request))
);

transportService.registerRequestHandler(
ExtensionsManager.REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION,
ThreadPool.Names.GENERIC,
Expand Down Expand Up @@ -528,6 +553,15 @@ public SDKClient getSdkClient() {
return sdkClient;
}

/**
* Returns the Extension rest client instance used by this extension.
*
* @return The Extension rest client instance.
*/
public OpenSearchClient getExtensionRestClient() {
return extensionRestClient;
}

/**
* @return The SDKClusterService instance associated with this object.
*/
Expand Down
53 changes: 45 additions & 8 deletions src/main/java/org/opensearch/sdk/SDKClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.datatype.guava.GuavaModule;
import org.apache.hc.core5.function.Factory;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpHost;
import org.apache.hc.client5.http.ssl.NoopHostnameVerifier;
import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder;
import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager;
import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder;
import org.apache.hc.core5.http.message.BasicHeader;
import org.apache.hc.core5.http.nio.ssl.TlsStrategy;
import org.apache.hc.core5.reactor.ssl.TlsDetails;
import org.apache.hc.core5.ssl.SSLContextBuilder;
Expand Down Expand Up @@ -95,6 +97,9 @@

import javax.net.ssl.SSLEngine;

import static org.opensearch.sdk.ssl.SSLConfigConstants.SSL_HTTP_ENABLED;
import static org.opensearch.sdk.ssl.SSLConfigConstants.SSL_TRANSPORT_ENABLED;

/**
* This class creates SDKClient for an extension to make requests to OpenSearch
*/
Expand Down Expand Up @@ -164,8 +169,11 @@ public void updateOpenSearchNodeSettings(String address, String httpPort) {
* @param port The port the client should connect to
* @return An instance of the builder
*/
private static RestClientBuilder builder(String hostAddress, int port) {
RestClientBuilder builder = RestClient.builder(new HttpHost(hostAddress, port));
private static RestClientBuilder builder(String hostAddress, int port, ExtensionSettings extensionSettings) {
boolean httpsEnabled = extensionSettings.getSecuritySettings().containsKey(SSL_HTTP_ENABLED)
&& "true".equals(extensionSettings.getSecuritySettings().get(SSL_HTTP_ENABLED));
String scheme = httpsEnabled ? "https" : "http";
RestClientBuilder builder = RestClient.builder(new HttpHost(scheme, hostAddress, port));
builder.setStrictDeprecationMode(true);
builder.setHttpClientConfigCallback(httpClientBuilder -> {
try {
Expand Down Expand Up @@ -201,8 +209,9 @@ public TlsDetails create(final SSLEngine sslEngine) {
* @param port The port of OpenSearch cluster
* @return The OpenSearchTransport implementation of RestClientTransport.
*/
private OpenSearchTransport initializeTransport(String hostAddress, int port) {
RestClientBuilder builder = builder(hostAddress, port);
private OpenSearchTransport initializeTransport(String hostAddress, int port, Map<String, String> headers) {
RestClientBuilder builder = builder(hostAddress, port, extensionSettings);
builder.setDefaultHeaders(headers.keySet().stream().map(k -> new BasicHeader(k, headers.get(k))).toArray(Header[]::new));

restClient = builder.build();
ObjectMapper mapper = new ObjectMapper();
Expand All @@ -227,6 +236,20 @@ public OpenSearchClient initializeJavaClient() {
return initializeJavaClient(extensionSettings.getOpensearchAddress(), Integer.parseInt(extensionSettings.getOpensearchPort()));
}

/**
* Initializes an OpenSearchClient using OpenSearch JavaClient
*
* @return The SDKClient implementation of OpenSearchClient. The user is responsible for calling
* {@link #doCloseJavaClients()} when finished with the client
*/
public OpenSearchClient initializeJavaClientWithHeaders(Map<String, String> headers) {
return initializeJavaClientWithHeaders(
extensionSettings.getOpensearchAddress(),
Integer.parseInt(extensionSettings.getOpensearchPort()),
headers
);
}

/**
* Initializes an OpenSearchClient using OpenSearch JavaClient
*
Expand All @@ -236,7 +259,21 @@ public OpenSearchClient initializeJavaClient() {
* {@link #doCloseJavaClients()} when finished with the client
*/
public OpenSearchClient initializeJavaClient(String hostAddress, int port) {
OpenSearchTransport transport = initializeTransport(hostAddress, port);
OpenSearchTransport transport = initializeTransport(hostAddress, port, Map.of());
javaClient = new OpenSearchClient(transport);
return javaClient;
}

/**
* Initializes an OpenSearchClient using OpenSearch JavaClient
*
* @param hostAddress The address of OpenSearch cluster, client can connect to
* @param port The port of OpenSearch cluster
* @return The SDKClient implementation of OpenSearchClient. The user is responsible for calling
* {@link #doCloseJavaClients()} when finished with the client
*/
public OpenSearchClient initializeJavaClientWithHeaders(String hostAddress, int port, Map<String, String> headers) {
OpenSearchTransport transport = initializeTransport(hostAddress, port, headers);
javaClient = new OpenSearchClient(transport);
return javaClient;
}
Expand All @@ -260,7 +297,7 @@ public OpenSearchAsyncClient initializeJavaAsyncClient() {
* {@link #doCloseJavaClients()} when finished with the client
*/
public OpenSearchAsyncClient initalizeJavaAsyncClient(String hostAddress, int port) {
OpenSearchTransport transport = initializeTransport(hostAddress, port);
OpenSearchTransport transport = initializeTransport(hostAddress, port, Map.of());
javaAsyncClient = new OpenSearchAsyncClient(transport);
return javaAsyncClient;
}
Expand Down Expand Up @@ -300,7 +337,7 @@ public SDKRestClient initializeRestClient() {
*/
@Deprecated
public SDKRestClient initializeRestClient(String hostAddress, int port) {
this.sdkRestClient = new SDKRestClient(this, new RestHighLevelClient(builder(hostAddress, port)));
this.sdkRestClient = new SDKRestClient(this, new RestHighLevelClient(builder(hostAddress, port, extensionSettings)));
return this.sdkRestClient;
}

Expand Down Expand Up @@ -556,7 +593,7 @@ public void bulk(BulkRequest request, ActionListener<BulkResponse> listener) {
* @return the response returned by OpenSearch
* @throws IOException in case of a problem or the connection was aborted
*/
public Response performRequest(Request request) throws IOException {
public Response uest(Request request) throws IOException {
cwperks marked this conversation as resolved.
Show resolved Hide resolved
return restHighLevelClient.getLowLevelClient().performRequest(request);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,32 @@

package org.opensearch.sdk.handlers;

import com.fasterxml.jackson.databind.JsonNode;
import org.apache.logging.log4j.LogManager;

import org.apache.logging.log4j.Logger;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.WarningFailureException;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch.core.IndexRequest;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.indices.CreateIndexRequest;
import org.opensearch.client.opensearch.indices.DeleteIndexRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.discovery.InitializeExtensionRequest;
import org.opensearch.discovery.InitializeExtensionResponse;
import org.opensearch.discovery.InitializeExtensionSecurityRequest;
import org.opensearch.discovery.InitializeExtensionSecurityResponse;
import org.opensearch.sdk.ExtensionsRunner;
import org.opensearch.sdk.SDKTransportService;
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Map;

import static org.opensearch.sdk.ExtensionsRunner.NODE_NAME_SETTING;

/**
Expand Down Expand Up @@ -94,4 +110,20 @@ public InitializeExtensionResponse handleExtensionInitRequest(InitializeExtensio
extensionsRunner.getSdkClusterService().getClusterSettings().sendPendingSettingsUpdateConsumers();
}
}

/**
* Handles a extension request from OpenSearch. This is the first request for the transport communication and will initialize the extension and will be a part of OpenSearch bootstrap.
*
* @param extensionInitSecurityRequest The request to handle.
* @return A response to OpenSearch validating that this is an extension.
*/
public InitializeExtensionSecurityResponse handleExtensionSecurityInitRequest(
InitializeExtensionSecurityRequest extensionInitSecurityRequest
) {
logger.info("Registering Extension Request received from OpenSearch");

extensionsRunner.initializeExtensionRestClient(extensionInitSecurityRequest.getServiceAccountToken());

return new InitializeExtensionSecurityResponse(extensionsRunner.getExtensionNode().getId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.opensearch.sdk.handlers;

import joptsimple.internal.Strings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.bytes.BytesReference;
Expand All @@ -22,6 +23,10 @@
import org.opensearch.sdk.rest.SDKHttpRequest;
import org.opensearch.sdk.rest.SDKRestRequest;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -68,11 +73,19 @@ public RestExecuteOnExtensionResponse handleRestExecuteOnExtensionRequest(Extens
);
}

String oboToken = request.getRequestIssuerIdentity();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this should be called an onBehalfOfToken both here and on the request object?

What do you think about making the type Optional<> ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think this needs to be updated in core. I can create a PR for that.

Map<String, List<String>> headers = new HashMap<>();
headers.putAll(request.headers());
System.out.println("oboToken: " + oboToken);
if (!Strings.isNullOrEmpty(oboToken)) {
headers.put("Authorization", List.of("Bearer " + oboToken));
}

SDKRestRequest sdkRestRequest = new SDKRestRequest(
sdkNamedXContentRegistry.getRegistry(),
request.params(),
request.path(),
request.headers(),
headers,
new SDKHttpRequest(request),
null
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

Expand All @@ -25,6 +27,7 @@
import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON;

import org.opensearch.OpenSearchException;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.Strings;
Expand All @@ -38,6 +41,7 @@
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.rest.RestResponse;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.sdk.SDKClient;

/**
* Provides convenience methods to reduce boilerplate code in an {@link ExtensionRestHandler} implementation.
Expand All @@ -48,6 +52,14 @@ public abstract class BaseExtensionRestHandler implements ExtensionRestHandler {

private String routeNamePrefix;

private SDKClient sdkClient;

protected OpenSearchClient userRestClient;

public BaseExtensionRestHandler(SDKClient sdkClient) {
this.sdkClient = sdkClient;
}

/**
* Constant for JSON content type
*/
Expand Down Expand Up @@ -114,6 +126,15 @@ public List<ReplacedRoute> replacedRoutes() {

@Override
public ExtensionRestResponse handleRequest(RestRequest request) {
if (request instanceof SDKRestRequest) {
SDKRestRequest sdkRestRequest = (SDKRestRequest) request;
List<String> authorizationHeaders = sdkRestRequest.getHeaders().get("Authorization");
Map<String, String> headers = new HashMap<>();
if (!authorizationHeaders.isEmpty()) {
headers.put("Authorization", authorizationHeaders.get(0));
}
this.userRestClient = sdkClient.initializeJavaClientWithHeaders(headers);
}
Optional<NamedRoute> route = routes().stream()
.filter(rh -> rh.getMethod().equals(request.method()))
.filter(rh -> restPathMatches(request.path(), rh.getPath()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public HelloWorldExtension() {

@Override
public List<ExtensionRestHandler> getExtensionRestHandlers() {
return List.of(new RestHelloAction(), new RestRemoteHelloAction(extensionsRunner()));
return List.of(new RestHelloAction(extensionsRunner().getSdkClient()), new RestRemoteHelloAction(extensionsRunner()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.rest.NamedRoute;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.sdk.SDKClient;
import org.opensearch.sdk.rest.BaseExtensionRestHandler;
import org.opensearch.sdk.rest.ExtensionRestHandler;

Expand Down Expand Up @@ -48,7 +49,7 @@
public class RestHelloAction extends BaseExtensionRestHandler {

private static final String TEXT_CONTENT_TYPE = "text/plain; charset=UTF-8";
private static final String GREETING = "Hello, %s!";
public static final String GREETING = "Hello, %s!";
private static final String DEFAULT_NAME = "World";

private String worldName = DEFAULT_NAME;
Expand All @@ -58,7 +59,9 @@ public class RestHelloAction extends BaseExtensionRestHandler {
/**
* Instantiate this action
*/
public RestHelloAction() {}
public RestHelloAction(SDKClient sdkClient) {
super(sdkClient);
}

@Override
public List<NamedRoute> routes() {
Expand Down
Loading