diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index 02f87a69f2..779a8bf772 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -8,8 +8,8 @@ package org.opensearch.sql.datasources.rest; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.NOT_FOUND; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import static org.opensearch.rest.RestRequest.Method.*; import com.google.common.collect.ImmutableList; @@ -293,7 +293,7 @@ private void handleException(Exception e, RestChannel restChannel) { reportError(restChannel, e, BAD_REQUEST); } else { MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS); - reportError(restChannel, e, SERVICE_UNAVAILABLE); + reportError(restChannel, e, INTERNAL_SERVER_ERROR); } } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java index 56b54ba748..eed2369590 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java @@ -31,7 +31,7 @@ public void queryExceedResourceLimitShouldFail() throws IOException { String query = String.format("search source=%s age=20", TEST_INDEX_DOG); ResponseException exception = expectThrows(ResponseException.class, () -> executeQuery(query)); - assertEquals(503, exception.getResponse().getStatusLine().getStatusCode()); + assertEquals(500, exception.getResponse().getStatusLine().getStatusCode()); assertThat( exception.getMessage(), Matchers.containsString("resource is not enough to run the" + " query, quit.")); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index fc8934dd73..d75f5abc76 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -6,8 +6,8 @@ package org.opensearch.sql.legacy.plugin; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.OK; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import com.alibaba.druid.sql.parser.ParserException; import com.google.common.collect.ImmutableList; @@ -23,6 +23,7 @@ import java.util.regex.Pattern; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchSecurityException; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.common.inject.Injector; @@ -171,21 +172,23 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli QueryAction queryAction = explainRequest(client, sqlRequest, format); executeSqlRequest(request, queryAction, client, restChannel); } catch (Exception e) { - logAndPublishMetrics(e); - reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + handleException(restChannel, e); } }, - (restChannel, exception) -> { - logAndPublishMetrics(exception); - reportError( - restChannel, - exception, - isClientError(exception) ? BAD_REQUEST : SERVICE_UNAVAILABLE); - }); + this::handleException); } catch (Exception e) { - logAndPublishMetrics(e); - return channel -> - reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + return channel -> handleException(channel, e); + } + } + + private void handleException(RestChannel restChannel, Exception exception) { + logAndPublishMetrics(exception); + if (exception instanceof OpenSearchSecurityException) { + OpenSearchSecurityException securityException = (OpenSearchSecurityException) exception; + reportError(restChannel, exception, securityException.status()); + } else { + reportError( + restChannel, exception, isClientError(exception) ? BAD_REQUEST : INTERNAL_SERVER_ERROR); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java index 383363b1e3..6f9d1e4117 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java @@ -5,7 +5,7 @@ package org.opensearch.sql.legacy.plugin; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -84,8 +84,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> channel.sendResponse( new BytesRestResponse( - SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) .toString())); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java index 2c837a7b2b..f9744ab841 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java @@ -151,7 +151,15 @@ public void collect( indexToType.put(tableName, null); } else if (sqlExprTableSource.getExpr() instanceof SQLBinaryOpExpr) { SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExprTableSource.getExpr(); - tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName(); + SQLExpr leftSideOfExpression = sqlBinaryOpExpr.getLeft(); + if (leftSideOfExpression instanceof SQLIdentifierExpr) { + tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName(); + } else { + throw new ParserException( + "Left side of the expression [" + + leftSideOfExpression.toString() + + "] is expected to be an identifier"); + } SQLExpr rightSideOfExpression = sqlBinaryOpExpr.getRight(); // This assumes that right side of the expression is different name in query diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java index 44d3e2cbc0..7922d60647 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java @@ -10,6 +10,7 @@ import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; +import com.alibaba.druid.sql.parser.ParserException; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -100,6 +101,15 @@ public void testSelectTheFieldWithConflictMappingShouldThrowException() { rewriteTerm(sql); } + @Test + public void testIssue2391_WithWrongBinaryOperation() { + String sql = "SELECT * from I_THINK/IM/A_URL"; + exception.expect(ParserException.class); + exception.expectMessage( + "Left side of the expression [I_THINK / IM] is expected to be an identifier"); + rewriteTerm(sql); + } + private String rewriteTerm(String sql) { SQLQueryExpr sqlQueryExpr = SqlParserUtils.parse(sql); sqlQueryExpr.accept(new TermFieldRewriter()); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java index d35962be91..4dda11718f 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java @@ -8,7 +8,6 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.OK; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -130,7 +129,7 @@ public void onFailure(Exception e) { Metrics.getInstance() .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS) .increment(); - reportError(channel, e, SERVICE_UNAVAILABLE); + reportError(channel, e, INTERNAL_SERVER_ERROR); } } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java index 39a3d20abb..d3d7074b20 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java @@ -5,7 +5,7 @@ package org.opensearch.sql.plugin.rest; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -79,8 +79,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> channel.sendResponse( new BytesRestResponse( - SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) .toString())); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index 6de8c35f03..cef3b6ede2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -11,6 +11,9 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -22,6 +25,9 @@ public class OpensearchAsyncQueryJobMetadataStorageService private final StateStore stateStore; + private static final Logger LOGGER = + LogManager.getLogger(OpensearchAsyncQueryJobMetadataStorageService.class); + @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); @@ -30,8 +36,13 @@ public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { @Override public Optional getJobMetadata(String qid) { - AsyncQueryId queryId = new AsyncQueryId(qid); - return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) - .apply(queryId.docId()); + try { + AsyncQueryId queryId = new AsyncQueryId(qid); + return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) + .apply(queryId.docId()); + } catch (Exception e) { + LOGGER.error("Error while fetching the job metadata.", e); + throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", qid)); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java index b99ebe0e8c..106028cdc9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java @@ -19,6 +19,15 @@ public static AsyncQueryId newAsyncQueryId(String datasourceName) { return new AsyncQueryId(encode(datasourceName)); } + public static boolean isValidQueryId(String id) { + try { + decode(id); + return true; + } catch (Throwable e) { + return false; + } + } + public String getDataSourceName() { return decode(id); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index ae4adc6de9..90d5d73696 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.rest; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.TOO_MANY_REQUESTS; import static org.opensearch.rest.RestRequest.Method.DELETE; import static org.opensearch.rest.RestRequest.Method.GET; @@ -26,10 +26,12 @@ import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; +import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.datasources.exceptions.ErrorMessage; import org.opensearch.sql.datasources.utils.Scheduler; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.utils.MetricUtils; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -112,12 +114,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient } } - private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) - throws IOException { - MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); - CreateAsyncQueryRequest submitJobRequest = - CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); - return restChannel -> + private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) { + return restChannel -> { + try { + MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); + CreateAsyncQueryRequest submitJobRequest = + CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); Scheduler.schedule( nodeClient, () -> @@ -140,6 +142,10 @@ public void onFailure(Exception e) { handleException(e, restChannel, restRequest.method()); } })); + } catch (Exception e) { + handleException(e, restChannel, restRequest.method()); + } + }; } private RestChannelConsumer executeGetAsyncQueryResultRequest( @@ -187,7 +193,7 @@ private void handleException( reportError(restChannel, e, BAD_REQUEST); addCustomerErrorMetric(requestMethod); } else { - reportError(restChannel, e, SERVICE_UNAVAILABLE); + reportError(restChannel, e, INTERNAL_SERVER_ERROR); addSystemErrorMetric(requestMethod); } } @@ -227,7 +233,10 @@ private void reportError(final RestChannel channel, final Exception e, final Res } private static boolean isClientError(Exception e) { - return e instanceof IllegalArgumentException || e instanceof IllegalStateException; + return e instanceof IllegalArgumentException + || e instanceof IllegalStateException + || e instanceof DataSourceNotFoundException + || e instanceof AsyncQueryNotFoundException; } private void addSystemErrorMetric(RestRequest.Method requestMethod) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 6acf6bc9a8..98527b6241 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -41,23 +41,28 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) LangType lang = null; String datasource = null; String sessionId = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - if (fieldName.equals("query")) { - query = parser.textOrNull(); - } else if (fieldName.equals("lang")) { - String langString = parser.textOrNull(); - lang = LangType.fromString(langString); - } else if (fieldName.equals("datasource")) { - datasource = parser.textOrNull(); - } else if (fieldName.equals(SESSION_ID)) { - sessionId = parser.textOrNull(); - } else { - throw new IllegalArgumentException("Unknown field: " + fieldName); + try { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (fieldName.equals("query")) { + query = parser.textOrNull(); + } else if (fieldName.equals("lang")) { + String langString = parser.textOrNull(); + lang = LangType.fromString(langString); + } else if (fieldName.equals("datasource")) { + datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); + } else { + throw new IllegalArgumentException("Unknown field: " + fieldName); + } } + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Error while parsing the request body: %s", e.getMessage())); } - return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c9b4b6fc88..1c7e806462 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -207,7 +207,7 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( JobExecutionResponseReader jobExecutionResponseReader) { StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore, dataSourceService); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( emrServerlessClientFactory, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index cf838db829..c5e327bac9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -5,15 +5,43 @@ package org.opensearch.sql.spark.asyncquery; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.DATASOURCE_URI_HOSTS_DENY_LIST; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Optional; +import org.apache.commons.lang3.StringUtils; +import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.plugins.Plugin; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest @@ -22,13 +50,89 @@ public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest public static final String DS_NAME = "mys3"; private static final String MOCK_SESSION_ID = "sessionId"; private static final String MOCK_RESULT_INDEX = "resultIndex"; + private static final String MOCK_QUERY_ID = "00fdo6u94n7abo0q"; + private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; + protected ClusterService clusterService; + protected org.opensearch.sql.common.setting.Settings pluginSettings; + protected ClusterSettings clusterSettings; + protected NodeClient client; + protected DataSourceServiceImpl dataSourceService; + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(AsyncQueryExecutorServiceSpec.TestSettingPlugin.class); + } + + public static class TestSettingPlugin extends Plugin { + @Override + public List> getSettings() { + return OpenSearchSettings.pluginSettings(); + } + } @Before public void setup() { + this.clusterService = clusterService(); + this.client = (NodeClient) cluster().client(); + this.clusterSettings = clusterService.getClusterSettings(); + this.pluginSettings = new OpenSearchSettings(clusterSettings); + LocalClusterState.state().setClusterService(clusterService); + LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); + this.dataSourceService = createDataSourceService(); opensearchJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService( - new StateStore(client(), clusterService())); + new StateStore(client, clusterService), this.dataSourceService); + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder() + .putList(DATASOURCE_URI_HOSTS_DENY_LIST.getKey(), Collections.emptyList()) + .build()) + .get(); + DataSourceMetadata dm = + new DataSourceMetadata( + DS_NAME, + StringUtils.EMPTY, + DataSourceType.S3GLUE, + ImmutableList.of(), + ImmutableMap.of( + "glue.auth.type", + "iam_role", + "glue.auth.role_arn", + "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", + "http://localhost:9200", + "glue.indexstore.opensearch.auth", + "noauth"), + null); + dataSourceService.createDataSource(dm); + } + + @After + public void clean() { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().putNull(DATASOURCE_URI_HOSTS_DENY_LIST.getKey()).build()) + .get(); + } + + private DataSourceServiceImpl createDataSourceService() { + String masterKey = "a57d991d9b573f75b9bba1df"; + DataSourceMetadataStorage dataSourceMetadataStorage = + new OpenSearchDataSourceMetadataStorage( + client, clusterService, new EncryptorImpl(masterKey)); + return new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new GlueDataSourceFactory(pluginSettings)) + .build(), + dataSourceMetadataStorage, + meta -> {}); } @Test @@ -69,4 +173,44 @@ public void testStoreJobMetadataWithResultExtraData() { assertEquals("resultIndex", actual.get().getResultIndex()); assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); } + + @Test + public void testStoreJobMetadataWithBas() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + MOCK_RESULT_INDEX, + MOCK_SESSION_ID); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertEquals("resultIndex", actual.get().getResultIndex()); + assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); + } + + @Test + public void testGetJobMetadataWithMalformedQueryId() { + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(MOCK_QUERY_ID)); + Assertions.assertEquals( + String.format("Invalid queryId: %s", MOCK_QUERY_ID), + asyncQueryNotFoundException.getMessage()); + } + + @Test + public void testGetJobMetadataWithEmptyQueryId() { + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata("")); + Assertions.assertEquals("Invalid queryId: ", asyncQueryNotFoundException.getMessage()); + } }