Skip to content

Commit

Permalink
Added Setting to Toggle Data Source Management Code Paths
Browse files Browse the repository at this point in the history
Signed-off-by: Frank Dattalo <[email protected]>
  • Loading branch information
fddattal committed Jun 12, 2024
1 parent 1d703e8 commit a827d03
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public enum Key {
ENCYRPTION_MASTER_KEY("plugins.query.datasources.encryption.masterkey"),
DATASOURCES_URI_HOSTS_DENY_LIST("plugins.query.datasources.uri.hosts.denylist"),
DATASOURCES_LIMIT("plugins.query.datasources.limit"),
DATASOURCES_ENABLED("plugins.query.datasources.enabled"),

METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"),
METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ public String toString() {
}

/** Register DataSourceType to be used in fromString method */
public static void register(DataSourceType ... dataSourceTypes) {
public static void register(DataSourceType... dataSourceTypes) {
for (DataSourceType type : dataSourceTypes) {
String upperCaseName = type.name().toUpperCase();
if (!knownValues.containsKey(upperCaseName)) {
knownValues.put(type.name().toUpperCase(), type);
} else {
throw new IllegalArgumentException("DataSourceType with name " + type.name() + " already exists");
throw new IllegalArgumentException(
"DataSourceType with name " + type.name() + " already exists");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
Expand All @@ -37,14 +40,18 @@
import org.opensearch.sql.datasources.utils.XContentParserUtils;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;

@RequiredArgsConstructor
public class RestDataSourceQueryAction extends BaseRestHandler {

public static final String DATASOURCE_ACTIONS = "datasource_actions";
public static final String BASE_DATASOURCE_ACTION_URL = "/_plugins/_query/_datasources";

private static final Logger LOG = LogManager.getLogger(RestDataSourceQueryAction.class);

private final OpenSearchSettings settings;

@Override
public String getName() {
return DATASOURCE_ACTIONS;
Expand Down Expand Up @@ -115,6 +122,9 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
if (!enabled()) {
return disabledError(restRequest);
}
switch (restRequest.method()) {
case POST:
return executePostRequest(restRequest, nodeClient);
Expand Down Expand Up @@ -314,4 +324,24 @@ private static boolean isClientError(Exception e) {
|| e instanceof IllegalArgumentException
|| e instanceof IllegalStateException;
}

private boolean enabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}

private RestChannelConsumer disabledError(RestRequest request) {

// consume all the params of the request to ensure that the BaseRestHandler
// doesn't fail the request with an unconsumed parameter exception
request.params().keySet().forEach(request::param);

return channel -> {
reportError(
channel,
new OpenSearchStatusException(
String.format("%s is disabled", Settings.Key.DATASOURCES_ENABLED.getKeyValue()),
BAD_REQUEST),
BAD_REQUEST);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.encryptor.Encryptor;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.service.DataSourceMetadataStorage;
import org.opensearch.sql.datasources.utils.XContentParserUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;

public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataStorage {

Expand All @@ -61,6 +63,7 @@ public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataSt
private final ClusterService clusterService;

private final Encryptor encryptor;
private final OpenSearchSettings settings;

/**
* This class implements DataSourceMetadataStorage interface using OpenSearch as underlying
Expand All @@ -71,14 +74,21 @@ public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataSt
* @param encryptor Encryptor.
*/
public OpenSearchDataSourceMetadataStorage(
Client client, ClusterService clusterService, Encryptor encryptor) {
Client client,
ClusterService clusterService,
Encryptor encryptor,
OpenSearchSettings settings) {
this.client = client;
this.clusterService = clusterService;
this.encryptor = encryptor;
this.settings = settings;
}

@Override
public List<DataSourceMetadata> getDataSourceMetadata() {
if (!isEnabled()) {
return Collections.emptyList();
}
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
return Collections.emptyList();
Expand All @@ -88,6 +98,9 @@ public List<DataSourceMetadata> getDataSourceMetadata() {

@Override
public Optional<DataSourceMetadata> getDataSourceMetadata(String datasourceName) {
if (!isEnabled()) {
return Optional.empty();
}
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
return Optional.empty();
Expand All @@ -101,6 +114,9 @@ public Optional<DataSourceMetadata> getDataSourceMetadata(String datasourceName)

@Override
public void createDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
encryptDecryptAuthenticationData(dataSourceMetadata, true);
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
Expand Down Expand Up @@ -134,6 +150,9 @@ public void createDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {

@Override
public void updateDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
encryptDecryptAuthenticationData(dataSourceMetadata, true);
UpdateRequest updateRequest =
new UpdateRequest(DATASOURCE_INDEX_NAME, dataSourceMetadata.getName());
Expand Down Expand Up @@ -163,6 +182,9 @@ public void updateDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {

@Override
public void deleteDataSourceMetadata(String datasourceName) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
DeleteRequest deleteRequest = new DeleteRequest(DATASOURCE_INDEX_NAME);
deleteRequest.id(datasourceName);
deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand Down Expand Up @@ -302,4 +324,8 @@ private void handleSigV4PropertiesEncryptionDecryption(
.ifPresent(list::add);
encryptOrDecrypt(propertiesMap, isEncryption, list);
}

private boolean isEnabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.opensearch.sql.datasources.rest;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.opensearch.client.node.NodeClient;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.threadpool.ThreadPool;

public class RestDataSourceQueryActionTest {

private OpenSearchSettings settings;
private RestRequest request;
private RestChannel channel;
private NodeClient nodeClient;
private ThreadPool threadPool;
private RestDataSourceQueryAction unit;

@BeforeEach
public void setup() {
settings = Mockito.mock(OpenSearchSettings.class);
request = Mockito.mock(RestRequest.class);
channel = Mockito.mock(RestChannel.class);
nodeClient = Mockito.mock(NodeClient.class);
threadPool = Mockito.mock(ThreadPool.class);

Mockito.when(nodeClient.threadPool()).thenReturn(threadPool);

unit = new RestDataSourceQueryAction(settings);
}

@Test
@SneakyThrows
public void testWhenDataSourcesAreDisabled() {
setDataSourcesEnabled(false);
unit.handleRequest(request, channel, nodeClient);
Mockito.verifyNoInteractions(nodeClient);
ArgumentCaptor<RestResponse> response = ArgumentCaptor.forClass(RestResponse.class);
Mockito.verify(channel, Mockito.times(1)).sendResponse(response.capture());
Assertions.assertEquals(400, response.getValue().status().getStatus());
JsonObject actualResponseJson =
new Gson().fromJson(response.getValue().content().utf8ToString(), JsonObject.class);
JsonObject expectedResponseJson = new JsonObject();
expectedResponseJson.addProperty("status", 400);
expectedResponseJson.add("error", new JsonObject());
expectedResponseJson.getAsJsonObject("error").addProperty("type", "OpenSearchStatusException");
expectedResponseJson.getAsJsonObject("error").addProperty("reason", "Invalid Request");
expectedResponseJson
.getAsJsonObject("error")
.addProperty("details", "plugins.query.datasources.enabled is disabled");
Assertions.assertEquals(expectedResponseJson, actualResponseJson);
}

@Test
@SneakyThrows
public void testWhenDataSourcesAreEnabled() {
setDataSourcesEnabled(true);
Mockito.when(request.method()).thenReturn(RestRequest.Method.GET);
unit.handleRequest(request, channel, nodeClient);
Mockito.verify(threadPool, Mockito.times(1))
.schedule(ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any());
Mockito.verifyNoInteractions(channel);
}

@Test
public void testGetName() {
Assertions.assertEquals("datasource_actions", unit.getName());
}

private void setDataSourcesEnabled(boolean value) {
Mockito.when(settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED)).thenReturn(value);
}
}
Loading

0 comments on commit a827d03

Please sign in to comment.