diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index 02f87a69f2..43249e8a28 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -8,8 +8,8 @@ package org.opensearch.sql.datasources.rest; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.NOT_FOUND; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import static org.opensearch.rest.RestRequest.Method.*; import com.google.common.collect.ImmutableList; @@ -20,6 +20,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchSecurityException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -282,6 +283,10 @@ private void handleException(Exception e, RestChannel restChannel) { if (e instanceof DataSourceNotFoundException) { MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_CUS); reportError(restChannel, e, NOT_FOUND); + } else if (e instanceof OpenSearchSecurityException) { + MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_CUS); + OpenSearchSecurityException exception = (OpenSearchSecurityException) e; + reportError(restChannel, exception, exception.status()); } else if (e instanceof OpenSearchException) { MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS); OpenSearchException exception = (OpenSearchException) e; @@ -293,7 +298,7 @@ private void handleException(Exception e, RestChannel restChannel) { reportError(restChannel, e, BAD_REQUEST); } else { MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS); - reportError(restChannel, e, SERVICE_UNAVAILABLE); + reportError(restChannel, e, INTERNAL_SERVER_ERROR); } } } diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 3fbc16d15f..983b66b055 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -30,10 +30,8 @@ Sample Setting Value :: "region":"eu-west-1", "sparkSubmitParameter": "--conf spark.dynamicAllocation.enabled=false" }' -If this setting is not configured during bootstrap, Async Query APIs will be disabled and it requires a cluster restart to enable them back again. -We make use of default aws credentials chain to make calls to the emr serverless application and also make sure the default credentials -have pass role permissions for emr-job-execution-role mentioned in the engine configuration. - +The user must be careful before transitioning to a new application or region, as changing these parameters might lead to failures in the retrieval of results from previous async query jobs. +The system relies on the default AWS credentials chain for making calls to the EMR serverless application. It is essential to confirm that the default credentials possess the necessary permissions to pass the role required for EMR job execution, as specified in the engine configuration. * ``applicationId``, ``executionRoleARN`` and ``region`` are required parameters. * ``sparkSubmitParameter`` is an optional parameter. It can take the form ``--conf A=1 --conf B=2 ...``. diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/QueryIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/QueryIT.java index 71795b1fb7..3fe8b50eef 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/QueryIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/QueryIT.java @@ -1755,7 +1755,7 @@ public void multipleIndicesOneNotExistWithoutHint() throws IOException { Assert.fail("Expected exception, but call succeeded"); } catch (ResponseException e) { Assert.assertEquals( - RestStatus.BAD_REQUEST.getStatus(), e.getResponse().getStatusLine().getStatusCode()); + RestStatus.NOT_FOUND.getStatus(), e.getResponse().getStatusLine().getStatusCode()); final String entity = TestUtils.getResponseBody(e.getResponse()); Assert.assertThat(entity, containsString("\"type\": \"IndexNotFoundException\"")); } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/PPLPluginIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/PPLPluginIT.java index 0c638be1e7..44f79a8944 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/PPLPluginIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/PPLPluginIT.java @@ -58,7 +58,7 @@ public void testQueryEndpointShouldFail() throws IOException { @Test public void testQueryEndpointShouldFailWithNonExistIndex() throws IOException { exceptionRule.expect(ResponseException.class); - exceptionRule.expect(hasProperty("response", statusCode(400))); + exceptionRule.expect(hasProperty("response", statusCode(404))); client().performRequest(makePPLRequest("search source=non_exist_index")); } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java index 56b54ba748..eed2369590 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java @@ -31,7 +31,7 @@ public void queryExceedResourceLimitShouldFail() throws IOException { String query = String.format("search source=%s age=20", TEST_INDEX_DOG); ResponseException exception = expectThrows(ResponseException.class, () -> executeQuery(query)); - assertEquals(503, exception.getResponse().getStatusLine().getStatusCode()); + assertEquals(500, exception.getResponse().getStatusLine().getStatusCode()); assertThat( exception.getMessage(), Matchers.containsString("resource is not enough to run the" + " query, quit.")); diff --git a/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java index 086f32cba7..cdf467706c 100644 --- a/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java @@ -87,7 +87,7 @@ public void testCrossClusterSearchWithoutLocalFieldMappingShouldFail() throws IO () -> executeQuery(String.format("search source=%s", TEST_INDEX_ACCOUNT_REMOTE))); assertTrue( exception.getMessage().contains("IndexNotFoundException") - && exception.getMessage().contains("400 Bad Request")); + && exception.getMessage().contains("404 Not Found")); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java index 54831cb561..96bbae94e5 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java @@ -588,7 +588,7 @@ public void nested_function_all_subfields_in_wrong_clause() { + " \"details\": \"Invalid use of expression nested(message.*)\",\n" + " \"type\": \"UnsupportedOperationException\"\n" + " },\n" - + " \"status\": 503\n" + + " \"status\": 500\n" + "}")); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index fc8934dd73..c47e5f05bd 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -6,8 +6,8 @@ package org.opensearch.sql.legacy.plugin; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.OK; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import com.alibaba.druid.sql.parser.ParserException; import com.google.common.collect.ImmutableList; @@ -23,6 +23,7 @@ import java.util.regex.Pattern; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.common.inject.Injector; @@ -171,21 +172,23 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli QueryAction queryAction = explainRequest(client, sqlRequest, format); executeSqlRequest(request, queryAction, client, restChannel); } catch (Exception e) { - logAndPublishMetrics(e); - reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + handleException(restChannel, e); } }, - (restChannel, exception) -> { - logAndPublishMetrics(exception); - reportError( - restChannel, - exception, - isClientError(exception) ? BAD_REQUEST : SERVICE_UNAVAILABLE); - }); + this::handleException); } catch (Exception e) { - logAndPublishMetrics(e); - return channel -> - reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + return channel -> handleException(channel, e); + } + } + + private void handleException(RestChannel restChannel, Exception exception) { + logAndPublishMetrics(exception); + if (exception instanceof OpenSearchException) { + OpenSearchException openSearchException = (OpenSearchException) exception; + reportError(restChannel, openSearchException, openSearchException.status()); + } else { + reportError( + restChannel, exception, isClientError(exception) ? BAD_REQUEST : INTERNAL_SERVER_ERROR); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java index 383363b1e3..6f9d1e4117 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java @@ -5,7 +5,7 @@ package org.opensearch.sql.legacy.plugin; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -84,8 +84,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> channel.sendResponse( new BytesRestResponse( - SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) .toString())); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java index 2c837a7b2b..f9744ab841 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java @@ -151,7 +151,15 @@ public void collect( indexToType.put(tableName, null); } else if (sqlExprTableSource.getExpr() instanceof SQLBinaryOpExpr) { SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExprTableSource.getExpr(); - tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName(); + SQLExpr leftSideOfExpression = sqlBinaryOpExpr.getLeft(); + if (leftSideOfExpression instanceof SQLIdentifierExpr) { + tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName(); + } else { + throw new ParserException( + "Left side of the expression [" + + leftSideOfExpression.toString() + + "] is expected to be an identifier"); + } SQLExpr rightSideOfExpression = sqlBinaryOpExpr.getRight(); // This assumes that right side of the expression is different name in query diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java index 44d3e2cbc0..7922d60647 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java @@ -10,6 +10,7 @@ import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; +import com.alibaba.druid.sql.parser.ParserException; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -100,6 +101,15 @@ public void testSelectTheFieldWithConflictMappingShouldThrowException() { rewriteTerm(sql); } + @Test + public void testIssue2391_WithWrongBinaryOperation() { + String sql = "SELECT * from I_THINK/IM/A_URL"; + exception.expect(ParserException.class); + exception.expectMessage( + "Left side of the expression [I_THINK / IM] is expected to be an identifier"); + rewriteTerm(sql); + } + private String rewriteTerm(String sql) { SQLQueryExpr sqlQueryExpr = SqlParserUtils.parse(sql); sqlQueryExpr.accept(new TermFieldRewriter()); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java index d35962be91..7e6d3c1422 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java @@ -8,7 +8,6 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.OK; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -17,7 +16,7 @@ import java.util.Set; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchSecurityException; +import org.opensearch.OpenSearchException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -116,8 +115,11 @@ public void onFailure(Exception e) { channel, INTERNAL_SERVER_ERROR, "Failed to explain the query due to error: " + e.getMessage()); - } else if (e instanceof OpenSearchSecurityException) { - OpenSearchSecurityException exception = (OpenSearchSecurityException) e; + } else if (e instanceof OpenSearchException) { + Metrics.getInstance() + .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS) + .increment(); + OpenSearchException exception = (OpenSearchException) e; reportError(channel, exception, exception.status()); } else { LOG.error("Error happened during query handling", e); @@ -130,7 +132,7 @@ public void onFailure(Exception e) { Metrics.getInstance() .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS) .increment(); - reportError(channel, e, SERVICE_UNAVAILABLE); + reportError(channel, e, INTERNAL_SERVER_ERROR); } } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java index 39a3d20abb..d3d7074b20 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java @@ -5,7 +5,7 @@ package org.opensearch.sql.plugin.rest; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -79,8 +79,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> channel.sendResponse( new BytesRestResponse( - SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) .toString())); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index 6de8c35f03..cef3b6ede2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -11,6 +11,9 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -22,6 +25,9 @@ public class OpensearchAsyncQueryJobMetadataStorageService private final StateStore stateStore; + private static final Logger LOGGER = + LogManager.getLogger(OpensearchAsyncQueryJobMetadataStorageService.class); + @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); @@ -30,8 +36,13 @@ public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { @Override public Optional getJobMetadata(String qid) { - AsyncQueryId queryId = new AsyncQueryId(qid); - return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) - .apply(queryId.docId()); + try { + AsyncQueryId queryId = new AsyncQueryId(qid); + return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) + .apply(queryId.docId()); + } catch (Exception e) { + LOGGER.error("Error while fetching the job metadata.", e); + throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", qid)); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index ae4adc6de9..90d5d73696 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.rest; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.TOO_MANY_REQUESTS; import static org.opensearch.rest.RestRequest.Method.DELETE; import static org.opensearch.rest.RestRequest.Method.GET; @@ -26,10 +26,12 @@ import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; +import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.datasources.exceptions.ErrorMessage; import org.opensearch.sql.datasources.utils.Scheduler; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.utils.MetricUtils; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -112,12 +114,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient } } - private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) - throws IOException { - MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); - CreateAsyncQueryRequest submitJobRequest = - CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); - return restChannel -> + private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) { + return restChannel -> { + try { + MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); + CreateAsyncQueryRequest submitJobRequest = + CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); Scheduler.schedule( nodeClient, () -> @@ -140,6 +142,10 @@ public void onFailure(Exception e) { handleException(e, restChannel, restRequest.method()); } })); + } catch (Exception e) { + handleException(e, restChannel, restRequest.method()); + } + }; } private RestChannelConsumer executeGetAsyncQueryResultRequest( @@ -187,7 +193,7 @@ private void handleException( reportError(restChannel, e, BAD_REQUEST); addCustomerErrorMetric(requestMethod); } else { - reportError(restChannel, e, SERVICE_UNAVAILABLE); + reportError(restChannel, e, INTERNAL_SERVER_ERROR); addSystemErrorMetric(requestMethod); } } @@ -227,7 +233,10 @@ private void reportError(final RestChannel channel, final Exception e, final Res } private static boolean isClientError(Exception e) { - return e instanceof IllegalArgumentException || e instanceof IllegalStateException; + return e instanceof IllegalArgumentException + || e instanceof IllegalStateException + || e instanceof DataSourceNotFoundException + || e instanceof AsyncQueryNotFoundException; } private void addSystemErrorMetric(RestRequest.Method requestMethod) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 6acf6bc9a8..98527b6241 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -41,23 +41,28 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) LangType lang = null; String datasource = null; String sessionId = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - if (fieldName.equals("query")) { - query = parser.textOrNull(); - } else if (fieldName.equals("lang")) { - String langString = parser.textOrNull(); - lang = LangType.fromString(langString); - } else if (fieldName.equals("datasource")) { - datasource = parser.textOrNull(); - } else if (fieldName.equals(SESSION_ID)) { - sessionId = parser.textOrNull(); - } else { - throw new IllegalArgumentException("Unknown field: " + fieldName); + try { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (fieldName.equals("query")) { + query = parser.textOrNull(); + } else if (fieldName.equals("lang")) { + String langString = parser.textOrNull(); + lang = LangType.fromString(langString); + } else if (fieldName.equals("datasource")) { + datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); + } else { + throw new IllegalArgumentException("Unknown field: " + fieldName); + } } + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Error while parsing the request body: %s", e.getMessage())); } - return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index cf838db829..20c944fd0a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -11,6 +11,8 @@ import java.util.Optional; import org.junit.Before; import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -22,6 +24,7 @@ public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest public static final String DS_NAME = "mys3"; private static final String MOCK_SESSION_ID = "sessionId"; private static final String MOCK_RESULT_INDEX = "resultIndex"; + private static final String MOCK_QUERY_ID = "00fdo6u94n7abo0q"; private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; @Before @@ -69,4 +72,24 @@ public void testStoreJobMetadataWithResultExtraData() { assertEquals("resultIndex", actual.get().getResultIndex()); assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); } + + @Test + public void testGetJobMetadataWithMalformedQueryId() { + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(MOCK_QUERY_ID)); + Assertions.assertEquals( + String.format("Invalid QueryId: %s", MOCK_QUERY_ID), + asyncQueryNotFoundException.getMessage()); + } + + @Test + public void testGetJobMetadataWithEmptyQueryId() { + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata("")); + Assertions.assertEquals("Invalid QueryId: ", asyncQueryNotFoundException.getMessage()); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java index dd634d6055..24f5a9d6fe 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -30,6 +30,64 @@ public void fromXContent() throws IOException { Assertions.assertEquals("select 1", queryRequest.getQuery()); } + @Test + public void testConstructor() { + Assertions.assertDoesNotThrow( + () -> new CreateAsyncQueryRequest("select * from apple", "my_glue", LangType.SQL)); + } + + @Test + public void fromXContentWithDuplicateFields() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"datasource\": \"my_glue_1\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); + Assertions.assertEquals( + "Error while parsing the request body: Duplicate field 'datasource'\n" + + " at [Source: REDACTED (`StreamReadFeature.INCLUDE_SOURCE_IN_LOCATION` disabled);" + + " line: 3, column: 15]", + illegalArgumentException.getMessage()); + } + + @Test + public void fromXContentWithUnknownField() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"random\": \"my_gue_1\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); + Assertions.assertEquals( + "Error while parsing the request body: Unknown field: random", + illegalArgumentException.getMessage()); + } + + @Test + public void fromXContentWithWrongDatatype() throws IOException { + String request = + "{\"datasource\": [\"my_glue\", \"my_glue_1\"], \"lang\": \"sql\", \"query\": \"select" + + " 1\"}"; + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); + Assertions.assertEquals( + "Error while parsing the request body: Can't get text on a START_ARRAY at 1:16", + illegalArgumentException.getMessage()); + } + @Test public void fromXContentWithSessionId() throws IOException { String request =